You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
749 lines
31 KiB
749 lines
31 KiB
package com.quillstudios.pokegoalshelper
|
|
|
|
import android.content.Context
|
|
import android.graphics.Bitmap
|
|
import android.graphics.BitmapFactory
|
|
import android.util.Log
|
|
import org.opencv.android.Utils
|
|
import org.opencv.core.*
|
|
import org.opencv.dnn.Dnn
|
|
import org.opencv.imgproc.Imgproc
|
|
import org.tensorflow.lite.Interpreter
|
|
import java.io.FileInputStream
|
|
import java.io.FileOutputStream
|
|
import java.io.IOException
|
|
import java.nio.ByteBuffer
|
|
import java.nio.ByteOrder
|
|
import java.nio.channels.FileChannel
|
|
import kotlin.math.max
|
|
import kotlin.math.min
|
|
|
|
class YOLOTFLiteDetector(private val context: Context) {
|
|
|
|
companion object {
|
|
private const val TAG = "YOLOTFLiteDetector"
|
|
private const val MODEL_FILE = "pokemon_model.tflite"
|
|
private const val INPUT_SIZE = 640
|
|
private const val CONFIDENCE_THRESHOLD = 0.05f // Extremely low to test conversion quality
|
|
private const val NMS_THRESHOLD = 0.4f
|
|
private const val NUM_CHANNELS = 3
|
|
private const val NUM_DETECTIONS = 8400 // YOLOv8 default
|
|
private const val NUM_CLASSES = 95 // Your class count
|
|
}
|
|
|
|
private var interpreter: Interpreter? = null
|
|
private var isInitialized = false
|
|
|
|
// Input/output buffers
|
|
private var inputBuffer: ByteBuffer? = null
|
|
private var outputBuffer: FloatArray? = null
|
|
|
|
// Your class names (same as before)
|
|
private val classNames = mapOf(
|
|
0 to "ball_icon_pokeball",
|
|
1 to "ball_icon_greatball",
|
|
2 to "ball_icon_ultraball",
|
|
3 to "ball_icon_masterball",
|
|
4 to "ball_icon_safariball",
|
|
5 to "ball_icon_levelball",
|
|
6 to "ball_icon_lureball",
|
|
7 to "ball_icon_moonball",
|
|
8 to "ball_icon_friendball",
|
|
9 to "ball_icon_loveball",
|
|
10 to "ball_icon_heavyball",
|
|
11 to "ball_icon_fastball",
|
|
12 to "ball_icon_sportball",
|
|
13 to "ball_icon_premierball",
|
|
14 to "ball_icon_repeatball",
|
|
15 to "ball_icon_timerball",
|
|
16 to "ball_icon_nestball",
|
|
17 to "ball_icon_netball",
|
|
18 to "ball_icon_diveball",
|
|
19 to "ball_icon_luxuryball",
|
|
20 to "ball_icon_healball",
|
|
21 to "ball_icon_quickball",
|
|
22 to "ball_icon_duskball",
|
|
23 to "ball_icon_cherishball",
|
|
24 to "ball_icon_dreamball",
|
|
25 to "ball_icon_beastball",
|
|
26 to "ball_icon_strangeparts",
|
|
27 to "ball_icon_parkball",
|
|
28 to "ball_icon_gsball",
|
|
29 to "pokemon_nickname",
|
|
30 to "gender_icon_male",
|
|
31 to "gender_icon_female",
|
|
32 to "pokemon_level",
|
|
33 to "language",
|
|
34 to "last_game_stamp_home",
|
|
35 to "last_game_stamp_lgp",
|
|
36 to "last_game_stamp_lge",
|
|
37 to "last_game_stamp_sw",
|
|
38 to "last_game_stamp_sh",
|
|
39 to "last_game_stamp_bank",
|
|
40 to "last_game_stamp_bd",
|
|
41 to "last_game_stamp_sp",
|
|
42 to "last_game_stamp_pla",
|
|
43 to "last_game_stamp_sc",
|
|
44 to "last_game_stamp_vi",
|
|
45 to "last_game_stamp_go",
|
|
46 to "national_dex_number",
|
|
47 to "pokemon_species",
|
|
48 to "type_1",
|
|
49 to "type_2",
|
|
50 to "shiny_icon",
|
|
51 to "origin_icon_vc",
|
|
52 to "origin_icon_xyoras",
|
|
53 to "origin_icon_smusum",
|
|
54 to "origin_icon_lg",
|
|
55 to "origin_icon_swsh",
|
|
56 to "origin_icon_go",
|
|
57 to "origin_icon_bdsp",
|
|
58 to "origin_icon_pla",
|
|
59 to "origin_icon_sv",
|
|
60 to "pokerus_infected_icon",
|
|
61 to "pokerus_cured_icon",
|
|
62 to "hp_value",
|
|
63 to "attack_value",
|
|
64 to "defense_value",
|
|
65 to "sp_atk_value",
|
|
66 to "sp_def_value",
|
|
67 to "speed_value",
|
|
68 to "ability_name",
|
|
69 to "nature_name",
|
|
70 to "move_name",
|
|
71 to "original_trainer_name",
|
|
72 to "original_trainder_number",
|
|
73 to "alpha_mark",
|
|
74 to "tera_water",
|
|
75 to "tera_psychic",
|
|
76 to "tera_ice",
|
|
77 to "tera_fairy",
|
|
78 to "tera_poison",
|
|
79 to "tera_ghost",
|
|
80 to "ball_icon_originball",
|
|
81 to "tera_dragon",
|
|
82 to "tera_steel",
|
|
83 to "tera_grass",
|
|
84 to "tera_normal",
|
|
85 to "tera_fire",
|
|
86 to "tera_electric",
|
|
87 to "tera_fighting",
|
|
88 to "tera_ground",
|
|
89 to "tera_flying",
|
|
90 to "tera_bug",
|
|
91 to "tera_rock",
|
|
92 to "tera_dark",
|
|
93 to "low_confidence",
|
|
94 to "ball_icon_pokeball_hisui",
|
|
95 to "ball_icon_ultraball_husui"
|
|
)
|
|
|
|
fun initialize(): Boolean {
|
|
if (isInitialized) return true
|
|
|
|
try {
|
|
Log.i(TAG, "🤖 Initializing TensorFlow Lite YOLO detector...")
|
|
|
|
// Load model from assets
|
|
Log.i(TAG, "📂 Copying model file: $MODEL_FILE")
|
|
val modelPath = copyAssetToInternalStorage(MODEL_FILE)
|
|
if (modelPath == null) {
|
|
Log.e(TAG, "❌ Failed to copy TFLite model from assets")
|
|
return false
|
|
}
|
|
Log.i(TAG, "✅ Model copied to: $modelPath")
|
|
|
|
// Create interpreter
|
|
Log.i(TAG, "📥 Loading TFLite model from: $modelPath")
|
|
val modelFile = loadModelFile(modelPath)
|
|
Log.i(TAG, "📥 Model file loaded, size: ${modelFile.capacity()} bytes")
|
|
|
|
val options = Interpreter.Options()
|
|
options.setNumThreads(4) // Use 4 CPU threads
|
|
Log.i(TAG, "🔧 Creating TensorFlow Lite interpreter...")
|
|
interpreter = Interpreter(modelFile, options)
|
|
Log.i(TAG, "✅ Interpreter created successfully")
|
|
|
|
// Get model info
|
|
val inputTensor = interpreter!!.getInputTensor(0)
|
|
val outputTensor = interpreter!!.getOutputTensor(0)
|
|
Log.i(TAG, "📊 Input tensor shape: ${inputTensor.shape().contentToString()}")
|
|
Log.i(TAG, "📊 Output tensor shape: ${outputTensor.shape().contentToString()}")
|
|
|
|
// Allocate input/output buffers
|
|
Log.i(TAG, "📦 Allocating buffers...")
|
|
allocateBuffers()
|
|
Log.i(TAG, "✅ Buffers allocated")
|
|
|
|
// Test model with dummy input
|
|
Log.i(TAG, "🧪 Testing model with dummy input...")
|
|
testModelInputOutput()
|
|
Log.i(TAG, "✅ Model test completed")
|
|
|
|
isInitialized = true
|
|
Log.i(TAG, "✅ TensorFlow Lite YOLO detector initialized successfully")
|
|
Log.i(TAG, "📊 Model info: ${classNames.size} classes, input size: ${INPUT_SIZE}x${INPUT_SIZE}")
|
|
|
|
return true
|
|
|
|
} catch (e: Exception) {
|
|
Log.e(TAG, "❌ Error initializing TensorFlow Lite detector", e)
|
|
e.printStackTrace()
|
|
return false
|
|
}
|
|
}
|
|
|
|
private fun loadModelFile(modelPath: String): ByteBuffer {
|
|
val fileInputStream = FileInputStream(modelPath)
|
|
val fileChannel = fileInputStream.channel
|
|
val startOffset = 0L
|
|
val declaredLength = fileChannel.size()
|
|
val modelBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
|
|
fileInputStream.close()
|
|
return modelBuffer
|
|
}
|
|
|
|
private fun allocateBuffers() {
|
|
// Get actual tensor shapes from the model
|
|
val inputTensor = interpreter!!.getInputTensor(0)
|
|
val outputTensor = interpreter!!.getOutputTensor(0)
|
|
|
|
val inputShape = inputTensor.shape()
|
|
val outputShape = outputTensor.shape()
|
|
|
|
Log.d(TAG, "📊 Actual input shape: ${inputShape.contentToString()}")
|
|
Log.d(TAG, "📊 Actual output shape: ${outputShape.contentToString()}")
|
|
|
|
// Input buffer: [1, 640, 640, 3] * 4 bytes per float
|
|
val inputSize = inputShape.fold(1) { acc, dim -> acc * dim } * 4
|
|
inputBuffer = ByteBuffer.allocateDirect(inputSize)
|
|
inputBuffer!!.order(ByteOrder.nativeOrder())
|
|
|
|
// Output buffer: [1, 100, 8400]
|
|
val outputSize = outputShape.fold(1) { acc, dim -> acc * dim }
|
|
outputBuffer = FloatArray(outputSize)
|
|
|
|
Log.d(TAG, "📦 Allocated input buffer: ${inputSize} bytes")
|
|
Log.d(TAG, "📦 Allocated output buffer: ${outputSize} floats")
|
|
}
|
|
|
|
private fun testModelInputOutput() {
|
|
try {
|
|
// Fill input buffer with dummy data
|
|
inputBuffer!!.rewind()
|
|
repeat(INPUT_SIZE * INPUT_SIZE * NUM_CHANNELS) { // HWC format
|
|
inputBuffer!!.putFloat(0.5f) // Dummy normalized pixel value
|
|
}
|
|
|
|
// Create output array as multidimensional array for TensorFlow Lite
|
|
val outputShape = interpreter!!.getOutputTensor(0).shape()
|
|
Log.d(TAG, "🔍 Output tensor shape: ${outputShape.contentToString()}")
|
|
|
|
// Create output as 3D array: [batch][features][detections]
|
|
val output = Array(outputShape[0]) {
|
|
Array(outputShape[1]) {
|
|
FloatArray(outputShape[2])
|
|
}
|
|
}
|
|
|
|
// Run inference
|
|
inputBuffer!!.rewind()
|
|
interpreter!!.run(inputBuffer, output)
|
|
|
|
// Check output
|
|
val firstBatch = output[0]
|
|
val maxOutput = firstBatch.flatMap { it.asIterable() }.maxOrNull() ?: 0f
|
|
val minOutput = firstBatch.flatMap { it.asIterable() }.minOrNull() ?: 0f
|
|
Log.i(TAG, "🧪 Model test: output range [${String.format("%.4f", minOutput)}, ${String.format("%.4f", maxOutput)}]")
|
|
|
|
// Convert 3D array to flat array for postprocessing
|
|
outputBuffer = firstBatch.flatMap { it.asIterable() }.toFloatArray()
|
|
Log.d(TAG, "🧪 Converted output to flat array of size: ${outputBuffer!!.size}")
|
|
|
|
} catch (e: Exception) {
|
|
Log.e(TAG, "❌ Model test failed", e)
|
|
throw e
|
|
}
|
|
}
|
|
|
|
fun detect(inputMat: Mat): List<Detection> {
|
|
if (!isInitialized || interpreter == null) {
|
|
Log.w(TAG, "⚠️ TensorFlow Lite detector not initialized")
|
|
return emptyList()
|
|
}
|
|
|
|
try {
|
|
Log.d(TAG, "🔍 Running TFLite YOLO detection on ${inputMat.cols()}x${inputMat.rows()} image")
|
|
|
|
// Preprocess image
|
|
preprocessImage(inputMat)
|
|
|
|
// Create output array as multidimensional array for TensorFlow Lite
|
|
val outputShape = interpreter!!.getOutputTensor(0).shape()
|
|
val output = Array(outputShape[0]) {
|
|
Array(outputShape[1]) {
|
|
FloatArray(outputShape[2])
|
|
}
|
|
}
|
|
|
|
// Run inference
|
|
inputBuffer!!.rewind()
|
|
interpreter!!.run(inputBuffer, output)
|
|
|
|
// Convert 3D array to flat array for postprocessing
|
|
val flatOutput = output[0].flatMap { it.asIterable() }.toFloatArray()
|
|
|
|
// Post-process results
|
|
val detections = postprocess(flatOutput, inputMat.cols(), inputMat.rows())
|
|
|
|
Log.i(TAG, "✅ TFLite YOLO detection complete: ${detections.size} objects detected")
|
|
|
|
return detections
|
|
|
|
} catch (e: Exception) {
|
|
Log.e(TAG, "❌ Error during TFLite YOLO detection", e)
|
|
return emptyList()
|
|
}
|
|
}
|
|
|
|
// Corrected preprocessImage function
|
|
private fun preprocessImage(mat: Mat) {
|
|
// Convert to RGB
|
|
val rgbMat = Mat()
|
|
// It's safer to always explicitly convert to the expected 3-channel type
|
|
// Assuming input `mat` can be from RGBA (screen capture) or BGR (file)
|
|
if (mat.channels() == 4) {
|
|
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGRA2RGB) // Assuming screen capture is BGRA
|
|
} else { // Handle 3-channel BGR or RGB direct
|
|
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB) // Convert BGR (OpenCV default) to RGB
|
|
}
|
|
|
|
// Resize to input size (640x640)
|
|
val resized = Mat()
|
|
Imgproc.resize(rgbMat, resized, Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble()))
|
|
|
|
Log.d(TAG, "🖼️ Preprocessed image: ${resized.cols()}x${resized.rows()}, channels: ${resized.channels()}")
|
|
|
|
// Prepare a temporary byte array to get pixel data from Mat
|
|
val pixels = ByteArray(INPUT_SIZE * INPUT_SIZE * NUM_CHANNELS)
|
|
resized.get(0, 0, pixels) // Get pixel data in HWC (Height, Width, Channel) byte order
|
|
|
|
// Convert to ByteBuffer in CHW (Channel, Height, Width) float format
|
|
inputBuffer!!.rewind()
|
|
|
|
for (c in 0 until NUM_CHANNELS) { // Iterate channels
|
|
for (y in 0 until INPUT_SIZE) { // Iterate height
|
|
for (x in 0 until INPUT_SIZE) { // Iterate width
|
|
// Calculate index in the HWC 'pixels' array
|
|
val pixelIndex = (y * INPUT_SIZE + x) * NUM_CHANNELS + c
|
|
// Get byte value, convert to unsigned int (0-255), then to float, then normalize
|
|
val pixelValue = (pixels[pixelIndex].toInt() and 0xFF) / 255.0f
|
|
inputBuffer!!.putFloat(pixelValue)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Debug: Check first few values
|
|
inputBuffer!!.rewind()
|
|
val testValues = FloatArray(10)
|
|
inputBuffer!!.asFloatBuffer().get(testValues)
|
|
val testStr = testValues.map { String.format("%.4f", it) }.joinToString(", ")
|
|
Log.d(TAG, "🌐 Input buffer first 10 values (CHW): [$testStr]")
|
|
|
|
// Clean up
|
|
rgbMat.release()
|
|
resized.release()
|
|
}
|
|
|
|
/*
|
|
private fun preprocessImage(mat: Mat) {
|
|
// Convert to RGB
|
|
val rgbMat = Mat()
|
|
if (mat.channels() == 4) {
|
|
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGRA2RGB)
|
|
} else if (mat.channels() == 3) {
|
|
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB)
|
|
} else {
|
|
mat.copyTo(rgbMat)
|
|
}
|
|
|
|
// Resize to input size
|
|
val resized = Mat()
|
|
Imgproc.resize(rgbMat, resized, Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble()))
|
|
|
|
Log.d(TAG, "🖼️ Preprocessed image: ${resized.cols()}x${resized.rows()}, channels: ${resized.channels()}")
|
|
|
|
// Convert to ByteBuffer in HWC format [640, 640, 3]
|
|
inputBuffer!!.rewind()
|
|
|
|
|
|
val rgbBytes = ByteArray(INPUT_SIZE * INPUT_SIZE * 3)
|
|
resized.get(0, 0, rgbBytes)
|
|
|
|
// Convert to float and normalize in HWC format (Height, Width, Channels)
|
|
// The data is already in HWC format from OpenCV
|
|
for (i in rgbBytes.indices) {
|
|
val pixelValue = (rgbBytes[i].toInt() and 0xFF) / 255.0f
|
|
inputBuffer!!.putFloat(pixelValue)
|
|
}
|
|
|
|
// Debug: Check first few values
|
|
inputBuffer!!.rewind()
|
|
val testValues = FloatArray(10)
|
|
inputBuffer!!.asFloatBuffer().get(testValues)
|
|
val testStr = testValues.map { String.format("%.4f", it) }.joinToString(", ")
|
|
Log.d(TAG, "🌐 Input buffer first 10 values: [$testStr]")
|
|
|
|
// Clean up
|
|
rgbMat.release()
|
|
resized.release()
|
|
}
|
|
*/
|
|
/*
|
|
private fun postprocess(output: FloatArray, originalWidth: Int, originalHeight: Int): List<Detection> {
|
|
val detections = mutableListOf<Detection>()
|
|
val confidences = mutableListOf<Float>()
|
|
val boxes = mutableListOf<Rect>()
|
|
val classIds = mutableListOf<Int>()
|
|
|
|
Log.d(TAG, "🔍 Processing detections from output array of size ${output.size}")
|
|
Log.d(TAG, "🔍 Original image size: ${originalWidth}x${originalHeight}")
|
|
|
|
// YOLOv8 outputs normalized coordinates (0-1), so we scale directly to original image size
|
|
val numFeatures = 100 // From actual model output
|
|
val numDetections = 8400 // From actual model output
|
|
|
|
var validDetections = 0
|
|
|
|
// Process transposed output: [1, 100, 8400]
|
|
// Features are: [x, y, w, h, conf, class0, class1, ..., class94]
|
|
for (i in 0 until numDetections) {
|
|
// In transposed format: feature_idx * numDetections + detection_idx
|
|
// YOLOv8 outputs normalized coordinates (0-1), scale to original image size
|
|
val centerX = output[0 * numDetections + i] * originalWidth // x row
|
|
val centerY = output[1 * numDetections + i] * originalHeight // y row
|
|
val width = output[2 * numDetections + i] * originalWidth // w row
|
|
val height = output[3 * numDetections + i] * originalHeight // h row
|
|
val confidence = output[4 * numDetections + i] // confidence row
|
|
|
|
// Debug first few detections
|
|
if (i < 3) {
|
|
val rawX = output[0 * numDetections + i]
|
|
val rawY = output[1 * numDetections + i]
|
|
val rawW = output[2 * numDetections + i]
|
|
val rawH = output[3 * numDetections + i]
|
|
Log.d(TAG, "🔍 Detection $i: raw x=${String.format("%.3f", rawX)}, y=${String.format("%.3f", rawY)}, w=${String.format("%.3f", rawW)}, h=${String.format("%.3f", rawH)}")
|
|
Log.d(TAG, "🔍 Detection $i: scaled x=${String.format("%.1f", centerX)}, y=${String.format("%.1f", centerY)}, w=${String.format("%.1f", width)}, h=${String.format("%.1f", height)}")
|
|
}
|
|
|
|
// Try different YOLOv8 format: no separate confidence, max class score is the confidence
|
|
var maxClassScore = 0f
|
|
var classId = 0
|
|
|
|
for (j in 4 until numFeatures) { // Start from feature 4 (after x,y,w,h), no separate conf
|
|
val classIdx = j * numDetections + i
|
|
if (classIdx >= output.size) break
|
|
|
|
val classScore = output[classIdx]
|
|
if (classScore > maxClassScore) {
|
|
maxClassScore = classScore
|
|
classId = j - 4 // Convert to 0-based class index
|
|
}
|
|
}
|
|
|
|
// Debug first few with max class scores
|
|
if (i < 3) {
|
|
Log.d(TAG, "🔍 Detection $i: maxClass=${String.format("%.4f", maxClassScore)}, classId=$classId")
|
|
}
|
|
|
|
if (maxClassScore > CONFIDENCE_THRESHOLD && classId < classNames.size) {
|
|
val x = (centerX - width / 2).toInt()
|
|
val y = (centerY - height / 2).toInt()
|
|
|
|
boxes.add(Rect(x, y, width.toInt(), height.toInt()))
|
|
confidences.add(maxClassScore)
|
|
classIds.add(classId)
|
|
validDetections++
|
|
|
|
if (validDetections <= 3) {
|
|
Log.d(TAG, "✅ Valid transposed detection: class=$classId, conf=${String.format("%.4f", maxClassScore)}")
|
|
}
|
|
}
|
|
}
|
|
|
|
Log.d(TAG, "🎯 Found ${validDetections} valid detections above confidence threshold")
|
|
|
|
// Apply Non-Maximum Suppression (simple version)
|
|
val finalDetections = applyNMS(boxes, confidences, classIds)
|
|
|
|
Log.d(TAG, "🎯 Post-processing complete: ${finalDetections.size} final detections after NMS")
|
|
|
|
return finalDetections.sortedByDescending { it.confidence }
|
|
}
|
|
*/
|
|
|
|
// In postprocess function
|
|
private fun postprocess(output: FloatArray, originalWidth: Int, originalHeight: Int): List<Detection> {
|
|
val detections = mutableListOf<Detection>()
|
|
val confidences = mutableListOf<Float>()
|
|
val boxes = mutableListOf<Rect>()
|
|
val classIds = mutableListOf<Int>()
|
|
|
|
Log.d(TAG, "🔍 Processing detections from output array of size ${output.size}")
|
|
Log.d(TAG, "🔍 Original image size: ${originalWidth}x${originalHeight}")
|
|
|
|
// Corrected Interpretation based on YOUR observed output shape [1, 100, 8400]
|
|
// This means attributes (box, confidence, class scores) are in the second dimension (index 1)
|
|
// and detections are in the third dimension (index 2).
|
|
// So, you need to iterate through 'detections' (8400) and for each, access its 'attributes' (100).
|
|
val numAttributesPerDetection = 100 // This is your 'outputShape[1]'
|
|
val totalDetections = 8400 // This is your 'outputShape[2]'
|
|
|
|
// Loop through each of the 8400 potential detections
|
|
for (i in 0 until totalDetections) {
|
|
// Get the attributes for the i-th detection
|
|
// The data for the i-th detection starts at index 'i' in the 'output' flat array,
|
|
// then it's interleaved. This is why it's better to process from the 3D array directly.
|
|
|
|
// Re-think: If `output[0]` from interpreter.run is `Array(100) { FloatArray(8400) }`
|
|
// Then it's attributes_per_detection x total_detections.
|
|
// So, output[0][0] is the x-coords for all detections, output[0][1] is y-coords for all detections.
|
|
// Let's assume this structure from your `output` 3D array:
|
|
// output[0][attribute_idx][detection_idx]
|
|
|
|
val centerX = output[0 * totalDetections + i] // x-coordinate for detection 'i'
|
|
val centerY = output[1 * totalDetections + i] // y-coordinate for detection 'i'
|
|
val width = output[2 * totalDetections + i] // width for detection 'i'
|
|
val height = output[3 * totalDetections + i] // height for detection 'i'
|
|
val objectnessConf = output[4 * totalDetections + i] // Objectness confidence for detection 'i'
|
|
|
|
// Debug raw values and scaled values before class scores
|
|
if (i < 5) { // Log first 5 detections
|
|
Log.d(TAG, "🔍 Detection $i (pre-scale): x=${String.format("%.3f", centerX)}, y=${String.format("%.3f", centerY)}, w=${String.format("%.3f", width)}, h=${String.format("%.3f", height)}, obj_conf=${String.format("%.4f", objectnessConf)}")
|
|
}
|
|
|
|
var maxClassScore = 0f
|
|
var classId = -1 // Initialize with -1 to catch issues
|
|
|
|
// Loop through class scores (starting from index 5 in the attributes list)
|
|
// Indices 5 to 99 are class scores (95 classes total)
|
|
for (j in 5 until numAttributesPerDetection) {
|
|
// Get the class score for detection 'i' and class 'j-5'
|
|
val classScore = output[j * totalDetections + i]
|
|
|
|
val classScore_sigmoid = 1.0f / (1.0f + Math.exp(-classScore.toDouble())).toFloat() // Apply sigmoid
|
|
|
|
if (classScore_sigmoid > maxClassScore) {
|
|
maxClassScore = classScore_sigmoid
|
|
classId = j - 5 // Convert to 0-based class index
|
|
}
|
|
}
|
|
|
|
val objectnessConf_sigmoid = 1.0f / (1.0f + Math.exp(-objectnessConf.toDouble())).toFloat() // Apply sigmoid
|
|
|
|
// Final confidence: Objectness score multiplied by the max class score
|
|
val finalConfidence = objectnessConf_sigmoid * maxClassScore
|
|
|
|
// Debug final confidence for first few detections
|
|
if (i < 5) {
|
|
Log.d(TAG, "🔍 Detection $i (post-score): maxClass=${String.format("%.4f", maxClassScore)}, finalConf=${String.format("%.4f", finalConfidence)}, classId=$classId")
|
|
}
|
|
|
|
// Apply confidence threshold
|
|
if (finalConfidence > CONFIDENCE_THRESHOLD && classId != -1 && classId < classNames.size) {
|
|
// Convert normalized coordinates (0-1) to pixel coordinates based on original image size
|
|
val x = ((centerX - width / 2) * originalWidth).toInt()
|
|
val y = ((centerY - height / 2) * originalHeight).toInt()
|
|
val w = (width * originalWidth).toInt()
|
|
val h = (height * originalHeight).toInt()
|
|
|
|
// Ensure coordinates are within image bounds
|
|
val x1 = max(0, x)
|
|
val y1 = max(0, y)
|
|
val x2 = min(originalWidth, x + w)
|
|
val y2 = min(originalHeight, y + h)
|
|
|
|
// Add to lists for NMS
|
|
boxes.add(Rect(x1, y1, x2 - x1, y2 - y1))
|
|
confidences.add(finalConfidence)
|
|
classIds.add(classId)
|
|
}
|
|
}
|
|
|
|
Log.d(TAG, "🎯 Found ${boxes.size} detections above confidence threshold before NMS")
|
|
|
|
// Apply Non-Maximum Suppression (using OpenCV's NMSBoxes which is more robust)
|
|
val finalDetections = applyNMS_OpenCV(boxes, confidences, classIds)
|
|
|
|
Log.d(TAG, "🎯 Post-processing complete: ${finalDetections.size} final detections after NMS")
|
|
|
|
return finalDetections.sortedByDescending { it.confidence }
|
|
}
|
|
|
|
// Replace your applyNMS function with this (or rename your old one and call this one)
|
|
private fun applyNMS_OpenCV(boxes: List<Rect>, confidences: List<Float>, classIds: List<Int>): List<Detection> {
|
|
val finalDetections = mutableListOf<Detection>()
|
|
|
|
// Convert List<Rect> to List<Rect2d>
|
|
val boxes2d = boxes.map { Rect2d(it.x.toDouble(), it.y.toDouble(), it.width.toDouble(), it.height.toDouble()) }
|
|
|
|
// Correct way to convert List<Rect2d> to MatOfRect2d
|
|
val boxesMat = MatOfRect2d()
|
|
boxesMat.fromList(boxes2d) // Use fromList to populate the MatOfRect2d
|
|
|
|
val confsMat = MatOfFloat()
|
|
confsMat.fromList(confidences) // This part was already correct
|
|
|
|
val indices = MatOfInt()
|
|
|
|
// OpenCV NMSBoxes
|
|
Dnn.NMSBoxes(
|
|
boxesMat,
|
|
confsMat,
|
|
CONFIDENCE_THRESHOLD, // Confidence threshold (boxes below this are ignored by NMS)
|
|
NMS_THRESHOLD, // IoU threshold (boxes with IoU above this are suppressed)
|
|
indices
|
|
)
|
|
|
|
val ind = indices.toArray() // Get array of indices to keep
|
|
|
|
for (i in ind.indices) {
|
|
val idx = ind[i]
|
|
val className = classNames[classIds[idx]] ?: "unknown_${classIds[idx]}"
|
|
finalDetections.add(
|
|
Detection(
|
|
classId = classIds[idx],
|
|
className = className,
|
|
confidence = confidences[idx],
|
|
boundingBox = boxes[idx] // Keep original Rect for the final Detection object if you prefer ints
|
|
)
|
|
)
|
|
}
|
|
|
|
// Release Mats to prevent memory leaks
|
|
boxesMat.release()
|
|
confsMat.release()
|
|
indices.release()
|
|
|
|
return finalDetections
|
|
}
|
|
|
|
private fun applyNMS(boxes: List<Rect>, confidences: List<Float>, classIds: List<Int>): List<Detection> {
|
|
val detections = mutableListOf<Detection>()
|
|
|
|
// Simple NMS implementation
|
|
val indices = confidences.indices.sortedByDescending { confidences[it] }
|
|
val suppressed = BooleanArray(boxes.size)
|
|
|
|
for (i in indices) {
|
|
if (suppressed[i]) continue
|
|
|
|
val className = classNames[classIds[i]] ?: "unknown_${classIds[i]}"
|
|
detections.add(
|
|
Detection(
|
|
classId = classIds[i],
|
|
className = className,
|
|
confidence = confidences[i],
|
|
boundingBox = boxes[i]
|
|
)
|
|
)
|
|
|
|
// Suppress overlapping boxes
|
|
for (j in indices) {
|
|
if (i != j && !suppressed[j] && classIds[i] == classIds[j]) {
|
|
val iou = calculateIoU(boxes[i], boxes[j])
|
|
if (iou > NMS_THRESHOLD) {
|
|
suppressed[j] = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return detections
|
|
}
|
|
|
|
private fun calculateIoU(box1: Rect, box2: Rect): Float {
|
|
val x1 = max(box1.x, box2.x)
|
|
val y1 = max(box1.y, box2.y)
|
|
val x2 = min(box1.x + box1.width, box2.x + box2.width)
|
|
val y2 = min(box1.y + box1.height, box2.y + box2.height)
|
|
|
|
val intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
|
val area1 = box1.width * box1.height
|
|
val area2 = box2.width * box2.height
|
|
val union = area1 + area2 - intersection
|
|
|
|
return if (union > 0) intersection.toFloat() / union.toFloat() else 0f
|
|
}
|
|
|
|
fun testWithStaticImage(): List<Detection> {
|
|
if (!isInitialized) {
|
|
Log.e(TAG, "❌ TensorFlow Lite detector not initialized for test")
|
|
return emptyList()
|
|
}
|
|
|
|
try {
|
|
Log.i(TAG, "🧪 TESTING WITH STATIC IMAGE (TFLite)")
|
|
|
|
// Load test image from assets
|
|
val inputStream = context.assets.open("test_pokemon.jpg")
|
|
val bitmap = BitmapFactory.decodeStream(inputStream)
|
|
inputStream.close()
|
|
|
|
if (bitmap == null) {
|
|
Log.e(TAG, "❌ Failed to load test_pokemon.jpg from assets")
|
|
return emptyList()
|
|
}
|
|
|
|
Log.i(TAG, "📸 Loaded test image: ${bitmap.width}x${bitmap.height}")
|
|
|
|
// Convert bitmap to OpenCV Mat
|
|
val mat = Mat()
|
|
Utils.bitmapToMat(bitmap, mat)
|
|
|
|
Log.i(TAG, "🔄 Converted to Mat: ${mat.cols()}x${mat.rows()}, channels: ${mat.channels()}")
|
|
|
|
// Run detection
|
|
val detections = detect(mat)
|
|
|
|
Log.i(TAG, "🎯 TFLite TEST RESULT: ${detections.size} detections found")
|
|
detections.forEachIndexed { index, detection ->
|
|
Log.i(TAG, " $index: ${detection.className} (${String.format("%.3f", detection.confidence)}) at [${detection.boundingBox.x}, ${detection.boundingBox.y}, ${detection.boundingBox.width}, ${detection.boundingBox.height}]")
|
|
}
|
|
|
|
// Clean up
|
|
mat.release()
|
|
bitmap.recycle()
|
|
|
|
return detections
|
|
|
|
} catch (e: Exception) {
|
|
Log.e(TAG, "❌ Error in TFLite static image test", e)
|
|
return emptyList()
|
|
}
|
|
}
|
|
|
|
private fun copyAssetToInternalStorage(assetName: String): String? {
|
|
return try {
|
|
val inputStream = context.assets.open(assetName)
|
|
val file = context.getFileStreamPath(assetName)
|
|
val outputStream = FileOutputStream(file)
|
|
|
|
inputStream.copyTo(outputStream)
|
|
inputStream.close()
|
|
outputStream.close()
|
|
|
|
file.absolutePath
|
|
} catch (e: IOException) {
|
|
Log.e(TAG, "Error copying asset $assetName", e)
|
|
null
|
|
}
|
|
}
|
|
|
|
fun release() {
|
|
interpreter?.close()
|
|
interpreter = null
|
|
isInitialized = false
|
|
Log.d(TAG, "TensorFlow Lite detector released")
|
|
}
|
|
}
|