Browse Source
ARCH-002: Extract ML Inference Engine - Created MLInferenceEngine interface with async detection methods - Implemented YOLOInferenceEngine preserving ALL YOLOOnnxDetector functionality: * Complete 96-class mapping with exact original class names * All preprocessing techniques (ultralytics, enhanced, sharpened, original) * All coordinate transformation modes (HYBRID, LETTERBOX, DIRECT) * Weighted NMS and cross-class NMS for semantically related classes * Confidence mapping and mobile optimization * Debug features (class filtering, confidence logging) * Letterbox resize with proper aspect ratio preservation * CLAHE contrast enhancement and sharpening filters - Created ImagePreprocessor utility for reusable preprocessing operations * Configurable preprocessing with letterboxing, normalization, color conversion * Coordinate transformation utilities for model-to-image space conversion * Support for different preprocessing configurations - Updated ScreenCaptureService to use new MLInferenceEngine: * Replaced YOLOOnnxDetector with MLInferenceEngine dependency injection * Added class name to class ID mapping for compatibility * Maintained all existing detection pipeline functionality * Proper async/await integration with coroutines - Applied preferred code style throughout: * Opening braces on new lines for functions and statements * snake_case for local variables to distinguish from members/parameters * Consistent formatting matching project standards - Removed obsolete YOLO implementations (YOLODetector, YOLOTFLiteDetector) - Preserved all sophisticated ML features: TTA, multi-preprocessing, confidence mapping 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>arch-002-ml-inference-engine
6 changed files with 1581 additions and 1466 deletions
@ -1,700 +0,0 @@ |
|||||
package com.quillstudios.pokegoalshelper |
|
||||
|
|
||||
import android.content.Context |
|
||||
import android.util.Log |
|
||||
import org.opencv.core.* |
|
||||
import org.opencv.dnn.Dnn |
|
||||
import org.opencv.dnn.Net |
|
||||
import org.opencv.imgproc.Imgproc |
|
||||
import java.io.FileOutputStream |
|
||||
import java.io.IOException |
|
||||
import android.graphics.Bitmap |
|
||||
import android.graphics.BitmapFactory |
|
||||
import org.opencv.android.Utils |
|
||||
|
|
||||
data class Detection( |
|
||||
val classId: Int, |
|
||||
val className: String, |
|
||||
val confidence: Float, |
|
||||
val boundingBox: Rect |
|
||||
) |
|
||||
|
|
||||
class YOLODetector(private val context: Context) { |
|
||||
|
|
||||
companion object { |
|
||||
private const val TAG = "YOLODetector" |
|
||||
private const val MODEL_FILE = "pokemon_model.onnx" |
|
||||
private const val INPUT_SIZE = 640 |
|
||||
private const val CONFIDENCE_THRESHOLD = 0.1f // Lower threshold for debugging |
|
||||
private const val NMS_THRESHOLD = 0.4f |
|
||||
} |
|
||||
|
|
||||
private fun parseTransposedOutput( |
|
||||
data: FloatArray, |
|
||||
rows: Int, |
|
||||
cols: Int, |
|
||||
xScale: Float, |
|
||||
yScale: Float, |
|
||||
boxes: MutableList<Rect>, |
|
||||
confidences: MutableList<Float>, |
|
||||
classIds: MutableList<Int> |
|
||||
) { |
|
||||
// For transposed output: rows=features(100), cols=detections(8400) |
|
||||
// Data layout: [x1, x2, x3, ...], [y1, y2, y3, ...], [w1, w2, w3, ...], etc. |
|
||||
|
|
||||
Log.d(TAG, "🔄 Parsing transposed output: $rows features x $cols detections") |
|
||||
|
|
||||
var validDetections = 0 |
|
||||
for (i in 0 until cols) { // Loop through detections |
|
||||
if (i >= data.size / rows) break |
|
||||
|
|
||||
// Extract coordinates from transposed layout |
|
||||
val centerX = data[0 * cols + i] * xScale // x row |
|
||||
val centerY = data[1 * cols + i] * yScale // y row |
|
||||
val width = data[2 * cols + i] * xScale // width row |
|
||||
val height = data[3 * cols + i] * yScale // height row |
|
||||
val confidence = data[4 * cols + i] // confidence row |
|
||||
|
|
||||
// Debug first few detections |
|
||||
if (i < 3) { |
|
||||
Log.d(TAG, "🔍 Transposed detection $i: conf=${String.format("%.4f", confidence)}, x=${String.format("%.1f", centerX)}, y=${String.format("%.1f", centerY)}, w=${String.format("%.1f", width)}, h=${String.format("%.1f", height)}") |
|
||||
} |
|
||||
|
|
||||
if (confidence > CONFIDENCE_THRESHOLD) { |
|
||||
// Find class with highest score |
|
||||
var maxClassScore = 0f |
|
||||
var classId = 0 |
|
||||
|
|
||||
for (j in 5 until rows) { // Start from row 5 (after x,y,w,h,conf) |
|
||||
if (j * cols + i >= data.size) break |
|
||||
val classScore = data[j * cols + i] |
|
||||
if (classScore > maxClassScore) { |
|
||||
maxClassScore = classScore |
|
||||
classId = j - 5 |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
val finalConfidence = confidence * maxClassScore |
|
||||
|
|
||||
if (finalConfidence > CONFIDENCE_THRESHOLD) { |
|
||||
val x = (centerX - width / 2).toInt() |
|
||||
val y = (centerY - height / 2).toInt() |
|
||||
|
|
||||
boxes.add(Rect(x, y, width.toInt(), height.toInt())) |
|
||||
confidences.add(finalConfidence) |
|
||||
classIds.add(classId) |
|
||||
validDetections++ |
|
||||
|
|
||||
if (validDetections <= 3) { |
|
||||
Log.d(TAG, "✅ Valid transposed detection: class=$classId, conf=${String.format("%.4f", finalConfidence)}") |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
Log.d(TAG, "🎯 Transposed parsing found $validDetections valid detections") |
|
||||
} |
|
||||
|
|
||||
private var net: Net? = null |
|
||||
private var isInitialized = false |
|
||||
|
|
||||
// Your class names from training - COMPLETE 93 CLASSES |
|
||||
private val classNames = mapOf( |
|
||||
0 to "ball_icon_pokeball", |
|
||||
1 to "ball_icon_greatball", |
|
||||
2 to "ball_icon_ultraball", |
|
||||
3 to "ball_icon_masterball", |
|
||||
4 to "ball_icon_safariball", |
|
||||
5 to "ball_icon_levelball", |
|
||||
6 to "ball_icon_lureball", |
|
||||
7 to "ball_icon_moonball", |
|
||||
8 to "ball_icon_friendball", |
|
||||
9 to "ball_icon_loveball", |
|
||||
10 to "ball_icon_heavyball", |
|
||||
11 to "ball_icon_fastball", |
|
||||
12 to "ball_icon_sportball", |
|
||||
13 to "ball_icon_premierball", |
|
||||
14 to "ball_icon_repeatball", |
|
||||
15 to "ball_icon_timerball", |
|
||||
16 to "ball_icon_nestball", |
|
||||
17 to "ball_icon_netball", |
|
||||
18 to "ball_icon_diveball", |
|
||||
19 to "ball_icon_luxuryball", |
|
||||
20 to "ball_icon_healball", |
|
||||
21 to "ball_icon_quickball", |
|
||||
22 to "ball_icon_duskball", |
|
||||
23 to "ball_icon_cherishball", |
|
||||
24 to "ball_icon_dreamball", |
|
||||
25 to "ball_icon_beastball", |
|
||||
26 to "ball_icon_strangeparts", |
|
||||
27 to "ball_icon_parkball", |
|
||||
28 to "ball_icon_gsball", |
|
||||
29 to "pokemon_nickname", |
|
||||
30 to "gender_icon_male", |
|
||||
31 to "gender_icon_female", |
|
||||
32 to "pokemon_level", |
|
||||
33 to "language", |
|
||||
34 to "last_game_stamp_home", |
|
||||
35 to "last_game_stamp_lgp", |
|
||||
36 to "last_game_stamp_lge", |
|
||||
37 to "last_game_stamp_sw", |
|
||||
38 to "last_game_stamp_sh", |
|
||||
39 to "last_game_stamp_bank", |
|
||||
40 to "last_game_stamp_bd", |
|
||||
41 to "last_game_stamp_sp", |
|
||||
42 to "last_game_stamp_pla", |
|
||||
43 to "last_game_stamp_sc", |
|
||||
44 to "last_game_stamp_vi", |
|
||||
45 to "last_game_stamp_go", |
|
||||
46 to "national_dex_number", |
|
||||
47 to "pokemon_species", |
|
||||
48 to "type_1", |
|
||||
49 to "type_2", |
|
||||
50 to "shiny_icon", |
|
||||
51 to "origin_icon_vc", |
|
||||
52 to "origin_icon_xyoras", |
|
||||
53 to "origin_icon_smusum", |
|
||||
54 to "origin_icon_lg", |
|
||||
55 to "origin_icon_swsh", |
|
||||
56 to "origin_icon_go", |
|
||||
57 to "origin_icon_bdsp", |
|
||||
58 to "origin_icon_pla", |
|
||||
59 to "origin_icon_sv", |
|
||||
60 to "pokerus_infected_icon", |
|
||||
61 to "pokerus_cured_icon", |
|
||||
62 to "hp_value", |
|
||||
63 to "attack_value", |
|
||||
64 to "defense_value", |
|
||||
65 to "sp_atk_value", |
|
||||
66 to "sp_def_value", |
|
||||
67 to "speed_value", |
|
||||
68 to "ability_name", |
|
||||
69 to "nature_name", |
|
||||
70 to "move_name", |
|
||||
71 to "original_trainer_name", |
|
||||
72 to "original_trainder_number", |
|
||||
73 to "alpha_mark", |
|
||||
74 to "tera_water", |
|
||||
75 to "tera_psychic", |
|
||||
76 to "tera_ice", |
|
||||
77 to "tera_fairy", |
|
||||
78 to "tera_poison", |
|
||||
79 to "tera_ghost", |
|
||||
80 to "ball_icon_originball", |
|
||||
81 to "tera_dragon", |
|
||||
82 to "tera_steel", |
|
||||
83 to "tera_grass", |
|
||||
84 to "tera_normal", |
|
||||
85 to "tera_fire", |
|
||||
86 to "tera_electric", |
|
||||
87 to "tera_fighting", |
|
||||
88 to "tera_ground", |
|
||||
89 to "tera_flying", |
|
||||
90 to "tera_bug", |
|
||||
91 to "tera_rock", |
|
||||
92 to "tera_dark", |
|
||||
93 to "low_confidence", |
|
||||
94 to "ball_icon_pokeball_hisui", |
|
||||
95 to "ball_icon_ultraball_husui" |
|
||||
// Note: "", "", "" |
|
||||
// were in your list but would make it 96 classes. Using exactly 93 as reported by model. |
|
||||
) |
|
||||
|
|
||||
fun initialize(): Boolean { |
|
||||
if (isInitialized) return true |
|
||||
|
|
||||
try { |
|
||||
Log.i(TAG, "🤖 Initializing YOLO detector...") |
|
||||
|
|
||||
// Copy model from assets to internal storage if needed |
|
||||
val modelPath = copyAssetToInternalStorage(MODEL_FILE) |
|
||||
if (modelPath == null) { |
|
||||
Log.e(TAG, "❌ Failed to copy model from assets") |
|
||||
return false |
|
||||
} |
|
||||
|
|
||||
// Load the ONNX model |
|
||||
Log.i(TAG, "📥 Loading ONNX model from: $modelPath") |
|
||||
net = Dnn.readNetFromONNX(modelPath) |
|
||||
|
|
||||
if (net == null || net!!.empty()) { |
|
||||
Log.e(TAG, "❌ Failed to load ONNX model") |
|
||||
return false |
|
||||
} |
|
||||
|
|
||||
// Verify model loaded correctly |
|
||||
val layerNames = net!!.layerNames |
|
||||
Log.i(TAG, "🧠 Model loaded with ${layerNames.size} layers") |
|
||||
|
|
||||
val outputNames = net!!.unconnectedOutLayersNames |
|
||||
Log.i(TAG, "📝 Output layers: ${outputNames?.toString()}") |
|
||||
|
|
||||
// Set computational backend |
|
||||
net!!.setPreferableBackend(Dnn.DNN_BACKEND_OPENCV) |
|
||||
net!!.setPreferableTarget(Dnn.DNN_TARGET_CPU) |
|
||||
|
|
||||
// Debug: Check model input requirements |
|
||||
val inputNames = net!!.unconnectedOutLayersNames |
|
||||
Log.i(TAG, "🔍 Model input layers: ${inputNames?.toString()}") |
|
||||
|
|
||||
// Get input blob info if possible |
|
||||
try { |
|
||||
val dummyBlob = Mat.zeros(Size(640.0, 640.0), CvType.CV_32FC3) |
|
||||
net!!.setInput(dummyBlob) |
|
||||
Log.i(TAG, "✅ Model accepts 640x640 CV_32FC3 input") |
|
||||
dummyBlob.release() |
|
||||
} catch (e: Exception) { |
|
||||
Log.w(TAG, "⚠️ Model input test failed: ${e.message}") |
|
||||
} |
|
||||
|
|
||||
isInitialized = true |
|
||||
Log.i(TAG, "✅ YOLO detector initialized successfully") |
|
||||
Log.i(TAG, "📊 Model info: ${classNames.size} classes, input size: ${INPUT_SIZE}x${INPUT_SIZE}") |
|
||||
|
|
||||
return true |
|
||||
|
|
||||
} catch (e: Exception) { |
|
||||
Log.e(TAG, "❌ Error initializing YOLO detector", e) |
|
||||
return false |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
fun detect(inputMat: Mat): List<Detection> { |
|
||||
if (!isInitialized || net == null) { |
|
||||
Log.w(TAG, "⚠️ YOLO detector not initialized") |
|
||||
return emptyList() |
|
||||
} |
|
||||
|
|
||||
try { |
|
||||
Log.d(TAG, "🔍 Running YOLO detection on ${inputMat.cols()}x${inputMat.rows()} image") |
|
||||
|
|
||||
// Preprocess image |
|
||||
val blob = preprocessImage(inputMat) |
|
||||
|
|
||||
// Set input to the network |
|
||||
net!!.setInput(blob) |
|
||||
|
|
||||
// Check blob before sending to model |
|
||||
val blobTestData = FloatArray(10) |
|
||||
blob.get(0, 0, blobTestData) |
|
||||
val hasRealData = blobTestData.any { it != 0f } |
|
||||
Log.w(TAG, "⚠️ CRITICAL: Blob sent to model has real data: $hasRealData") |
|
||||
|
|
||||
if (!hasRealData) { |
|
||||
Log.e(TAG, "❌ FATAL: All blob creation methods failed - this is likely an OpenCV bug or model issue") |
|
||||
Log.e(TAG, "❌ Try these solutions:") |
|
||||
Log.e(TAG, " 1. Re-export your ONNX model with different settings") |
|
||||
Log.e(TAG, " 2. Try a different OpenCV version") |
|
||||
Log.e(TAG, " 3. Use a different inference framework (TensorFlow Lite)") |
|
||||
Log.e(TAG, " 4. Check if your model expects different input format") |
|
||||
} |
|
||||
|
|
||||
// Run forward pass |
|
||||
val outputs = mutableListOf<Mat>() |
|
||||
net!!.forward(outputs, net!!.unconnectedOutLayersNames) |
|
||||
|
|
||||
Log.d(TAG, "🧠 Model inference complete, got ${outputs.size} output tensors") |
|
||||
|
|
||||
// Post-process results |
|
||||
val detections = postprocess(outputs, inputMat.cols(), inputMat.rows()) |
|
||||
|
|
||||
// Clean up |
|
||||
blob.release() |
|
||||
outputs.forEach { it.release() } |
|
||||
|
|
||||
Log.i(TAG, "✅ YOLO detection complete: ${detections.size} objects detected") |
|
||||
|
|
||||
return detections |
|
||||
|
|
||||
} catch (e: Exception) { |
|
||||
Log.e(TAG, "❌ Error during YOLO detection", e) |
|
||||
return emptyList() |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
private fun preprocessImage(mat: Mat): Mat { |
|
||||
// Convert to RGB if needed |
|
||||
val rgbMat = Mat() |
|
||||
if (mat.channels() == 4) { |
|
||||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGRA2RGB) |
|
||||
} else if (mat.channels() == 3) { |
|
||||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB) |
|
||||
} else { |
|
||||
mat.copyTo(rgbMat) |
|
||||
} |
|
||||
|
|
||||
// Ensure the matrix is continuous for blob creation |
|
||||
if (!rgbMat.isContinuous) { |
|
||||
val continuousMat = Mat() |
|
||||
rgbMat.copyTo(continuousMat) |
|
||||
rgbMat.release() |
|
||||
continuousMat.copyTo(rgbMat) |
|
||||
continuousMat.release() |
|
||||
} |
|
||||
|
|
||||
Log.d(TAG, "🖼️ Input image: ${rgbMat.cols()}x${rgbMat.rows()}, channels: ${rgbMat.channels()}") |
|
||||
Log.d(TAG, "🖼️ Mat type: ${rgbMat.type()}, depth: ${rgbMat.depth()}") |
|
||||
|
|
||||
// Debug: Check if image data is not all zeros (use ByteArray for CV_8U data) |
|
||||
try { |
|
||||
val testData = ByteArray(3) |
|
||||
rgbMat.get(100, 100, testData) // Sample some pixels |
|
||||
val testValues = testData.map { (it.toInt() and 0xFF).toString() }.joinToString(", ") |
|
||||
Log.d(TAG, "🖼️ Sample RGB values at (100,100): [$testValues]") |
|
||||
} catch (e: Exception) { |
|
||||
Log.w(TAG, "⚠️ Could not sample pixel values: ${e.message}") |
|
||||
} |
|
||||
|
|
||||
// Create blob from image - match training preprocessing exactly |
|
||||
val blob = Dnn.blobFromImage( |
|
||||
rgbMat, |
|
||||
1.0 / 255.0, // Scale factor (normalize to 0-1) |
|
||||
Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble()), // Size |
|
||||
Scalar(0.0, 0.0, 0.0), // Mean subtraction (none for YOLO) |
|
||||
true, // Swap R and B channels for OpenCV |
|
||||
false, // Crop |
|
||||
CvType.CV_32F // Data type |
|
||||
) |
|
||||
|
|
||||
Log.d(TAG, "🌐 Blob created: [${blob.size(0)}, ${blob.size(1)}, ${blob.size(2)}, ${blob.size(3)}]") |
|
||||
|
|
||||
// Debug: Check blob values to ensure they're not all zeros |
|
||||
val blobData = FloatArray(Math.min(30, blob.total().toInt())) |
|
||||
blob.get(0, 0, blobData) |
|
||||
val blobValues = blobData.map { String.format("%.4f", it) }.joinToString(", ") |
|
||||
Log.d(TAG, "🌐 First 30 blob values: [$blobValues]") |
|
||||
|
|
||||
// Check if blob is completely zero |
|
||||
val nonZeroCount = blobData.count { it != 0f } |
|
||||
Log.d(TAG, "🌐 Non-zero blob values: $nonZeroCount/${blobData.size}") |
|
||||
|
|
||||
// Try different blob creation methods |
|
||||
if (nonZeroCount == 0) { |
|
||||
Log.w(TAG, "⚠️ Blob is all zeros! Trying alternative blob creation...") |
|
||||
|
|
||||
// Try without swapRB |
|
||||
val blob2 = Dnn.blobFromImage( |
|
||||
rgbMat, |
|
||||
1.0 / 255.0, |
|
||||
Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble()), |
|
||||
Scalar(0.0, 0.0, 0.0), |
|
||||
false, // No channel swap |
|
||||
false, |
|
||||
CvType.CV_32F |
|
||||
) |
|
||||
|
|
||||
val blobData2 = FloatArray(10) |
|
||||
blob2.get(0, 0, blobData2) |
|
||||
val blobValues2 = blobData2.map { String.format("%.4f", it) }.joinToString(", ") |
|
||||
Log.d(TAG, "🌐 Alternative blob (swapRB=false): [$blobValues2]") |
|
||||
|
|
||||
if (blobData2.any { it != 0f }) { |
|
||||
Log.i(TAG, "✅ Alternative blob has data! Using swapRB=false") |
|
||||
blob.release() |
|
||||
rgbMat.release() |
|
||||
return blob2 |
|
||||
} |
|
||||
blob2.release() |
|
||||
|
|
||||
// Try manual blob creation as last resort |
|
||||
Log.w(TAG, "⚠️ Both blob methods failed! Trying manual blob creation...") |
|
||||
val manualBlob = createManualBlob(rgbMat) |
|
||||
if (manualBlob != null) { |
|
||||
val manualData = FloatArray(10) |
|
||||
manualBlob.get(0, 0, manualData) |
|
||||
val manualValues = manualData.map { String.format("%.4f", it) }.joinToString(", ") |
|
||||
Log.d(TAG, "🌐 Manual blob: [$manualValues]") |
|
||||
|
|
||||
if (manualData.any { it != 0f }) { |
|
||||
Log.i(TAG, "✅ Manual blob has data! Using manual method") |
|
||||
blob.release() |
|
||||
rgbMat.release() |
|
||||
return manualBlob |
|
||||
} |
|
||||
manualBlob.release() |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
rgbMat.release() |
|
||||
return blob |
|
||||
} |
|
||||
|
|
||||
private fun createManualBlob(rgbMat: Mat): Mat? { |
|
||||
try { |
|
||||
// Resize image to 640x640 |
|
||||
val resized = Mat() |
|
||||
Imgproc.resize(rgbMat, resized, Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble())) |
|
||||
|
|
||||
// Convert to float and normalize manually |
|
||||
val floatMat = Mat() |
|
||||
resized.convertTo(floatMat, CvType.CV_32F, 1.0/255.0) |
|
||||
|
|
||||
Log.d(TAG, "🔧 Manual resize: ${resized.cols()}x${resized.rows()}") |
|
||||
Log.d(TAG, "🔧 Float conversion: type=${floatMat.type()}") |
|
||||
|
|
||||
// Check if float conversion worked |
|
||||
val testFloat = FloatArray(3) |
|
||||
floatMat.get(100, 100, testFloat) |
|
||||
val testFloatValues = testFloat.map { String.format("%.4f", it) }.joinToString(", ") |
|
||||
Log.d(TAG, "🔧 Float test values: [$testFloatValues]") |
|
||||
|
|
||||
// Use OpenCV's blobFromImage on the preprocessed float mat |
|
||||
val blob = Dnn.blobFromImage( |
|
||||
floatMat, |
|
||||
1.0, // No additional scaling since already normalized |
|
||||
Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble()), |
|
||||
Scalar(0.0, 0.0, 0.0), |
|
||||
false, // Don't swap channels |
|
||||
false, // Don't crop |
|
||||
CvType.CV_32F |
|
||||
) |
|
||||
|
|
||||
Log.d(TAG, "🔧 Manual blob from preprocessed image") |
|
||||
|
|
||||
// Clean up |
|
||||
resized.release() |
|
||||
floatMat.release() |
|
||||
|
|
||||
return blob |
|
||||
|
|
||||
} catch (e: Exception) { |
|
||||
Log.e(TAG, "❌ Manual blob creation failed", e) |
|
||||
return null |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
private fun postprocess(outputs: List<Mat>, originalWidth: Int, originalHeight: Int): List<Detection> { |
|
||||
if (outputs.isEmpty()) return emptyList() |
|
||||
|
|
||||
val detections = mutableListOf<Detection>() |
|
||||
val confidences = mutableListOf<Float>() |
|
||||
val boxes = mutableListOf<Rect>() |
|
||||
val classIds = mutableListOf<Int>() |
|
||||
|
|
||||
// Calculate scale factors |
|
||||
val xScale = originalWidth.toFloat() / INPUT_SIZE |
|
||||
val yScale = originalHeight.toFloat() / INPUT_SIZE |
|
||||
|
|
||||
// Process each output |
|
||||
for (outputIndex in outputs.indices) { |
|
||||
val output = outputs[outputIndex] |
|
||||
val data = FloatArray((output.total() * output.channels()).toInt()) |
|
||||
output.get(0, 0, data) |
|
||||
|
|
||||
val rows = output.size(1).toInt() // Number of detections |
|
||||
val cols = output.size(2).toInt() // Features per detection |
|
||||
|
|
||||
Log.d(TAG, "🔍 Output $outputIndex: ${rows} detections, ${cols} features each") |
|
||||
Log.d(TAG, "🔍 Output shape: [${output.size(0)}, ${output.size(1)}, ${output.size(2)}]") |
|
||||
|
|
||||
// Debug: Check first few values to understand format |
|
||||
if (data.size >= 10) { |
|
||||
val firstValues = data.take(10).map { String.format("%.4f", it) }.joinToString(", ") |
|
||||
Log.d(TAG, "🔍 First 10 values: [$firstValues]") |
|
||||
|
|
||||
// Check max confidence in first 100 values to verify model output |
|
||||
val maxConf = data.take(100).maxOrNull() ?: 0f |
|
||||
Log.d(TAG, "🔍 Max value in first 100: ${String.format("%.4f", maxConf)}") |
|
||||
} |
|
||||
|
|
||||
// Check if this might be a transposed output (8400 detections, 100 features) |
|
||||
if (cols == 8400 && rows == 100) { |
|
||||
Log.d(TAG, "🤔 Detected transposed output format - trying alternative parsing") |
|
||||
parseTransposedOutput(data, rows, cols, xScale, yScale, boxes, confidences, classIds) |
|
||||
continue |
|
||||
} |
|
||||
|
|
||||
var validDetections = 0 |
|
||||
for (i in 0 until rows) { |
|
||||
val offset = i * cols |
|
||||
|
|
||||
if (offset + 4 >= data.size) { |
|
||||
Log.w(TAG, "⚠️ Data array too small for detection $i") |
|
||||
break |
|
||||
} |
|
||||
|
|
||||
// Extract box coordinates (center format) |
|
||||
val centerX = data[offset + 0] * xScale |
|
||||
val centerY = data[offset + 1] * yScale |
|
||||
val width = data[offset + 2] * xScale |
|
||||
val height = data[offset + 3] * yScale |
|
||||
|
|
||||
// Convert to top-left corner format |
|
||||
val x = (centerX - width / 2).toInt() |
|
||||
val y = (centerY - height / 2).toInt() |
|
||||
|
|
||||
// Extract confidence and class scores |
|
||||
val confidence = data[offset + 4] |
|
||||
|
|
||||
// Debug first few detections |
|
||||
if (i < 3) { |
|
||||
Log.d(TAG, "🔍 Detection $i: conf=${String.format("%.4f", confidence)}, x=${String.format("%.1f", centerX)}, y=${String.format("%.1f", centerY)}, w=${String.format("%.1f", width)}, h=${String.format("%.1f", height)}") |
|
||||
} |
|
||||
|
|
||||
if (confidence > CONFIDENCE_THRESHOLD) { |
|
||||
// Find class with highest score |
|
||||
var maxClassScore = 0f |
|
||||
var classId = 0 |
|
||||
|
|
||||
for (j in 5 until cols) { |
|
||||
if (offset + j >= data.size) break |
|
||||
val classScore = data[offset + j] |
|
||||
if (classScore > maxClassScore) { |
|
||||
maxClassScore = classScore |
|
||||
classId = j - 5 |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
val finalConfidence = confidence * maxClassScore |
|
||||
|
|
||||
if (finalConfidence > CONFIDENCE_THRESHOLD) { |
|
||||
boxes.add(Rect(x, y, width.toInt(), height.toInt())) |
|
||||
confidences.add(finalConfidence) |
|
||||
classIds.add(classId) |
|
||||
validDetections++ |
|
||||
|
|
||||
if (validDetections <= 3) { |
|
||||
Log.d(TAG, "✅ Valid detection: class=$classId, conf=${String.format("%.4f", finalConfidence)}") |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
Log.d(TAG, "🎯 Found ${validDetections} valid detections above confidence threshold") |
|
||||
} |
|
||||
|
|
||||
// Apply Non-Maximum Suppression |
|
||||
Log.d(TAG, "📊 Before NMS: ${boxes.size} detections") |
|
||||
|
|
||||
if (boxes.isEmpty()) { |
|
||||
Log.d(TAG, "⚠️ No detections found before NMS") |
|
||||
return emptyList() |
|
||||
} |
|
||||
|
|
||||
val indices = MatOfInt() |
|
||||
val boxesArray = MatOfRect2d() |
|
||||
|
|
||||
// Convert Rect to Rect2d for NMSBoxes |
|
||||
val boxes2d = boxes.map { Rect2d(it.x.toDouble(), it.y.toDouble(), it.width.toDouble(), it.height.toDouble()) } |
|
||||
boxesArray.fromList(boxes2d) |
|
||||
|
|
||||
val confidencesArray = FloatArray(confidences.size) |
|
||||
for (i in confidences.indices) { |
|
||||
confidencesArray[i] = confidences[i] |
|
||||
} |
|
||||
|
|
||||
try { |
|
||||
Dnn.NMSBoxes( |
|
||||
boxesArray, |
|
||||
MatOfFloat(*confidencesArray), |
|
||||
CONFIDENCE_THRESHOLD, |
|
||||
NMS_THRESHOLD, |
|
||||
indices |
|
||||
) |
|
||||
Log.d(TAG, "✅ NMS completed successfully") |
|
||||
} catch (e: Exception) { |
|
||||
Log.e(TAG, "❌ NMS failed: ${e.message}", e) |
|
||||
return emptyList() |
|
||||
} |
|
||||
|
|
||||
// Build final detection list |
|
||||
val indicesArray = indices.toArray() |
|
||||
|
|
||||
// Check if NMS returned any valid indices |
|
||||
if (indicesArray.isEmpty()) { |
|
||||
Log.d(TAG, "🎯 NMS filtered out all detections") |
|
||||
indices.release() |
|
||||
return emptyList() |
|
||||
} |
|
||||
|
|
||||
for (i in indicesArray) { |
|
||||
val className = classNames[classIds[i]] ?: "unknown_${classIds[i]}" |
|
||||
detections.add( |
|
||||
Detection( |
|
||||
classId = classIds[i], |
|
||||
className = className, |
|
||||
confidence = confidences[i], |
|
||||
boundingBox = boxes[i] |
|
||||
) |
|
||||
) |
|
||||
} |
|
||||
|
|
||||
// Clean up |
|
||||
indices.release() |
|
||||
// Note: MatOfRect2d doesn't have a release method in OpenCV Android |
|
||||
|
|
||||
Log.d(TAG, "🎯 Post-processing complete: ${detections.size} final detections after NMS") |
|
||||
|
|
||||
return detections.sortedByDescending { it.confidence } |
|
||||
} |
|
||||
|
|
||||
private fun copyAssetToInternalStorage(assetName: String): String? { |
|
||||
return try { |
|
||||
val inputStream = context.assets.open(assetName) |
|
||||
val file = context.getFileStreamPath(assetName) |
|
||||
val outputStream = FileOutputStream(file) |
|
||||
|
|
||||
inputStream.copyTo(outputStream) |
|
||||
inputStream.close() |
|
||||
outputStream.close() |
|
||||
|
|
||||
file.absolutePath |
|
||||
} catch (e: IOException) { |
|
||||
Log.e(TAG, "Error copying asset $assetName", e) |
|
||||
null |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
fun testWithStaticImage(): List<Detection> { |
|
||||
if (!isInitialized) { |
|
||||
Log.e(TAG, "❌ YOLO detector not initialized for test") |
|
||||
return emptyList() |
|
||||
} |
|
||||
|
|
||||
try { |
|
||||
Log.i(TAG, "🧪 TESTING WITH STATIC IMAGE") |
|
||||
|
|
||||
// Load test image from assets |
|
||||
val inputStream = context.assets.open("test_pokemon.jpg") |
|
||||
val bitmap = BitmapFactory.decodeStream(inputStream) |
|
||||
inputStream.close() |
|
||||
|
|
||||
if (bitmap == null) { |
|
||||
Log.e(TAG, "❌ Failed to load test_pokemon.jpg from assets") |
|
||||
return emptyList() |
|
||||
} |
|
||||
|
|
||||
Log.i(TAG, "📸 Loaded test image: ${bitmap.width}x${bitmap.height}") |
|
||||
|
|
||||
// Convert bitmap to OpenCV Mat |
|
||||
val mat = Mat() |
|
||||
Utils.bitmapToMat(bitmap, mat) |
|
||||
|
|
||||
Log.i(TAG, "🔄 Converted to Mat: ${mat.cols()}x${mat.rows()}, channels: ${mat.channels()}") |
|
||||
|
|
||||
// Run detection |
|
||||
val detections = detect(mat) |
|
||||
|
|
||||
Log.i(TAG, "🎯 TEST RESULT: ${detections.size} detections found") |
|
||||
detections.forEachIndexed { index, detection -> |
|
||||
Log.i(TAG, " $index: ${detection.className} (${String.format("%.3f", detection.confidence)}) at [${detection.boundingBox.x}, ${detection.boundingBox.y}, ${detection.boundingBox.width}, ${detection.boundingBox.height}]") |
|
||||
} |
|
||||
|
|
||||
// Clean up |
|
||||
mat.release() |
|
||||
bitmap.recycle() |
|
||||
|
|
||||
return detections |
|
||||
|
|
||||
} catch (e: Exception) { |
|
||||
Log.e(TAG, "❌ Error in static image test", e) |
|
||||
return emptyList() |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
fun release() { |
|
||||
net = null |
|
||||
isInitialized = false |
|
||||
Log.d(TAG, "YOLO detector released") |
|
||||
} |
|
||||
} |
|
||||
@ -1,749 +0,0 @@ |
|||||
package com.quillstudios.pokegoalshelper |
|
||||
|
|
||||
import android.content.Context |
|
||||
import android.graphics.Bitmap |
|
||||
import android.graphics.BitmapFactory |
|
||||
import android.util.Log |
|
||||
import org.opencv.android.Utils |
|
||||
import org.opencv.core.* |
|
||||
import org.opencv.dnn.Dnn |
|
||||
import org.opencv.imgproc.Imgproc |
|
||||
import org.tensorflow.lite.Interpreter |
|
||||
import java.io.FileInputStream |
|
||||
import java.io.FileOutputStream |
|
||||
import java.io.IOException |
|
||||
import java.nio.ByteBuffer |
|
||||
import java.nio.ByteOrder |
|
||||
import java.nio.channels.FileChannel |
|
||||
import kotlin.math.max |
|
||||
import kotlin.math.min |
|
||||
|
|
||||
class YOLOTFLiteDetector(private val context: Context) { |
|
||||
|
|
||||
companion object { |
|
||||
private const val TAG = "YOLOTFLiteDetector" |
|
||||
private const val MODEL_FILE = "pokemon_model.tflite" |
|
||||
private const val INPUT_SIZE = 640 |
|
||||
private const val CONFIDENCE_THRESHOLD = 0.05f // Extremely low to test conversion quality |
|
||||
private const val NMS_THRESHOLD = 0.4f |
|
||||
private const val NUM_CHANNELS = 3 |
|
||||
private const val NUM_DETECTIONS = 8400 // YOLOv8 default |
|
||||
private const val NUM_CLASSES = 95 // Your class count |
|
||||
} |
|
||||
|
|
||||
private var interpreter: Interpreter? = null |
|
||||
private var isInitialized = false |
|
||||
|
|
||||
// Input/output buffers |
|
||||
private var inputBuffer: ByteBuffer? = null |
|
||||
private var outputBuffer: FloatArray? = null |
|
||||
|
|
||||
// Your class names (same as before) |
|
||||
private val classNames = mapOf( |
|
||||
0 to "ball_icon_pokeball", |
|
||||
1 to "ball_icon_greatball", |
|
||||
2 to "ball_icon_ultraball", |
|
||||
3 to "ball_icon_masterball", |
|
||||
4 to "ball_icon_safariball", |
|
||||
5 to "ball_icon_levelball", |
|
||||
6 to "ball_icon_lureball", |
|
||||
7 to "ball_icon_moonball", |
|
||||
8 to "ball_icon_friendball", |
|
||||
9 to "ball_icon_loveball", |
|
||||
10 to "ball_icon_heavyball", |
|
||||
11 to "ball_icon_fastball", |
|
||||
12 to "ball_icon_sportball", |
|
||||
13 to "ball_icon_premierball", |
|
||||
14 to "ball_icon_repeatball", |
|
||||
15 to "ball_icon_timerball", |
|
||||
16 to "ball_icon_nestball", |
|
||||
17 to "ball_icon_netball", |
|
||||
18 to "ball_icon_diveball", |
|
||||
19 to "ball_icon_luxuryball", |
|
||||
20 to "ball_icon_healball", |
|
||||
21 to "ball_icon_quickball", |
|
||||
22 to "ball_icon_duskball", |
|
||||
23 to "ball_icon_cherishball", |
|
||||
24 to "ball_icon_dreamball", |
|
||||
25 to "ball_icon_beastball", |
|
||||
26 to "ball_icon_strangeparts", |
|
||||
27 to "ball_icon_parkball", |
|
||||
28 to "ball_icon_gsball", |
|
||||
29 to "pokemon_nickname", |
|
||||
30 to "gender_icon_male", |
|
||||
31 to "gender_icon_female", |
|
||||
32 to "pokemon_level", |
|
||||
33 to "language", |
|
||||
34 to "last_game_stamp_home", |
|
||||
35 to "last_game_stamp_lgp", |
|
||||
36 to "last_game_stamp_lge", |
|
||||
37 to "last_game_stamp_sw", |
|
||||
38 to "last_game_stamp_sh", |
|
||||
39 to "last_game_stamp_bank", |
|
||||
40 to "last_game_stamp_bd", |
|
||||
41 to "last_game_stamp_sp", |
|
||||
42 to "last_game_stamp_pla", |
|
||||
43 to "last_game_stamp_sc", |
|
||||
44 to "last_game_stamp_vi", |
|
||||
45 to "last_game_stamp_go", |
|
||||
46 to "national_dex_number", |
|
||||
47 to "pokemon_species", |
|
||||
48 to "type_1", |
|
||||
49 to "type_2", |
|
||||
50 to "shiny_icon", |
|
||||
51 to "origin_icon_vc", |
|
||||
52 to "origin_icon_xyoras", |
|
||||
53 to "origin_icon_smusum", |
|
||||
54 to "origin_icon_lg", |
|
||||
55 to "origin_icon_swsh", |
|
||||
56 to "origin_icon_go", |
|
||||
57 to "origin_icon_bdsp", |
|
||||
58 to "origin_icon_pla", |
|
||||
59 to "origin_icon_sv", |
|
||||
60 to "pokerus_infected_icon", |
|
||||
61 to "pokerus_cured_icon", |
|
||||
62 to "hp_value", |
|
||||
63 to "attack_value", |
|
||||
64 to "defense_value", |
|
||||
65 to "sp_atk_value", |
|
||||
66 to "sp_def_value", |
|
||||
67 to "speed_value", |
|
||||
68 to "ability_name", |
|
||||
69 to "nature_name", |
|
||||
70 to "move_name", |
|
||||
71 to "original_trainer_name", |
|
||||
72 to "original_trainder_number", |
|
||||
73 to "alpha_mark", |
|
||||
74 to "tera_water", |
|
||||
75 to "tera_psychic", |
|
||||
76 to "tera_ice", |
|
||||
77 to "tera_fairy", |
|
||||
78 to "tera_poison", |
|
||||
79 to "tera_ghost", |
|
||||
80 to "ball_icon_originball", |
|
||||
81 to "tera_dragon", |
|
||||
82 to "tera_steel", |
|
||||
83 to "tera_grass", |
|
||||
84 to "tera_normal", |
|
||||
85 to "tera_fire", |
|
||||
86 to "tera_electric", |
|
||||
87 to "tera_fighting", |
|
||||
88 to "tera_ground", |
|
||||
89 to "tera_flying", |
|
||||
90 to "tera_bug", |
|
||||
91 to "tera_rock", |
|
||||
92 to "tera_dark", |
|
||||
93 to "low_confidence", |
|
||||
94 to "ball_icon_pokeball_hisui", |
|
||||
95 to "ball_icon_ultraball_husui" |
|
||||
) |
|
||||
|
|
||||
fun initialize(): Boolean { |
|
||||
if (isInitialized) return true |
|
||||
|
|
||||
try { |
|
||||
Log.i(TAG, "🤖 Initializing TensorFlow Lite YOLO detector...") |
|
||||
|
|
||||
// Load model from assets |
|
||||
Log.i(TAG, "📂 Copying model file: $MODEL_FILE") |
|
||||
val modelPath = copyAssetToInternalStorage(MODEL_FILE) |
|
||||
if (modelPath == null) { |
|
||||
Log.e(TAG, "❌ Failed to copy TFLite model from assets") |
|
||||
return false |
|
||||
} |
|
||||
Log.i(TAG, "✅ Model copied to: $modelPath") |
|
||||
|
|
||||
// Create interpreter |
|
||||
Log.i(TAG, "📥 Loading TFLite model from: $modelPath") |
|
||||
val modelFile = loadModelFile(modelPath) |
|
||||
Log.i(TAG, "📥 Model file loaded, size: ${modelFile.capacity()} bytes") |
|
||||
|
|
||||
val options = Interpreter.Options() |
|
||||
options.setNumThreads(4) // Use 4 CPU threads |
|
||||
Log.i(TAG, "🔧 Creating TensorFlow Lite interpreter...") |
|
||||
interpreter = Interpreter(modelFile, options) |
|
||||
Log.i(TAG, "✅ Interpreter created successfully") |
|
||||
|
|
||||
// Get model info |
|
||||
val inputTensor = interpreter!!.getInputTensor(0) |
|
||||
val outputTensor = interpreter!!.getOutputTensor(0) |
|
||||
Log.i(TAG, "📊 Input tensor shape: ${inputTensor.shape().contentToString()}") |
|
||||
Log.i(TAG, "📊 Output tensor shape: ${outputTensor.shape().contentToString()}") |
|
||||
|
|
||||
// Allocate input/output buffers |
|
||||
Log.i(TAG, "📦 Allocating buffers...") |
|
||||
allocateBuffers() |
|
||||
Log.i(TAG, "✅ Buffers allocated") |
|
||||
|
|
||||
// Test model with dummy input |
|
||||
Log.i(TAG, "🧪 Testing model with dummy input...") |
|
||||
testModelInputOutput() |
|
||||
Log.i(TAG, "✅ Model test completed") |
|
||||
|
|
||||
isInitialized = true |
|
||||
Log.i(TAG, "✅ TensorFlow Lite YOLO detector initialized successfully") |
|
||||
Log.i(TAG, "📊 Model info: ${classNames.size} classes, input size: ${INPUT_SIZE}x${INPUT_SIZE}") |
|
||||
|
|
||||
return true |
|
||||
|
|
||||
} catch (e: Exception) { |
|
||||
Log.e(TAG, "❌ Error initializing TensorFlow Lite detector", e) |
|
||||
e.printStackTrace() |
|
||||
return false |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
private fun loadModelFile(modelPath: String): ByteBuffer { |
|
||||
val fileInputStream = FileInputStream(modelPath) |
|
||||
val fileChannel = fileInputStream.channel |
|
||||
val startOffset = 0L |
|
||||
val declaredLength = fileChannel.size() |
|
||||
val modelBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) |
|
||||
fileInputStream.close() |
|
||||
return modelBuffer |
|
||||
} |
|
||||
|
|
||||
private fun allocateBuffers() { |
|
||||
// Get actual tensor shapes from the model |
|
||||
val inputTensor = interpreter!!.getInputTensor(0) |
|
||||
val outputTensor = interpreter!!.getOutputTensor(0) |
|
||||
|
|
||||
val inputShape = inputTensor.shape() |
|
||||
val outputShape = outputTensor.shape() |
|
||||
|
|
||||
Log.d(TAG, "📊 Actual input shape: ${inputShape.contentToString()}") |
|
||||
Log.d(TAG, "📊 Actual output shape: ${outputShape.contentToString()}") |
|
||||
|
|
||||
// Input buffer: [1, 640, 640, 3] * 4 bytes per float |
|
||||
val inputSize = inputShape.fold(1) { acc, dim -> acc * dim } * 4 |
|
||||
inputBuffer = ByteBuffer.allocateDirect(inputSize) |
|
||||
inputBuffer!!.order(ByteOrder.nativeOrder()) |
|
||||
|
|
||||
// Output buffer: [1, 100, 8400] |
|
||||
val outputSize = outputShape.fold(1) { acc, dim -> acc * dim } |
|
||||
outputBuffer = FloatArray(outputSize) |
|
||||
|
|
||||
Log.d(TAG, "📦 Allocated input buffer: ${inputSize} bytes") |
|
||||
Log.d(TAG, "📦 Allocated output buffer: ${outputSize} floats") |
|
||||
} |
|
||||
|
|
||||
private fun testModelInputOutput() { |
|
||||
try { |
|
||||
// Fill input buffer with dummy data |
|
||||
inputBuffer!!.rewind() |
|
||||
repeat(INPUT_SIZE * INPUT_SIZE * NUM_CHANNELS) { // HWC format |
|
||||
inputBuffer!!.putFloat(0.5f) // Dummy normalized pixel value |
|
||||
} |
|
||||
|
|
||||
// Create output array as multidimensional array for TensorFlow Lite |
|
||||
val outputShape = interpreter!!.getOutputTensor(0).shape() |
|
||||
Log.d(TAG, "🔍 Output tensor shape: ${outputShape.contentToString()}") |
|
||||
|
|
||||
// Create output as 3D array: [batch][features][detections] |
|
||||
val output = Array(outputShape[0]) { |
|
||||
Array(outputShape[1]) { |
|
||||
FloatArray(outputShape[2]) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Run inference |
|
||||
inputBuffer!!.rewind() |
|
||||
interpreter!!.run(inputBuffer, output) |
|
||||
|
|
||||
// Check output |
|
||||
val firstBatch = output[0] |
|
||||
val maxOutput = firstBatch.flatMap { it.asIterable() }.maxOrNull() ?: 0f |
|
||||
val minOutput = firstBatch.flatMap { it.asIterable() }.minOrNull() ?: 0f |
|
||||
Log.i(TAG, "🧪 Model test: output range [${String.format("%.4f", minOutput)}, ${String.format("%.4f", maxOutput)}]") |
|
||||
|
|
||||
// Convert 3D array to flat array for postprocessing |
|
||||
outputBuffer = firstBatch.flatMap { it.asIterable() }.toFloatArray() |
|
||||
Log.d(TAG, "🧪 Converted output to flat array of size: ${outputBuffer!!.size}") |
|
||||
|
|
||||
} catch (e: Exception) { |
|
||||
Log.e(TAG, "❌ Model test failed", e) |
|
||||
throw e |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
fun detect(inputMat: Mat): List<Detection> { |
|
||||
if (!isInitialized || interpreter == null) { |
|
||||
Log.w(TAG, "⚠️ TensorFlow Lite detector not initialized") |
|
||||
return emptyList() |
|
||||
} |
|
||||
|
|
||||
try { |
|
||||
Log.d(TAG, "🔍 Running TFLite YOLO detection on ${inputMat.cols()}x${inputMat.rows()} image") |
|
||||
|
|
||||
// Preprocess image |
|
||||
preprocessImage(inputMat) |
|
||||
|
|
||||
// Create output array as multidimensional array for TensorFlow Lite |
|
||||
val outputShape = interpreter!!.getOutputTensor(0).shape() |
|
||||
val output = Array(outputShape[0]) { |
|
||||
Array(outputShape[1]) { |
|
||||
FloatArray(outputShape[2]) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Run inference |
|
||||
inputBuffer!!.rewind() |
|
||||
interpreter!!.run(inputBuffer, output) |
|
||||
|
|
||||
// Convert 3D array to flat array for postprocessing |
|
||||
val flatOutput = output[0].flatMap { it.asIterable() }.toFloatArray() |
|
||||
|
|
||||
// Post-process results |
|
||||
val detections = postprocess(flatOutput, inputMat.cols(), inputMat.rows()) |
|
||||
|
|
||||
Log.i(TAG, "✅ TFLite YOLO detection complete: ${detections.size} objects detected") |
|
||||
|
|
||||
return detections |
|
||||
|
|
||||
} catch (e: Exception) { |
|
||||
Log.e(TAG, "❌ Error during TFLite YOLO detection", e) |
|
||||
return emptyList() |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Corrected preprocessImage function |
|
||||
private fun preprocessImage(mat: Mat) { |
|
||||
// Convert to RGB |
|
||||
val rgbMat = Mat() |
|
||||
// It's safer to always explicitly convert to the expected 3-channel type |
|
||||
// Assuming input `mat` can be from RGBA (screen capture) or BGR (file) |
|
||||
if (mat.channels() == 4) { |
|
||||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGRA2RGB) // Assuming screen capture is BGRA |
|
||||
} else { // Handle 3-channel BGR or RGB direct |
|
||||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB) // Convert BGR (OpenCV default) to RGB |
|
||||
} |
|
||||
|
|
||||
// Resize to input size (640x640) |
|
||||
val resized = Mat() |
|
||||
Imgproc.resize(rgbMat, resized, Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble())) |
|
||||
|
|
||||
Log.d(TAG, "🖼️ Preprocessed image: ${resized.cols()}x${resized.rows()}, channels: ${resized.channels()}") |
|
||||
|
|
||||
// Prepare a temporary byte array to get pixel data from Mat |
|
||||
val pixels = ByteArray(INPUT_SIZE * INPUT_SIZE * NUM_CHANNELS) |
|
||||
resized.get(0, 0, pixels) // Get pixel data in HWC (Height, Width, Channel) byte order |
|
||||
|
|
||||
// Convert to ByteBuffer in CHW (Channel, Height, Width) float format |
|
||||
inputBuffer!!.rewind() |
|
||||
|
|
||||
for (c in 0 until NUM_CHANNELS) { // Iterate channels |
|
||||
for (y in 0 until INPUT_SIZE) { // Iterate height |
|
||||
for (x in 0 until INPUT_SIZE) { // Iterate width |
|
||||
// Calculate index in the HWC 'pixels' array |
|
||||
val pixelIndex = (y * INPUT_SIZE + x) * NUM_CHANNELS + c |
|
||||
// Get byte value, convert to unsigned int (0-255), then to float, then normalize |
|
||||
val pixelValue = (pixels[pixelIndex].toInt() and 0xFF) / 255.0f |
|
||||
inputBuffer!!.putFloat(pixelValue) |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Debug: Check first few values |
|
||||
inputBuffer!!.rewind() |
|
||||
val testValues = FloatArray(10) |
|
||||
inputBuffer!!.asFloatBuffer().get(testValues) |
|
||||
val testStr = testValues.map { String.format("%.4f", it) }.joinToString(", ") |
|
||||
Log.d(TAG, "🌐 Input buffer first 10 values (CHW): [$testStr]") |
|
||||
|
|
||||
// Clean up |
|
||||
rgbMat.release() |
|
||||
resized.release() |
|
||||
} |
|
||||
|
|
||||
/* |
|
||||
private fun preprocessImage(mat: Mat) { |
|
||||
// Convert to RGB |
|
||||
val rgbMat = Mat() |
|
||||
if (mat.channels() == 4) { |
|
||||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGRA2RGB) |
|
||||
} else if (mat.channels() == 3) { |
|
||||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB) |
|
||||
} else { |
|
||||
mat.copyTo(rgbMat) |
|
||||
} |
|
||||
|
|
||||
// Resize to input size |
|
||||
val resized = Mat() |
|
||||
Imgproc.resize(rgbMat, resized, Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble())) |
|
||||
|
|
||||
Log.d(TAG, "🖼️ Preprocessed image: ${resized.cols()}x${resized.rows()}, channels: ${resized.channels()}") |
|
||||
|
|
||||
// Convert to ByteBuffer in HWC format [640, 640, 3] |
|
||||
inputBuffer!!.rewind() |
|
||||
|
|
||||
|
|
||||
val rgbBytes = ByteArray(INPUT_SIZE * INPUT_SIZE * 3) |
|
||||
resized.get(0, 0, rgbBytes) |
|
||||
|
|
||||
// Convert to float and normalize in HWC format (Height, Width, Channels) |
|
||||
// The data is already in HWC format from OpenCV |
|
||||
for (i in rgbBytes.indices) { |
|
||||
val pixelValue = (rgbBytes[i].toInt() and 0xFF) / 255.0f |
|
||||
inputBuffer!!.putFloat(pixelValue) |
|
||||
} |
|
||||
|
|
||||
// Debug: Check first few values |
|
||||
inputBuffer!!.rewind() |
|
||||
val testValues = FloatArray(10) |
|
||||
inputBuffer!!.asFloatBuffer().get(testValues) |
|
||||
val testStr = testValues.map { String.format("%.4f", it) }.joinToString(", ") |
|
||||
Log.d(TAG, "🌐 Input buffer first 10 values: [$testStr]") |
|
||||
|
|
||||
// Clean up |
|
||||
rgbMat.release() |
|
||||
resized.release() |
|
||||
} |
|
||||
*/ |
|
||||
/* |
|
||||
private fun postprocess(output: FloatArray, originalWidth: Int, originalHeight: Int): List<Detection> { |
|
||||
val detections = mutableListOf<Detection>() |
|
||||
val confidences = mutableListOf<Float>() |
|
||||
val boxes = mutableListOf<Rect>() |
|
||||
val classIds = mutableListOf<Int>() |
|
||||
|
|
||||
Log.d(TAG, "🔍 Processing detections from output array of size ${output.size}") |
|
||||
Log.d(TAG, "🔍 Original image size: ${originalWidth}x${originalHeight}") |
|
||||
|
|
||||
// YOLOv8 outputs normalized coordinates (0-1), so we scale directly to original image size |
|
||||
val numFeatures = 100 // From actual model output |
|
||||
val numDetections = 8400 // From actual model output |
|
||||
|
|
||||
var validDetections = 0 |
|
||||
|
|
||||
// Process transposed output: [1, 100, 8400] |
|
||||
// Features are: [x, y, w, h, conf, class0, class1, ..., class94] |
|
||||
for (i in 0 until numDetections) { |
|
||||
// In transposed format: feature_idx * numDetections + detection_idx |
|
||||
// YOLOv8 outputs normalized coordinates (0-1), scale to original image size |
|
||||
val centerX = output[0 * numDetections + i] * originalWidth // x row |
|
||||
val centerY = output[1 * numDetections + i] * originalHeight // y row |
|
||||
val width = output[2 * numDetections + i] * originalWidth // w row |
|
||||
val height = output[3 * numDetections + i] * originalHeight // h row |
|
||||
val confidence = output[4 * numDetections + i] // confidence row |
|
||||
|
|
||||
// Debug first few detections |
|
||||
if (i < 3) { |
|
||||
val rawX = output[0 * numDetections + i] |
|
||||
val rawY = output[1 * numDetections + i] |
|
||||
val rawW = output[2 * numDetections + i] |
|
||||
val rawH = output[3 * numDetections + i] |
|
||||
Log.d(TAG, "🔍 Detection $i: raw x=${String.format("%.3f", rawX)}, y=${String.format("%.3f", rawY)}, w=${String.format("%.3f", rawW)}, h=${String.format("%.3f", rawH)}") |
|
||||
Log.d(TAG, "🔍 Detection $i: scaled x=${String.format("%.1f", centerX)}, y=${String.format("%.1f", centerY)}, w=${String.format("%.1f", width)}, h=${String.format("%.1f", height)}") |
|
||||
} |
|
||||
|
|
||||
// Try different YOLOv8 format: no separate confidence, max class score is the confidence |
|
||||
var maxClassScore = 0f |
|
||||
var classId = 0 |
|
||||
|
|
||||
for (j in 4 until numFeatures) { // Start from feature 4 (after x,y,w,h), no separate conf |
|
||||
val classIdx = j * numDetections + i |
|
||||
if (classIdx >= output.size) break |
|
||||
|
|
||||
val classScore = output[classIdx] |
|
||||
if (classScore > maxClassScore) { |
|
||||
maxClassScore = classScore |
|
||||
classId = j - 4 // Convert to 0-based class index |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// Debug first few with max class scores |
|
||||
if (i < 3) { |
|
||||
Log.d(TAG, "🔍 Detection $i: maxClass=${String.format("%.4f", maxClassScore)}, classId=$classId") |
|
||||
} |
|
||||
|
|
||||
if (maxClassScore > CONFIDENCE_THRESHOLD && classId < classNames.size) { |
|
||||
val x = (centerX - width / 2).toInt() |
|
||||
val y = (centerY - height / 2).toInt() |
|
||||
|
|
||||
boxes.add(Rect(x, y, width.toInt(), height.toInt())) |
|
||||
confidences.add(maxClassScore) |
|
||||
classIds.add(classId) |
|
||||
validDetections++ |
|
||||
|
|
||||
if (validDetections <= 3) { |
|
||||
Log.d(TAG, "✅ Valid transposed detection: class=$classId, conf=${String.format("%.4f", maxClassScore)}") |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
Log.d(TAG, "🎯 Found ${validDetections} valid detections above confidence threshold") |
|
||||
|
|
||||
// Apply Non-Maximum Suppression (simple version) |
|
||||
val finalDetections = applyNMS(boxes, confidences, classIds) |
|
||||
|
|
||||
Log.d(TAG, "🎯 Post-processing complete: ${finalDetections.size} final detections after NMS") |
|
||||
|
|
||||
return finalDetections.sortedByDescending { it.confidence } |
|
||||
} |
|
||||
*/ |
|
||||
|
|
||||
// In postprocess function |
|
||||
private fun postprocess(output: FloatArray, originalWidth: Int, originalHeight: Int): List<Detection> { |
|
||||
val detections = mutableListOf<Detection>() |
|
||||
val confidences = mutableListOf<Float>() |
|
||||
val boxes = mutableListOf<Rect>() |
|
||||
val classIds = mutableListOf<Int>() |
|
||||
|
|
||||
Log.d(TAG, "🔍 Processing detections from output array of size ${output.size}") |
|
||||
Log.d(TAG, "🔍 Original image size: ${originalWidth}x${originalHeight}") |
|
||||
|
|
||||
// Corrected Interpretation based on YOUR observed output shape [1, 100, 8400] |
|
||||
// This means attributes (box, confidence, class scores) are in the second dimension (index 1) |
|
||||
// and detections are in the third dimension (index 2). |
|
||||
// So, you need to iterate through 'detections' (8400) and for each, access its 'attributes' (100). |
|
||||
val numAttributesPerDetection = 100 // This is your 'outputShape[1]' |
|
||||
val totalDetections = 8400 // This is your 'outputShape[2]' |
|
||||
|
|
||||
// Loop through each of the 8400 potential detections |
|
||||
for (i in 0 until totalDetections) { |
|
||||
// Get the attributes for the i-th detection |
|
||||
// The data for the i-th detection starts at index 'i' in the 'output' flat array, |
|
||||
// then it's interleaved. This is why it's better to process from the 3D array directly. |
|
||||
|
|
||||
// Re-think: If `output[0]` from interpreter.run is `Array(100) { FloatArray(8400) }` |
|
||||
// Then it's attributes_per_detection x total_detections. |
|
||||
// So, output[0][0] is the x-coords for all detections, output[0][1] is y-coords for all detections. |
|
||||
// Let's assume this structure from your `output` 3D array: |
|
||||
// output[0][attribute_idx][detection_idx] |
|
||||
|
|
||||
val centerX = output[0 * totalDetections + i] // x-coordinate for detection 'i' |
|
||||
val centerY = output[1 * totalDetections + i] // y-coordinate for detection 'i' |
|
||||
val width = output[2 * totalDetections + i] // width for detection 'i' |
|
||||
val height = output[3 * totalDetections + i] // height for detection 'i' |
|
||||
val objectnessConf = output[4 * totalDetections + i] // Objectness confidence for detection 'i' |
|
||||
|
|
||||
// Debug raw values and scaled values before class scores |
|
||||
if (i < 5) { // Log first 5 detections |
|
||||
Log.d(TAG, "🔍 Detection $i (pre-scale): x=${String.format("%.3f", centerX)}, y=${String.format("%.3f", centerY)}, w=${String.format("%.3f", width)}, h=${String.format("%.3f", height)}, obj_conf=${String.format("%.4f", objectnessConf)}") |
|
||||
} |
|
||||
|
|
||||
var maxClassScore = 0f |
|
||||
var classId = -1 // Initialize with -1 to catch issues |
|
||||
|
|
||||
// Loop through class scores (starting from index 5 in the attributes list) |
|
||||
// Indices 5 to 99 are class scores (95 classes total) |
|
||||
for (j in 5 until numAttributesPerDetection) { |
|
||||
// Get the class score for detection 'i' and class 'j-5' |
|
||||
val classScore = output[j * totalDetections + i] |
|
||||
|
|
||||
val classScore_sigmoid = 1.0f / (1.0f + Math.exp(-classScore.toDouble())).toFloat() // Apply sigmoid |
|
||||
|
|
||||
if (classScore_sigmoid > maxClassScore) { |
|
||||
maxClassScore = classScore_sigmoid |
|
||||
classId = j - 5 // Convert to 0-based class index |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
val objectnessConf_sigmoid = 1.0f / (1.0f + Math.exp(-objectnessConf.toDouble())).toFloat() // Apply sigmoid |
|
||||
|
|
||||
// Final confidence: Objectness score multiplied by the max class score |
|
||||
val finalConfidence = objectnessConf_sigmoid * maxClassScore |
|
||||
|
|
||||
// Debug final confidence for first few detections |
|
||||
if (i < 5) { |
|
||||
Log.d(TAG, "🔍 Detection $i (post-score): maxClass=${String.format("%.4f", maxClassScore)}, finalConf=${String.format("%.4f", finalConfidence)}, classId=$classId") |
|
||||
} |
|
||||
|
|
||||
// Apply confidence threshold |
|
||||
if (finalConfidence > CONFIDENCE_THRESHOLD && classId != -1 && classId < classNames.size) { |
|
||||
// Convert normalized coordinates (0-1) to pixel coordinates based on original image size |
|
||||
val x = ((centerX - width / 2) * originalWidth).toInt() |
|
||||
val y = ((centerY - height / 2) * originalHeight).toInt() |
|
||||
val w = (width * originalWidth).toInt() |
|
||||
val h = (height * originalHeight).toInt() |
|
||||
|
|
||||
// Ensure coordinates are within image bounds |
|
||||
val x1 = max(0, x) |
|
||||
val y1 = max(0, y) |
|
||||
val x2 = min(originalWidth, x + w) |
|
||||
val y2 = min(originalHeight, y + h) |
|
||||
|
|
||||
// Add to lists for NMS |
|
||||
boxes.add(Rect(x1, y1, x2 - x1, y2 - y1)) |
|
||||
confidences.add(finalConfidence) |
|
||||
classIds.add(classId) |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
Log.d(TAG, "🎯 Found ${boxes.size} detections above confidence threshold before NMS") |
|
||||
|
|
||||
// Apply Non-Maximum Suppression (using OpenCV's NMSBoxes which is more robust) |
|
||||
val finalDetections = applyNMS_OpenCV(boxes, confidences, classIds) |
|
||||
|
|
||||
Log.d(TAG, "🎯 Post-processing complete: ${finalDetections.size} final detections after NMS") |
|
||||
|
|
||||
return finalDetections.sortedByDescending { it.confidence } |
|
||||
} |
|
||||
|
|
||||
// Replace your applyNMS function with this (or rename your old one and call this one) |
|
||||
private fun applyNMS_OpenCV(boxes: List<Rect>, confidences: List<Float>, classIds: List<Int>): List<Detection> { |
|
||||
val finalDetections = mutableListOf<Detection>() |
|
||||
|
|
||||
// Convert List<Rect> to List<Rect2d> |
|
||||
val boxes2d = boxes.map { Rect2d(it.x.toDouble(), it.y.toDouble(), it.width.toDouble(), it.height.toDouble()) } |
|
||||
|
|
||||
// Correct way to convert List<Rect2d> to MatOfRect2d |
|
||||
val boxesMat = MatOfRect2d() |
|
||||
boxesMat.fromList(boxes2d) // Use fromList to populate the MatOfRect2d |
|
||||
|
|
||||
val confsMat = MatOfFloat() |
|
||||
confsMat.fromList(confidences) // This part was already correct |
|
||||
|
|
||||
val indices = MatOfInt() |
|
||||
|
|
||||
// OpenCV NMSBoxes |
|
||||
Dnn.NMSBoxes( |
|
||||
boxesMat, |
|
||||
confsMat, |
|
||||
CONFIDENCE_THRESHOLD, // Confidence threshold (boxes below this are ignored by NMS) |
|
||||
NMS_THRESHOLD, // IoU threshold (boxes with IoU above this are suppressed) |
|
||||
indices |
|
||||
) |
|
||||
|
|
||||
val ind = indices.toArray() // Get array of indices to keep |
|
||||
|
|
||||
for (i in ind.indices) { |
|
||||
val idx = ind[i] |
|
||||
val className = classNames[classIds[idx]] ?: "unknown_${classIds[idx]}" |
|
||||
finalDetections.add( |
|
||||
Detection( |
|
||||
classId = classIds[idx], |
|
||||
className = className, |
|
||||
confidence = confidences[idx], |
|
||||
boundingBox = boxes[idx] // Keep original Rect for the final Detection object if you prefer ints |
|
||||
) |
|
||||
) |
|
||||
} |
|
||||
|
|
||||
// Release Mats to prevent memory leaks |
|
||||
boxesMat.release() |
|
||||
confsMat.release() |
|
||||
indices.release() |
|
||||
|
|
||||
return finalDetections |
|
||||
} |
|
||||
|
|
||||
private fun applyNMS(boxes: List<Rect>, confidences: List<Float>, classIds: List<Int>): List<Detection> { |
|
||||
val detections = mutableListOf<Detection>() |
|
||||
|
|
||||
// Simple NMS implementation |
|
||||
val indices = confidences.indices.sortedByDescending { confidences[it] } |
|
||||
val suppressed = BooleanArray(boxes.size) |
|
||||
|
|
||||
for (i in indices) { |
|
||||
if (suppressed[i]) continue |
|
||||
|
|
||||
val className = classNames[classIds[i]] ?: "unknown_${classIds[i]}" |
|
||||
detections.add( |
|
||||
Detection( |
|
||||
classId = classIds[i], |
|
||||
className = className, |
|
||||
confidence = confidences[i], |
|
||||
boundingBox = boxes[i] |
|
||||
) |
|
||||
) |
|
||||
|
|
||||
// Suppress overlapping boxes |
|
||||
for (j in indices) { |
|
||||
if (i != j && !suppressed[j] && classIds[i] == classIds[j]) { |
|
||||
val iou = calculateIoU(boxes[i], boxes[j]) |
|
||||
if (iou > NMS_THRESHOLD) { |
|
||||
suppressed[j] = true |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
return detections |
|
||||
} |
|
||||
|
|
||||
private fun calculateIoU(box1: Rect, box2: Rect): Float { |
|
||||
val x1 = max(box1.x, box2.x) |
|
||||
val y1 = max(box1.y, box2.y) |
|
||||
val x2 = min(box1.x + box1.width, box2.x + box2.width) |
|
||||
val y2 = min(box1.y + box1.height, box2.y + box2.height) |
|
||||
|
|
||||
val intersection = max(0, x2 - x1) * max(0, y2 - y1) |
|
||||
val area1 = box1.width * box1.height |
|
||||
val area2 = box2.width * box2.height |
|
||||
val union = area1 + area2 - intersection |
|
||||
|
|
||||
return if (union > 0) intersection.toFloat() / union.toFloat() else 0f |
|
||||
} |
|
||||
|
|
||||
fun testWithStaticImage(): List<Detection> { |
|
||||
if (!isInitialized) { |
|
||||
Log.e(TAG, "❌ TensorFlow Lite detector not initialized for test") |
|
||||
return emptyList() |
|
||||
} |
|
||||
|
|
||||
try { |
|
||||
Log.i(TAG, "🧪 TESTING WITH STATIC IMAGE (TFLite)") |
|
||||
|
|
||||
// Load test image from assets |
|
||||
val inputStream = context.assets.open("test_pokemon.jpg") |
|
||||
val bitmap = BitmapFactory.decodeStream(inputStream) |
|
||||
inputStream.close() |
|
||||
|
|
||||
if (bitmap == null) { |
|
||||
Log.e(TAG, "❌ Failed to load test_pokemon.jpg from assets") |
|
||||
return emptyList() |
|
||||
} |
|
||||
|
|
||||
Log.i(TAG, "📸 Loaded test image: ${bitmap.width}x${bitmap.height}") |
|
||||
|
|
||||
// Convert bitmap to OpenCV Mat |
|
||||
val mat = Mat() |
|
||||
Utils.bitmapToMat(bitmap, mat) |
|
||||
|
|
||||
Log.i(TAG, "🔄 Converted to Mat: ${mat.cols()}x${mat.rows()}, channels: ${mat.channels()}") |
|
||||
|
|
||||
// Run detection |
|
||||
val detections = detect(mat) |
|
||||
|
|
||||
Log.i(TAG, "🎯 TFLite TEST RESULT: ${detections.size} detections found") |
|
||||
detections.forEachIndexed { index, detection -> |
|
||||
Log.i(TAG, " $index: ${detection.className} (${String.format("%.3f", detection.confidence)}) at [${detection.boundingBox.x}, ${detection.boundingBox.y}, ${detection.boundingBox.width}, ${detection.boundingBox.height}]") |
|
||||
} |
|
||||
|
|
||||
// Clean up |
|
||||
mat.release() |
|
||||
bitmap.recycle() |
|
||||
|
|
||||
return detections |
|
||||
|
|
||||
} catch (e: Exception) { |
|
||||
Log.e(TAG, "❌ Error in TFLite static image test", e) |
|
||||
return emptyList() |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
private fun copyAssetToInternalStorage(assetName: String): String? { |
|
||||
return try { |
|
||||
val inputStream = context.assets.open(assetName) |
|
||||
val file = context.getFileStreamPath(assetName) |
|
||||
val outputStream = FileOutputStream(file) |
|
||||
|
|
||||
inputStream.copyTo(outputStream) |
|
||||
inputStream.close() |
|
||||
outputStream.close() |
|
||||
|
|
||||
file.absolutePath |
|
||||
} catch (e: IOException) { |
|
||||
Log.e(TAG, "Error copying asset $assetName", e) |
|
||||
null |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
fun release() { |
|
||||
interpreter?.close() |
|
||||
interpreter = null |
|
||||
isInitialized = false |
|
||||
Log.d(TAG, "TensorFlow Lite detector released") |
|
||||
} |
|
||||
} |
|
||||
@ -0,0 +1,211 @@ |
|||||
|
package com.quillstudios.pokegoalshelper.ml |
||||
|
|
||||
|
import android.graphics.Bitmap |
||||
|
import android.util.Log |
||||
|
import org.opencv.android.Utils |
||||
|
import org.opencv.core.* |
||||
|
import org.opencv.imgproc.Imgproc |
||||
|
import kotlin.math.min |
||||
|
|
||||
|
/** |
||||
|
* Utility class for image preprocessing operations used in ML inference. |
||||
|
* Extracted for better separation of concerns and reusability. |
||||
|
*/ |
||||
|
object ImagePreprocessor |
||||
|
{ |
||||
|
private const val TAG = "ImagePreprocessor" |
||||
|
|
||||
|
/** |
||||
|
* Data class representing preprocessing configuration. |
||||
|
*/ |
||||
|
data class PreprocessConfig( |
||||
|
val targetSize: Int = 640, |
||||
|
val numChannels: Int = 3, |
||||
|
val normalizeRange: Pair<Float, Float> = Pair(0.0f, 1.0f), |
||||
|
val useLetterboxing: Boolean = true, |
||||
|
val colorConversion: Int = Imgproc.COLOR_BGR2RGB |
||||
|
) |
||||
|
|
||||
|
/** |
||||
|
* Result of preprocessing operation containing the processed data and metadata. |
||||
|
*/ |
||||
|
data class PreprocessResult( |
||||
|
val data: Array<Array<Array<FloatArray>>>, |
||||
|
val scale: Float, |
||||
|
val offsetX: Float, |
||||
|
val offsetY: Float, |
||||
|
val originalWidth: Int, |
||||
|
val originalHeight: Int |
||||
|
) |
||||
|
|
||||
|
/** |
||||
|
* Preprocess a bitmap for ML inference with the specified configuration. |
||||
|
* |
||||
|
* @param bitmap The input bitmap to preprocess |
||||
|
* @param config Preprocessing configuration |
||||
|
* @return PreprocessResult containing processed data and transformation metadata |
||||
|
*/ |
||||
|
fun preprocessBitmap(bitmap: Bitmap, config: PreprocessConfig = PreprocessConfig()): PreprocessResult |
||||
|
{ |
||||
|
val original_width = bitmap.width |
||||
|
val original_height = bitmap.height |
||||
|
|
||||
|
// Convert bitmap to Mat |
||||
|
val original_mat = Mat() |
||||
|
Utils.bitmapToMat(bitmap, original_mat) |
||||
|
|
||||
|
try |
||||
|
{ |
||||
|
val (processed_mat, scale, offset_x, offset_y) = if (config.useLetterboxing) |
||||
|
{ |
||||
|
applyLetterboxing(original_mat, config.targetSize) |
||||
|
} |
||||
|
else |
||||
|
{ |
||||
|
// Simple resize without letterboxing |
||||
|
val resized_mat = Mat() |
||||
|
Imgproc.resize(original_mat, resized_mat, Size(config.targetSize.toDouble(), config.targetSize.toDouble())) |
||||
|
val scale_x = config.targetSize.toFloat() / original_width |
||||
|
val scale_y = config.targetSize.toFloat() / original_height |
||||
|
val avg_scale = (scale_x + scale_y) / 2.0f |
||||
|
ResizeResult(resized_mat, avg_scale, 0.0f, 0.0f) |
||||
|
} |
||||
|
|
||||
|
// Apply color conversion |
||||
|
if (config.colorConversion != -1) |
||||
|
{ |
||||
|
Imgproc.cvtColor(processed_mat, processed_mat, config.colorConversion) |
||||
|
} |
||||
|
|
||||
|
// Normalize |
||||
|
val normalized_mat = Mat() |
||||
|
val normalization_factor = (config.normalizeRange.second - config.normalizeRange.first) / 255.0 |
||||
|
processed_mat.convertTo(normalized_mat, CvType.CV_32F, normalization_factor, config.normalizeRange.first.toDouble()) |
||||
|
|
||||
|
// Convert to array format [1, channels, height, width] |
||||
|
val data = matToArray(normalized_mat, config.numChannels, config.targetSize) |
||||
|
|
||||
|
// Clean up intermediate Mats |
||||
|
processed_mat.release() |
||||
|
normalized_mat.release() |
||||
|
|
||||
|
return PreprocessResult( |
||||
|
data = data, |
||||
|
scale = scale, |
||||
|
offsetX = offset_x, |
||||
|
offsetY = offset_y, |
||||
|
originalWidth = original_width, |
||||
|
originalHeight = original_height |
||||
|
) |
||||
|
} |
||||
|
finally |
||||
|
{ |
||||
|
original_mat.release() |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* Transform coordinates from model space back to original image space. |
||||
|
* |
||||
|
* @param modelX X coordinate in model space |
||||
|
* @param modelY Y coordinate in model space |
||||
|
* @param preprocessResult The preprocessing result containing transformation metadata |
||||
|
* @return Pair of (originalX, originalY) coordinates |
||||
|
*/ |
||||
|
fun transformCoordinates( |
||||
|
modelX: Float, |
||||
|
modelY: Float, |
||||
|
preprocessResult: PreprocessResult |
||||
|
): Pair<Float, Float> |
||||
|
{ |
||||
|
val original_x = ((modelX - preprocessResult.offsetX) / preprocessResult.scale) |
||||
|
.coerceIn(0f, preprocessResult.originalWidth.toFloat()) |
||||
|
val original_y = ((modelY - preprocessResult.offsetY) / preprocessResult.scale) |
||||
|
.coerceIn(0f, preprocessResult.originalHeight.toFloat()) |
||||
|
|
||||
|
return Pair(original_x, original_y) |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* Transform a bounding box from model space back to original image space. |
||||
|
* |
||||
|
* @param boundingBox Bounding box in model space |
||||
|
* @param preprocessResult The preprocessing result containing transformation metadata |
||||
|
* @return Transformed bounding box in original image space |
||||
|
*/ |
||||
|
fun transformBoundingBox( |
||||
|
boundingBox: BoundingBox, |
||||
|
preprocessResult: PreprocessResult |
||||
|
): BoundingBox |
||||
|
{ |
||||
|
val (left, top) = transformCoordinates(boundingBox.left, boundingBox.top, preprocessResult) |
||||
|
val (right, bottom) = transformCoordinates(boundingBox.right, boundingBox.bottom, preprocessResult) |
||||
|
|
||||
|
return BoundingBox(left, top, right, bottom) |
||||
|
} |
||||
|
|
||||
|
private data class ResizeResult( |
||||
|
val mat: Mat, |
||||
|
val scale: Float, |
||||
|
val offsetX: Float, |
||||
|
val offsetY: Float |
||||
|
) |
||||
|
|
||||
|
private fun applyLetterboxing(inputMat: Mat, targetSize: Int): ResizeResult |
||||
|
{ |
||||
|
val scale = min(targetSize.toFloat() / inputMat.width(), targetSize.toFloat() / inputMat.height()) |
||||
|
|
||||
|
val new_width = (inputMat.width() * scale).toInt() |
||||
|
val new_height = (inputMat.height() * scale).toInt() |
||||
|
|
||||
|
// Resize while maintaining aspect ratio |
||||
|
val resized_mat = Mat() |
||||
|
Imgproc.resize(inputMat, resized_mat, Size(new_width.toDouble(), new_height.toDouble())) |
||||
|
|
||||
|
// Create letterboxed image (centered) |
||||
|
val letterbox_mat = Mat.zeros(targetSize, targetSize, CvType.CV_8UC3) |
||||
|
val offset_x = (targetSize - new_width) / 2 |
||||
|
val offset_y = (targetSize - new_height) / 2 |
||||
|
|
||||
|
val roi = Rect(offset_x, offset_y, new_width, new_height) |
||||
|
resized_mat.copyTo(letterbox_mat.submat(roi)) |
||||
|
|
||||
|
// Clean up intermediate mat |
||||
|
resized_mat.release() |
||||
|
|
||||
|
return ResizeResult(letterbox_mat, scale, offset_x.toFloat(), offset_y.toFloat()) |
||||
|
} |
||||
|
|
||||
|
private fun matToArray(mat: Mat, numChannels: Int, size: Int): Array<Array<Array<FloatArray>>> |
||||
|
{ |
||||
|
val data = Array(1) { Array(numChannels) { Array(size) { FloatArray(size) } } } |
||||
|
|
||||
|
for (c in 0 until numChannels) |
||||
|
{ |
||||
|
for (h in 0 until size) |
||||
|
{ |
||||
|
for (w in 0 until size) |
||||
|
{ |
||||
|
val pixel = mat.get(h, w) |
||||
|
data[0][c][h][w] = if (pixel != null && c < pixel.size) pixel[c].toFloat() else 0.0f |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return data |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* Log preprocessing statistics for debugging. |
||||
|
*/ |
||||
|
fun logPreprocessStats(result: PreprocessResult) |
||||
|
{ |
||||
|
Log.d(TAG, """ |
||||
|
📊 Preprocessing Stats: |
||||
|
- Original size: ${result.originalWidth}x${result.originalHeight} |
||||
|
- Scale factor: ${result.scale} |
||||
|
- Offset: (${result.offsetX}, ${result.offsetY}) |
||||
|
- Data shape: [${result.data.size}, ${result.data[0].size}, ${result.data[0][0].size}, ${result.data[0][0][0].size}] |
||||
|
""".trimIndent()) |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,77 @@ |
|||||
|
package com.quillstudios.pokegoalshelper.ml |
||||
|
|
||||
|
import android.graphics.Bitmap |
||||
|
|
||||
|
/** |
||||
|
* Interface for ML model inference operations. |
||||
|
* Separates ML concerns from the main service for better architecture. |
||||
|
*/ |
||||
|
interface MLInferenceEngine |
||||
|
{ |
||||
|
/** |
||||
|
* Initialize the ML model and prepare for inference. |
||||
|
* @return true if initialization was successful, false otherwise |
||||
|
*/ |
||||
|
suspend fun initialize(): Boolean |
||||
|
|
||||
|
/** |
||||
|
* Perform object detection on the provided image. |
||||
|
* @param image The bitmap image to analyze |
||||
|
* @return List of detected objects, empty if no objects found |
||||
|
*/ |
||||
|
suspend fun detect(image: Bitmap): List<Detection> |
||||
|
|
||||
|
/** |
||||
|
* Set the confidence threshold for detections. |
||||
|
* @param threshold Minimum confidence value (0.0 to 1.0) |
||||
|
*/ |
||||
|
fun setConfidenceThreshold(threshold: Float) |
||||
|
|
||||
|
/** |
||||
|
* Set the class filter for detections. |
||||
|
* @param className Class name to filter by, null for no filtering |
||||
|
*/ |
||||
|
fun setClassFilter(className: String?) |
||||
|
|
||||
|
/** |
||||
|
* Check if the engine is ready for inference. |
||||
|
* @return true if initialized and ready, false otherwise |
||||
|
*/ |
||||
|
fun isReady(): Boolean |
||||
|
|
||||
|
/** |
||||
|
* Get the current inference time statistics. |
||||
|
* @return Pair of (last inference time ms, average inference time ms) |
||||
|
*/ |
||||
|
fun getInferenceStats(): Pair<Long, Long> |
||||
|
|
||||
|
/** |
||||
|
* Clean up all resources. Should be called when engine is no longer needed. |
||||
|
*/ |
||||
|
fun cleanup() |
||||
|
} |
||||
|
|
||||
|
/** |
||||
|
* Data class representing a detected object. |
||||
|
*/ |
||||
|
data class Detection( |
||||
|
val className: String, |
||||
|
val confidence: Float, |
||||
|
val boundingBox: BoundingBox |
||||
|
) |
||||
|
|
||||
|
/** |
||||
|
* Data class representing a bounding box. |
||||
|
*/ |
||||
|
data class BoundingBox( |
||||
|
val left: Float, |
||||
|
val top: Float, |
||||
|
val right: Float, |
||||
|
val bottom: Float |
||||
|
) |
||||
|
{ |
||||
|
val width: Float get() = right - left |
||||
|
val height: Float get() = bottom - top |
||||
|
val centerX: Float get() = left + width / 2 |
||||
|
val centerY: Float get() = top + height / 2 |
||||
|
} |
||||
File diff suppressed because it is too large
Loading…
Reference in new issue