Browse Source
ARCH-002: Extract ML Inference Engine - Created MLInferenceEngine interface with async detection methods - Implemented YOLOInferenceEngine preserving ALL YOLOOnnxDetector functionality: * Complete 96-class mapping with exact original class names * All preprocessing techniques (ultralytics, enhanced, sharpened, original) * All coordinate transformation modes (HYBRID, LETTERBOX, DIRECT) * Weighted NMS and cross-class NMS for semantically related classes * Confidence mapping and mobile optimization * Debug features (class filtering, confidence logging) * Letterbox resize with proper aspect ratio preservation * CLAHE contrast enhancement and sharpening filters - Created ImagePreprocessor utility for reusable preprocessing operations * Configurable preprocessing with letterboxing, normalization, color conversion * Coordinate transformation utilities for model-to-image space conversion * Support for different preprocessing configurations - Updated ScreenCaptureService to use new MLInferenceEngine: * Replaced YOLOOnnxDetector with MLInferenceEngine dependency injection * Added class name to class ID mapping for compatibility * Maintained all existing detection pipeline functionality * Proper async/await integration with coroutines - Applied preferred code style throughout: * Opening braces on new lines for functions and statements * snake_case for local variables to distinguish from members/parameters * Consistent formatting matching project standards - Removed obsolete YOLO implementations (YOLODetector, YOLOTFLiteDetector) - Preserved all sophisticated ML features: TTA, multi-preprocessing, confidence mapping 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>arch-002-ml-inference-engine
6 changed files with 1581 additions and 1466 deletions
@ -1,700 +0,0 @@ |
|||
package com.quillstudios.pokegoalshelper |
|||
|
|||
import android.content.Context |
|||
import android.util.Log |
|||
import org.opencv.core.* |
|||
import org.opencv.dnn.Dnn |
|||
import org.opencv.dnn.Net |
|||
import org.opencv.imgproc.Imgproc |
|||
import java.io.FileOutputStream |
|||
import java.io.IOException |
|||
import android.graphics.Bitmap |
|||
import android.graphics.BitmapFactory |
|||
import org.opencv.android.Utils |
|||
|
|||
data class Detection( |
|||
val classId: Int, |
|||
val className: String, |
|||
val confidence: Float, |
|||
val boundingBox: Rect |
|||
) |
|||
|
|||
class YOLODetector(private val context: Context) { |
|||
|
|||
companion object { |
|||
private const val TAG = "YOLODetector" |
|||
private const val MODEL_FILE = "pokemon_model.onnx" |
|||
private const val INPUT_SIZE = 640 |
|||
private const val CONFIDENCE_THRESHOLD = 0.1f // Lower threshold for debugging |
|||
private const val NMS_THRESHOLD = 0.4f |
|||
} |
|||
|
|||
private fun parseTransposedOutput( |
|||
data: FloatArray, |
|||
rows: Int, |
|||
cols: Int, |
|||
xScale: Float, |
|||
yScale: Float, |
|||
boxes: MutableList<Rect>, |
|||
confidences: MutableList<Float>, |
|||
classIds: MutableList<Int> |
|||
) { |
|||
// For transposed output: rows=features(100), cols=detections(8400) |
|||
// Data layout: [x1, x2, x3, ...], [y1, y2, y3, ...], [w1, w2, w3, ...], etc. |
|||
|
|||
Log.d(TAG, "🔄 Parsing transposed output: $rows features x $cols detections") |
|||
|
|||
var validDetections = 0 |
|||
for (i in 0 until cols) { // Loop through detections |
|||
if (i >= data.size / rows) break |
|||
|
|||
// Extract coordinates from transposed layout |
|||
val centerX = data[0 * cols + i] * xScale // x row |
|||
val centerY = data[1 * cols + i] * yScale // y row |
|||
val width = data[2 * cols + i] * xScale // width row |
|||
val height = data[3 * cols + i] * yScale // height row |
|||
val confidence = data[4 * cols + i] // confidence row |
|||
|
|||
// Debug first few detections |
|||
if (i < 3) { |
|||
Log.d(TAG, "🔍 Transposed detection $i: conf=${String.format("%.4f", confidence)}, x=${String.format("%.1f", centerX)}, y=${String.format("%.1f", centerY)}, w=${String.format("%.1f", width)}, h=${String.format("%.1f", height)}") |
|||
} |
|||
|
|||
if (confidence > CONFIDENCE_THRESHOLD) { |
|||
// Find class with highest score |
|||
var maxClassScore = 0f |
|||
var classId = 0 |
|||
|
|||
for (j in 5 until rows) { // Start from row 5 (after x,y,w,h,conf) |
|||
if (j * cols + i >= data.size) break |
|||
val classScore = data[j * cols + i] |
|||
if (classScore > maxClassScore) { |
|||
maxClassScore = classScore |
|||
classId = j - 5 |
|||
} |
|||
} |
|||
|
|||
val finalConfidence = confidence * maxClassScore |
|||
|
|||
if (finalConfidence > CONFIDENCE_THRESHOLD) { |
|||
val x = (centerX - width / 2).toInt() |
|||
val y = (centerY - height / 2).toInt() |
|||
|
|||
boxes.add(Rect(x, y, width.toInt(), height.toInt())) |
|||
confidences.add(finalConfidence) |
|||
classIds.add(classId) |
|||
validDetections++ |
|||
|
|||
if (validDetections <= 3) { |
|||
Log.d(TAG, "✅ Valid transposed detection: class=$classId, conf=${String.format("%.4f", finalConfidence)}") |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
Log.d(TAG, "🎯 Transposed parsing found $validDetections valid detections") |
|||
} |
|||
|
|||
private var net: Net? = null |
|||
private var isInitialized = false |
|||
|
|||
// Your class names from training - COMPLETE 93 CLASSES |
|||
private val classNames = mapOf( |
|||
0 to "ball_icon_pokeball", |
|||
1 to "ball_icon_greatball", |
|||
2 to "ball_icon_ultraball", |
|||
3 to "ball_icon_masterball", |
|||
4 to "ball_icon_safariball", |
|||
5 to "ball_icon_levelball", |
|||
6 to "ball_icon_lureball", |
|||
7 to "ball_icon_moonball", |
|||
8 to "ball_icon_friendball", |
|||
9 to "ball_icon_loveball", |
|||
10 to "ball_icon_heavyball", |
|||
11 to "ball_icon_fastball", |
|||
12 to "ball_icon_sportball", |
|||
13 to "ball_icon_premierball", |
|||
14 to "ball_icon_repeatball", |
|||
15 to "ball_icon_timerball", |
|||
16 to "ball_icon_nestball", |
|||
17 to "ball_icon_netball", |
|||
18 to "ball_icon_diveball", |
|||
19 to "ball_icon_luxuryball", |
|||
20 to "ball_icon_healball", |
|||
21 to "ball_icon_quickball", |
|||
22 to "ball_icon_duskball", |
|||
23 to "ball_icon_cherishball", |
|||
24 to "ball_icon_dreamball", |
|||
25 to "ball_icon_beastball", |
|||
26 to "ball_icon_strangeparts", |
|||
27 to "ball_icon_parkball", |
|||
28 to "ball_icon_gsball", |
|||
29 to "pokemon_nickname", |
|||
30 to "gender_icon_male", |
|||
31 to "gender_icon_female", |
|||
32 to "pokemon_level", |
|||
33 to "language", |
|||
34 to "last_game_stamp_home", |
|||
35 to "last_game_stamp_lgp", |
|||
36 to "last_game_stamp_lge", |
|||
37 to "last_game_stamp_sw", |
|||
38 to "last_game_stamp_sh", |
|||
39 to "last_game_stamp_bank", |
|||
40 to "last_game_stamp_bd", |
|||
41 to "last_game_stamp_sp", |
|||
42 to "last_game_stamp_pla", |
|||
43 to "last_game_stamp_sc", |
|||
44 to "last_game_stamp_vi", |
|||
45 to "last_game_stamp_go", |
|||
46 to "national_dex_number", |
|||
47 to "pokemon_species", |
|||
48 to "type_1", |
|||
49 to "type_2", |
|||
50 to "shiny_icon", |
|||
51 to "origin_icon_vc", |
|||
52 to "origin_icon_xyoras", |
|||
53 to "origin_icon_smusum", |
|||
54 to "origin_icon_lg", |
|||
55 to "origin_icon_swsh", |
|||
56 to "origin_icon_go", |
|||
57 to "origin_icon_bdsp", |
|||
58 to "origin_icon_pla", |
|||
59 to "origin_icon_sv", |
|||
60 to "pokerus_infected_icon", |
|||
61 to "pokerus_cured_icon", |
|||
62 to "hp_value", |
|||
63 to "attack_value", |
|||
64 to "defense_value", |
|||
65 to "sp_atk_value", |
|||
66 to "sp_def_value", |
|||
67 to "speed_value", |
|||
68 to "ability_name", |
|||
69 to "nature_name", |
|||
70 to "move_name", |
|||
71 to "original_trainer_name", |
|||
72 to "original_trainder_number", |
|||
73 to "alpha_mark", |
|||
74 to "tera_water", |
|||
75 to "tera_psychic", |
|||
76 to "tera_ice", |
|||
77 to "tera_fairy", |
|||
78 to "tera_poison", |
|||
79 to "tera_ghost", |
|||
80 to "ball_icon_originball", |
|||
81 to "tera_dragon", |
|||
82 to "tera_steel", |
|||
83 to "tera_grass", |
|||
84 to "tera_normal", |
|||
85 to "tera_fire", |
|||
86 to "tera_electric", |
|||
87 to "tera_fighting", |
|||
88 to "tera_ground", |
|||
89 to "tera_flying", |
|||
90 to "tera_bug", |
|||
91 to "tera_rock", |
|||
92 to "tera_dark", |
|||
93 to "low_confidence", |
|||
94 to "ball_icon_pokeball_hisui", |
|||
95 to "ball_icon_ultraball_husui" |
|||
// Note: "", "", "" |
|||
// were in your list but would make it 96 classes. Using exactly 93 as reported by model. |
|||
) |
|||
|
|||
fun initialize(): Boolean { |
|||
if (isInitialized) return true |
|||
|
|||
try { |
|||
Log.i(TAG, "🤖 Initializing YOLO detector...") |
|||
|
|||
// Copy model from assets to internal storage if needed |
|||
val modelPath = copyAssetToInternalStorage(MODEL_FILE) |
|||
if (modelPath == null) { |
|||
Log.e(TAG, "❌ Failed to copy model from assets") |
|||
return false |
|||
} |
|||
|
|||
// Load the ONNX model |
|||
Log.i(TAG, "📥 Loading ONNX model from: $modelPath") |
|||
net = Dnn.readNetFromONNX(modelPath) |
|||
|
|||
if (net == null || net!!.empty()) { |
|||
Log.e(TAG, "❌ Failed to load ONNX model") |
|||
return false |
|||
} |
|||
|
|||
// Verify model loaded correctly |
|||
val layerNames = net!!.layerNames |
|||
Log.i(TAG, "🧠 Model loaded with ${layerNames.size} layers") |
|||
|
|||
val outputNames = net!!.unconnectedOutLayersNames |
|||
Log.i(TAG, "📝 Output layers: ${outputNames?.toString()}") |
|||
|
|||
// Set computational backend |
|||
net!!.setPreferableBackend(Dnn.DNN_BACKEND_OPENCV) |
|||
net!!.setPreferableTarget(Dnn.DNN_TARGET_CPU) |
|||
|
|||
// Debug: Check model input requirements |
|||
val inputNames = net!!.unconnectedOutLayersNames |
|||
Log.i(TAG, "🔍 Model input layers: ${inputNames?.toString()}") |
|||
|
|||
// Get input blob info if possible |
|||
try { |
|||
val dummyBlob = Mat.zeros(Size(640.0, 640.0), CvType.CV_32FC3) |
|||
net!!.setInput(dummyBlob) |
|||
Log.i(TAG, "✅ Model accepts 640x640 CV_32FC3 input") |
|||
dummyBlob.release() |
|||
} catch (e: Exception) { |
|||
Log.w(TAG, "⚠️ Model input test failed: ${e.message}") |
|||
} |
|||
|
|||
isInitialized = true |
|||
Log.i(TAG, "✅ YOLO detector initialized successfully") |
|||
Log.i(TAG, "📊 Model info: ${classNames.size} classes, input size: ${INPUT_SIZE}x${INPUT_SIZE}") |
|||
|
|||
return true |
|||
|
|||
} catch (e: Exception) { |
|||
Log.e(TAG, "❌ Error initializing YOLO detector", e) |
|||
return false |
|||
} |
|||
} |
|||
|
|||
fun detect(inputMat: Mat): List<Detection> { |
|||
if (!isInitialized || net == null) { |
|||
Log.w(TAG, "⚠️ YOLO detector not initialized") |
|||
return emptyList() |
|||
} |
|||
|
|||
try { |
|||
Log.d(TAG, "🔍 Running YOLO detection on ${inputMat.cols()}x${inputMat.rows()} image") |
|||
|
|||
// Preprocess image |
|||
val blob = preprocessImage(inputMat) |
|||
|
|||
// Set input to the network |
|||
net!!.setInput(blob) |
|||
|
|||
// Check blob before sending to model |
|||
val blobTestData = FloatArray(10) |
|||
blob.get(0, 0, blobTestData) |
|||
val hasRealData = blobTestData.any { it != 0f } |
|||
Log.w(TAG, "⚠️ CRITICAL: Blob sent to model has real data: $hasRealData") |
|||
|
|||
if (!hasRealData) { |
|||
Log.e(TAG, "❌ FATAL: All blob creation methods failed - this is likely an OpenCV bug or model issue") |
|||
Log.e(TAG, "❌ Try these solutions:") |
|||
Log.e(TAG, " 1. Re-export your ONNX model with different settings") |
|||
Log.e(TAG, " 2. Try a different OpenCV version") |
|||
Log.e(TAG, " 3. Use a different inference framework (TensorFlow Lite)") |
|||
Log.e(TAG, " 4. Check if your model expects different input format") |
|||
} |
|||
|
|||
// Run forward pass |
|||
val outputs = mutableListOf<Mat>() |
|||
net!!.forward(outputs, net!!.unconnectedOutLayersNames) |
|||
|
|||
Log.d(TAG, "🧠 Model inference complete, got ${outputs.size} output tensors") |
|||
|
|||
// Post-process results |
|||
val detections = postprocess(outputs, inputMat.cols(), inputMat.rows()) |
|||
|
|||
// Clean up |
|||
blob.release() |
|||
outputs.forEach { it.release() } |
|||
|
|||
Log.i(TAG, "✅ YOLO detection complete: ${detections.size} objects detected") |
|||
|
|||
return detections |
|||
|
|||
} catch (e: Exception) { |
|||
Log.e(TAG, "❌ Error during YOLO detection", e) |
|||
return emptyList() |
|||
} |
|||
} |
|||
|
|||
private fun preprocessImage(mat: Mat): Mat { |
|||
// Convert to RGB if needed |
|||
val rgbMat = Mat() |
|||
if (mat.channels() == 4) { |
|||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGRA2RGB) |
|||
} else if (mat.channels() == 3) { |
|||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB) |
|||
} else { |
|||
mat.copyTo(rgbMat) |
|||
} |
|||
|
|||
// Ensure the matrix is continuous for blob creation |
|||
if (!rgbMat.isContinuous) { |
|||
val continuousMat = Mat() |
|||
rgbMat.copyTo(continuousMat) |
|||
rgbMat.release() |
|||
continuousMat.copyTo(rgbMat) |
|||
continuousMat.release() |
|||
} |
|||
|
|||
Log.d(TAG, "🖼️ Input image: ${rgbMat.cols()}x${rgbMat.rows()}, channels: ${rgbMat.channels()}") |
|||
Log.d(TAG, "🖼️ Mat type: ${rgbMat.type()}, depth: ${rgbMat.depth()}") |
|||
|
|||
// Debug: Check if image data is not all zeros (use ByteArray for CV_8U data) |
|||
try { |
|||
val testData = ByteArray(3) |
|||
rgbMat.get(100, 100, testData) // Sample some pixels |
|||
val testValues = testData.map { (it.toInt() and 0xFF).toString() }.joinToString(", ") |
|||
Log.d(TAG, "🖼️ Sample RGB values at (100,100): [$testValues]") |
|||
} catch (e: Exception) { |
|||
Log.w(TAG, "⚠️ Could not sample pixel values: ${e.message}") |
|||
} |
|||
|
|||
// Create blob from image - match training preprocessing exactly |
|||
val blob = Dnn.blobFromImage( |
|||
rgbMat, |
|||
1.0 / 255.0, // Scale factor (normalize to 0-1) |
|||
Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble()), // Size |
|||
Scalar(0.0, 0.0, 0.0), // Mean subtraction (none for YOLO) |
|||
true, // Swap R and B channels for OpenCV |
|||
false, // Crop |
|||
CvType.CV_32F // Data type |
|||
) |
|||
|
|||
Log.d(TAG, "🌐 Blob created: [${blob.size(0)}, ${blob.size(1)}, ${blob.size(2)}, ${blob.size(3)}]") |
|||
|
|||
// Debug: Check blob values to ensure they're not all zeros |
|||
val blobData = FloatArray(Math.min(30, blob.total().toInt())) |
|||
blob.get(0, 0, blobData) |
|||
val blobValues = blobData.map { String.format("%.4f", it) }.joinToString(", ") |
|||
Log.d(TAG, "🌐 First 30 blob values: [$blobValues]") |
|||
|
|||
// Check if blob is completely zero |
|||
val nonZeroCount = blobData.count { it != 0f } |
|||
Log.d(TAG, "🌐 Non-zero blob values: $nonZeroCount/${blobData.size}") |
|||
|
|||
// Try different blob creation methods |
|||
if (nonZeroCount == 0) { |
|||
Log.w(TAG, "⚠️ Blob is all zeros! Trying alternative blob creation...") |
|||
|
|||
// Try without swapRB |
|||
val blob2 = Dnn.blobFromImage( |
|||
rgbMat, |
|||
1.0 / 255.0, |
|||
Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble()), |
|||
Scalar(0.0, 0.0, 0.0), |
|||
false, // No channel swap |
|||
false, |
|||
CvType.CV_32F |
|||
) |
|||
|
|||
val blobData2 = FloatArray(10) |
|||
blob2.get(0, 0, blobData2) |
|||
val blobValues2 = blobData2.map { String.format("%.4f", it) }.joinToString(", ") |
|||
Log.d(TAG, "🌐 Alternative blob (swapRB=false): [$blobValues2]") |
|||
|
|||
if (blobData2.any { it != 0f }) { |
|||
Log.i(TAG, "✅ Alternative blob has data! Using swapRB=false") |
|||
blob.release() |
|||
rgbMat.release() |
|||
return blob2 |
|||
} |
|||
blob2.release() |
|||
|
|||
// Try manual blob creation as last resort |
|||
Log.w(TAG, "⚠️ Both blob methods failed! Trying manual blob creation...") |
|||
val manualBlob = createManualBlob(rgbMat) |
|||
if (manualBlob != null) { |
|||
val manualData = FloatArray(10) |
|||
manualBlob.get(0, 0, manualData) |
|||
val manualValues = manualData.map { String.format("%.4f", it) }.joinToString(", ") |
|||
Log.d(TAG, "🌐 Manual blob: [$manualValues]") |
|||
|
|||
if (manualData.any { it != 0f }) { |
|||
Log.i(TAG, "✅ Manual blob has data! Using manual method") |
|||
blob.release() |
|||
rgbMat.release() |
|||
return manualBlob |
|||
} |
|||
manualBlob.release() |
|||
} |
|||
} |
|||
|
|||
rgbMat.release() |
|||
return blob |
|||
} |
|||
|
|||
private fun createManualBlob(rgbMat: Mat): Mat? { |
|||
try { |
|||
// Resize image to 640x640 |
|||
val resized = Mat() |
|||
Imgproc.resize(rgbMat, resized, Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble())) |
|||
|
|||
// Convert to float and normalize manually |
|||
val floatMat = Mat() |
|||
resized.convertTo(floatMat, CvType.CV_32F, 1.0/255.0) |
|||
|
|||
Log.d(TAG, "🔧 Manual resize: ${resized.cols()}x${resized.rows()}") |
|||
Log.d(TAG, "🔧 Float conversion: type=${floatMat.type()}") |
|||
|
|||
// Check if float conversion worked |
|||
val testFloat = FloatArray(3) |
|||
floatMat.get(100, 100, testFloat) |
|||
val testFloatValues = testFloat.map { String.format("%.4f", it) }.joinToString(", ") |
|||
Log.d(TAG, "🔧 Float test values: [$testFloatValues]") |
|||
|
|||
// Use OpenCV's blobFromImage on the preprocessed float mat |
|||
val blob = Dnn.blobFromImage( |
|||
floatMat, |
|||
1.0, // No additional scaling since already normalized |
|||
Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble()), |
|||
Scalar(0.0, 0.0, 0.0), |
|||
false, // Don't swap channels |
|||
false, // Don't crop |
|||
CvType.CV_32F |
|||
) |
|||
|
|||
Log.d(TAG, "🔧 Manual blob from preprocessed image") |
|||
|
|||
// Clean up |
|||
resized.release() |
|||
floatMat.release() |
|||
|
|||
return blob |
|||
|
|||
} catch (e: Exception) { |
|||
Log.e(TAG, "❌ Manual blob creation failed", e) |
|||
return null |
|||
} |
|||
} |
|||
|
|||
private fun postprocess(outputs: List<Mat>, originalWidth: Int, originalHeight: Int): List<Detection> { |
|||
if (outputs.isEmpty()) return emptyList() |
|||
|
|||
val detections = mutableListOf<Detection>() |
|||
val confidences = mutableListOf<Float>() |
|||
val boxes = mutableListOf<Rect>() |
|||
val classIds = mutableListOf<Int>() |
|||
|
|||
// Calculate scale factors |
|||
val xScale = originalWidth.toFloat() / INPUT_SIZE |
|||
val yScale = originalHeight.toFloat() / INPUT_SIZE |
|||
|
|||
// Process each output |
|||
for (outputIndex in outputs.indices) { |
|||
val output = outputs[outputIndex] |
|||
val data = FloatArray((output.total() * output.channels()).toInt()) |
|||
output.get(0, 0, data) |
|||
|
|||
val rows = output.size(1).toInt() // Number of detections |
|||
val cols = output.size(2).toInt() // Features per detection |
|||
|
|||
Log.d(TAG, "🔍 Output $outputIndex: ${rows} detections, ${cols} features each") |
|||
Log.d(TAG, "🔍 Output shape: [${output.size(0)}, ${output.size(1)}, ${output.size(2)}]") |
|||
|
|||
// Debug: Check first few values to understand format |
|||
if (data.size >= 10) { |
|||
val firstValues = data.take(10).map { String.format("%.4f", it) }.joinToString(", ") |
|||
Log.d(TAG, "🔍 First 10 values: [$firstValues]") |
|||
|
|||
// Check max confidence in first 100 values to verify model output |
|||
val maxConf = data.take(100).maxOrNull() ?: 0f |
|||
Log.d(TAG, "🔍 Max value in first 100: ${String.format("%.4f", maxConf)}") |
|||
} |
|||
|
|||
// Check if this might be a transposed output (8400 detections, 100 features) |
|||
if (cols == 8400 && rows == 100) { |
|||
Log.d(TAG, "🤔 Detected transposed output format - trying alternative parsing") |
|||
parseTransposedOutput(data, rows, cols, xScale, yScale, boxes, confidences, classIds) |
|||
continue |
|||
} |
|||
|
|||
var validDetections = 0 |
|||
for (i in 0 until rows) { |
|||
val offset = i * cols |
|||
|
|||
if (offset + 4 >= data.size) { |
|||
Log.w(TAG, "⚠️ Data array too small for detection $i") |
|||
break |
|||
} |
|||
|
|||
// Extract box coordinates (center format) |
|||
val centerX = data[offset + 0] * xScale |
|||
val centerY = data[offset + 1] * yScale |
|||
val width = data[offset + 2] * xScale |
|||
val height = data[offset + 3] * yScale |
|||
|
|||
// Convert to top-left corner format |
|||
val x = (centerX - width / 2).toInt() |
|||
val y = (centerY - height / 2).toInt() |
|||
|
|||
// Extract confidence and class scores |
|||
val confidence = data[offset + 4] |
|||
|
|||
// Debug first few detections |
|||
if (i < 3) { |
|||
Log.d(TAG, "🔍 Detection $i: conf=${String.format("%.4f", confidence)}, x=${String.format("%.1f", centerX)}, y=${String.format("%.1f", centerY)}, w=${String.format("%.1f", width)}, h=${String.format("%.1f", height)}") |
|||
} |
|||
|
|||
if (confidence > CONFIDENCE_THRESHOLD) { |
|||
// Find class with highest score |
|||
var maxClassScore = 0f |
|||
var classId = 0 |
|||
|
|||
for (j in 5 until cols) { |
|||
if (offset + j >= data.size) break |
|||
val classScore = data[offset + j] |
|||
if (classScore > maxClassScore) { |
|||
maxClassScore = classScore |
|||
classId = j - 5 |
|||
} |
|||
} |
|||
|
|||
val finalConfidence = confidence * maxClassScore |
|||
|
|||
if (finalConfidence > CONFIDENCE_THRESHOLD) { |
|||
boxes.add(Rect(x, y, width.toInt(), height.toInt())) |
|||
confidences.add(finalConfidence) |
|||
classIds.add(classId) |
|||
validDetections++ |
|||
|
|||
if (validDetections <= 3) { |
|||
Log.d(TAG, "✅ Valid detection: class=$classId, conf=${String.format("%.4f", finalConfidence)}") |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
Log.d(TAG, "🎯 Found ${validDetections} valid detections above confidence threshold") |
|||
} |
|||
|
|||
// Apply Non-Maximum Suppression |
|||
Log.d(TAG, "📊 Before NMS: ${boxes.size} detections") |
|||
|
|||
if (boxes.isEmpty()) { |
|||
Log.d(TAG, "⚠️ No detections found before NMS") |
|||
return emptyList() |
|||
} |
|||
|
|||
val indices = MatOfInt() |
|||
val boxesArray = MatOfRect2d() |
|||
|
|||
// Convert Rect to Rect2d for NMSBoxes |
|||
val boxes2d = boxes.map { Rect2d(it.x.toDouble(), it.y.toDouble(), it.width.toDouble(), it.height.toDouble()) } |
|||
boxesArray.fromList(boxes2d) |
|||
|
|||
val confidencesArray = FloatArray(confidences.size) |
|||
for (i in confidences.indices) { |
|||
confidencesArray[i] = confidences[i] |
|||
} |
|||
|
|||
try { |
|||
Dnn.NMSBoxes( |
|||
boxesArray, |
|||
MatOfFloat(*confidencesArray), |
|||
CONFIDENCE_THRESHOLD, |
|||
NMS_THRESHOLD, |
|||
indices |
|||
) |
|||
Log.d(TAG, "✅ NMS completed successfully") |
|||
} catch (e: Exception) { |
|||
Log.e(TAG, "❌ NMS failed: ${e.message}", e) |
|||
return emptyList() |
|||
} |
|||
|
|||
// Build final detection list |
|||
val indicesArray = indices.toArray() |
|||
|
|||
// Check if NMS returned any valid indices |
|||
if (indicesArray.isEmpty()) { |
|||
Log.d(TAG, "🎯 NMS filtered out all detections") |
|||
indices.release() |
|||
return emptyList() |
|||
} |
|||
|
|||
for (i in indicesArray) { |
|||
val className = classNames[classIds[i]] ?: "unknown_${classIds[i]}" |
|||
detections.add( |
|||
Detection( |
|||
classId = classIds[i], |
|||
className = className, |
|||
confidence = confidences[i], |
|||
boundingBox = boxes[i] |
|||
) |
|||
) |
|||
} |
|||
|
|||
// Clean up |
|||
indices.release() |
|||
// Note: MatOfRect2d doesn't have a release method in OpenCV Android |
|||
|
|||
Log.d(TAG, "🎯 Post-processing complete: ${detections.size} final detections after NMS") |
|||
|
|||
return detections.sortedByDescending { it.confidence } |
|||
} |
|||
|
|||
private fun copyAssetToInternalStorage(assetName: String): String? { |
|||
return try { |
|||
val inputStream = context.assets.open(assetName) |
|||
val file = context.getFileStreamPath(assetName) |
|||
val outputStream = FileOutputStream(file) |
|||
|
|||
inputStream.copyTo(outputStream) |
|||
inputStream.close() |
|||
outputStream.close() |
|||
|
|||
file.absolutePath |
|||
} catch (e: IOException) { |
|||
Log.e(TAG, "Error copying asset $assetName", e) |
|||
null |
|||
} |
|||
} |
|||
|
|||
fun testWithStaticImage(): List<Detection> { |
|||
if (!isInitialized) { |
|||
Log.e(TAG, "❌ YOLO detector not initialized for test") |
|||
return emptyList() |
|||
} |
|||
|
|||
try { |
|||
Log.i(TAG, "🧪 TESTING WITH STATIC IMAGE") |
|||
|
|||
// Load test image from assets |
|||
val inputStream = context.assets.open("test_pokemon.jpg") |
|||
val bitmap = BitmapFactory.decodeStream(inputStream) |
|||
inputStream.close() |
|||
|
|||
if (bitmap == null) { |
|||
Log.e(TAG, "❌ Failed to load test_pokemon.jpg from assets") |
|||
return emptyList() |
|||
} |
|||
|
|||
Log.i(TAG, "📸 Loaded test image: ${bitmap.width}x${bitmap.height}") |
|||
|
|||
// Convert bitmap to OpenCV Mat |
|||
val mat = Mat() |
|||
Utils.bitmapToMat(bitmap, mat) |
|||
|
|||
Log.i(TAG, "🔄 Converted to Mat: ${mat.cols()}x${mat.rows()}, channels: ${mat.channels()}") |
|||
|
|||
// Run detection |
|||
val detections = detect(mat) |
|||
|
|||
Log.i(TAG, "🎯 TEST RESULT: ${detections.size} detections found") |
|||
detections.forEachIndexed { index, detection -> |
|||
Log.i(TAG, " $index: ${detection.className} (${String.format("%.3f", detection.confidence)}) at [${detection.boundingBox.x}, ${detection.boundingBox.y}, ${detection.boundingBox.width}, ${detection.boundingBox.height}]") |
|||
} |
|||
|
|||
// Clean up |
|||
mat.release() |
|||
bitmap.recycle() |
|||
|
|||
return detections |
|||
|
|||
} catch (e: Exception) { |
|||
Log.e(TAG, "❌ Error in static image test", e) |
|||
return emptyList() |
|||
} |
|||
} |
|||
|
|||
fun release() { |
|||
net = null |
|||
isInitialized = false |
|||
Log.d(TAG, "YOLO detector released") |
|||
} |
|||
} |
|||
@ -1,749 +0,0 @@ |
|||
package com.quillstudios.pokegoalshelper |
|||
|
|||
import android.content.Context |
|||
import android.graphics.Bitmap |
|||
import android.graphics.BitmapFactory |
|||
import android.util.Log |
|||
import org.opencv.android.Utils |
|||
import org.opencv.core.* |
|||
import org.opencv.dnn.Dnn |
|||
import org.opencv.imgproc.Imgproc |
|||
import org.tensorflow.lite.Interpreter |
|||
import java.io.FileInputStream |
|||
import java.io.FileOutputStream |
|||
import java.io.IOException |
|||
import java.nio.ByteBuffer |
|||
import java.nio.ByteOrder |
|||
import java.nio.channels.FileChannel |
|||
import kotlin.math.max |
|||
import kotlin.math.min |
|||
|
|||
class YOLOTFLiteDetector(private val context: Context) { |
|||
|
|||
companion object { |
|||
private const val TAG = "YOLOTFLiteDetector" |
|||
private const val MODEL_FILE = "pokemon_model.tflite" |
|||
private const val INPUT_SIZE = 640 |
|||
private const val CONFIDENCE_THRESHOLD = 0.05f // Extremely low to test conversion quality |
|||
private const val NMS_THRESHOLD = 0.4f |
|||
private const val NUM_CHANNELS = 3 |
|||
private const val NUM_DETECTIONS = 8400 // YOLOv8 default |
|||
private const val NUM_CLASSES = 95 // Your class count |
|||
} |
|||
|
|||
private var interpreter: Interpreter? = null |
|||
private var isInitialized = false |
|||
|
|||
// Input/output buffers |
|||
private var inputBuffer: ByteBuffer? = null |
|||
private var outputBuffer: FloatArray? = null |
|||
|
|||
// Your class names (same as before) |
|||
private val classNames = mapOf( |
|||
0 to "ball_icon_pokeball", |
|||
1 to "ball_icon_greatball", |
|||
2 to "ball_icon_ultraball", |
|||
3 to "ball_icon_masterball", |
|||
4 to "ball_icon_safariball", |
|||
5 to "ball_icon_levelball", |
|||
6 to "ball_icon_lureball", |
|||
7 to "ball_icon_moonball", |
|||
8 to "ball_icon_friendball", |
|||
9 to "ball_icon_loveball", |
|||
10 to "ball_icon_heavyball", |
|||
11 to "ball_icon_fastball", |
|||
12 to "ball_icon_sportball", |
|||
13 to "ball_icon_premierball", |
|||
14 to "ball_icon_repeatball", |
|||
15 to "ball_icon_timerball", |
|||
16 to "ball_icon_nestball", |
|||
17 to "ball_icon_netball", |
|||
18 to "ball_icon_diveball", |
|||
19 to "ball_icon_luxuryball", |
|||
20 to "ball_icon_healball", |
|||
21 to "ball_icon_quickball", |
|||
22 to "ball_icon_duskball", |
|||
23 to "ball_icon_cherishball", |
|||
24 to "ball_icon_dreamball", |
|||
25 to "ball_icon_beastball", |
|||
26 to "ball_icon_strangeparts", |
|||
27 to "ball_icon_parkball", |
|||
28 to "ball_icon_gsball", |
|||
29 to "pokemon_nickname", |
|||
30 to "gender_icon_male", |
|||
31 to "gender_icon_female", |
|||
32 to "pokemon_level", |
|||
33 to "language", |
|||
34 to "last_game_stamp_home", |
|||
35 to "last_game_stamp_lgp", |
|||
36 to "last_game_stamp_lge", |
|||
37 to "last_game_stamp_sw", |
|||
38 to "last_game_stamp_sh", |
|||
39 to "last_game_stamp_bank", |
|||
40 to "last_game_stamp_bd", |
|||
41 to "last_game_stamp_sp", |
|||
42 to "last_game_stamp_pla", |
|||
43 to "last_game_stamp_sc", |
|||
44 to "last_game_stamp_vi", |
|||
45 to "last_game_stamp_go", |
|||
46 to "national_dex_number", |
|||
47 to "pokemon_species", |
|||
48 to "type_1", |
|||
49 to "type_2", |
|||
50 to "shiny_icon", |
|||
51 to "origin_icon_vc", |
|||
52 to "origin_icon_xyoras", |
|||
53 to "origin_icon_smusum", |
|||
54 to "origin_icon_lg", |
|||
55 to "origin_icon_swsh", |
|||
56 to "origin_icon_go", |
|||
57 to "origin_icon_bdsp", |
|||
58 to "origin_icon_pla", |
|||
59 to "origin_icon_sv", |
|||
60 to "pokerus_infected_icon", |
|||
61 to "pokerus_cured_icon", |
|||
62 to "hp_value", |
|||
63 to "attack_value", |
|||
64 to "defense_value", |
|||
65 to "sp_atk_value", |
|||
66 to "sp_def_value", |
|||
67 to "speed_value", |
|||
68 to "ability_name", |
|||
69 to "nature_name", |
|||
70 to "move_name", |
|||
71 to "original_trainer_name", |
|||
72 to "original_trainder_number", |
|||
73 to "alpha_mark", |
|||
74 to "tera_water", |
|||
75 to "tera_psychic", |
|||
76 to "tera_ice", |
|||
77 to "tera_fairy", |
|||
78 to "tera_poison", |
|||
79 to "tera_ghost", |
|||
80 to "ball_icon_originball", |
|||
81 to "tera_dragon", |
|||
82 to "tera_steel", |
|||
83 to "tera_grass", |
|||
84 to "tera_normal", |
|||
85 to "tera_fire", |
|||
86 to "tera_electric", |
|||
87 to "tera_fighting", |
|||
88 to "tera_ground", |
|||
89 to "tera_flying", |
|||
90 to "tera_bug", |
|||
91 to "tera_rock", |
|||
92 to "tera_dark", |
|||
93 to "low_confidence", |
|||
94 to "ball_icon_pokeball_hisui", |
|||
95 to "ball_icon_ultraball_husui" |
|||
) |
|||
|
|||
fun initialize(): Boolean { |
|||
if (isInitialized) return true |
|||
|
|||
try { |
|||
Log.i(TAG, "🤖 Initializing TensorFlow Lite YOLO detector...") |
|||
|
|||
// Load model from assets |
|||
Log.i(TAG, "📂 Copying model file: $MODEL_FILE") |
|||
val modelPath = copyAssetToInternalStorage(MODEL_FILE) |
|||
if (modelPath == null) { |
|||
Log.e(TAG, "❌ Failed to copy TFLite model from assets") |
|||
return false |
|||
} |
|||
Log.i(TAG, "✅ Model copied to: $modelPath") |
|||
|
|||
// Create interpreter |
|||
Log.i(TAG, "📥 Loading TFLite model from: $modelPath") |
|||
val modelFile = loadModelFile(modelPath) |
|||
Log.i(TAG, "📥 Model file loaded, size: ${modelFile.capacity()} bytes") |
|||
|
|||
val options = Interpreter.Options() |
|||
options.setNumThreads(4) // Use 4 CPU threads |
|||
Log.i(TAG, "🔧 Creating TensorFlow Lite interpreter...") |
|||
interpreter = Interpreter(modelFile, options) |
|||
Log.i(TAG, "✅ Interpreter created successfully") |
|||
|
|||
// Get model info |
|||
val inputTensor = interpreter!!.getInputTensor(0) |
|||
val outputTensor = interpreter!!.getOutputTensor(0) |
|||
Log.i(TAG, "📊 Input tensor shape: ${inputTensor.shape().contentToString()}") |
|||
Log.i(TAG, "📊 Output tensor shape: ${outputTensor.shape().contentToString()}") |
|||
|
|||
// Allocate input/output buffers |
|||
Log.i(TAG, "📦 Allocating buffers...") |
|||
allocateBuffers() |
|||
Log.i(TAG, "✅ Buffers allocated") |
|||
|
|||
// Test model with dummy input |
|||
Log.i(TAG, "🧪 Testing model with dummy input...") |
|||
testModelInputOutput() |
|||
Log.i(TAG, "✅ Model test completed") |
|||
|
|||
isInitialized = true |
|||
Log.i(TAG, "✅ TensorFlow Lite YOLO detector initialized successfully") |
|||
Log.i(TAG, "📊 Model info: ${classNames.size} classes, input size: ${INPUT_SIZE}x${INPUT_SIZE}") |
|||
|
|||
return true |
|||
|
|||
} catch (e: Exception) { |
|||
Log.e(TAG, "❌ Error initializing TensorFlow Lite detector", e) |
|||
e.printStackTrace() |
|||
return false |
|||
} |
|||
} |
|||
|
|||
private fun loadModelFile(modelPath: String): ByteBuffer { |
|||
val fileInputStream = FileInputStream(modelPath) |
|||
val fileChannel = fileInputStream.channel |
|||
val startOffset = 0L |
|||
val declaredLength = fileChannel.size() |
|||
val modelBuffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) |
|||
fileInputStream.close() |
|||
return modelBuffer |
|||
} |
|||
|
|||
private fun allocateBuffers() { |
|||
// Get actual tensor shapes from the model |
|||
val inputTensor = interpreter!!.getInputTensor(0) |
|||
val outputTensor = interpreter!!.getOutputTensor(0) |
|||
|
|||
val inputShape = inputTensor.shape() |
|||
val outputShape = outputTensor.shape() |
|||
|
|||
Log.d(TAG, "📊 Actual input shape: ${inputShape.contentToString()}") |
|||
Log.d(TAG, "📊 Actual output shape: ${outputShape.contentToString()}") |
|||
|
|||
// Input buffer: [1, 640, 640, 3] * 4 bytes per float |
|||
val inputSize = inputShape.fold(1) { acc, dim -> acc * dim } * 4 |
|||
inputBuffer = ByteBuffer.allocateDirect(inputSize) |
|||
inputBuffer!!.order(ByteOrder.nativeOrder()) |
|||
|
|||
// Output buffer: [1, 100, 8400] |
|||
val outputSize = outputShape.fold(1) { acc, dim -> acc * dim } |
|||
outputBuffer = FloatArray(outputSize) |
|||
|
|||
Log.d(TAG, "📦 Allocated input buffer: ${inputSize} bytes") |
|||
Log.d(TAG, "📦 Allocated output buffer: ${outputSize} floats") |
|||
} |
|||
|
|||
private fun testModelInputOutput() { |
|||
try { |
|||
// Fill input buffer with dummy data |
|||
inputBuffer!!.rewind() |
|||
repeat(INPUT_SIZE * INPUT_SIZE * NUM_CHANNELS) { // HWC format |
|||
inputBuffer!!.putFloat(0.5f) // Dummy normalized pixel value |
|||
} |
|||
|
|||
// Create output array as multidimensional array for TensorFlow Lite |
|||
val outputShape = interpreter!!.getOutputTensor(0).shape() |
|||
Log.d(TAG, "🔍 Output tensor shape: ${outputShape.contentToString()}") |
|||
|
|||
// Create output as 3D array: [batch][features][detections] |
|||
val output = Array(outputShape[0]) { |
|||
Array(outputShape[1]) { |
|||
FloatArray(outputShape[2]) |
|||
} |
|||
} |
|||
|
|||
// Run inference |
|||
inputBuffer!!.rewind() |
|||
interpreter!!.run(inputBuffer, output) |
|||
|
|||
// Check output |
|||
val firstBatch = output[0] |
|||
val maxOutput = firstBatch.flatMap { it.asIterable() }.maxOrNull() ?: 0f |
|||
val minOutput = firstBatch.flatMap { it.asIterable() }.minOrNull() ?: 0f |
|||
Log.i(TAG, "🧪 Model test: output range [${String.format("%.4f", minOutput)}, ${String.format("%.4f", maxOutput)}]") |
|||
|
|||
// Convert 3D array to flat array for postprocessing |
|||
outputBuffer = firstBatch.flatMap { it.asIterable() }.toFloatArray() |
|||
Log.d(TAG, "🧪 Converted output to flat array of size: ${outputBuffer!!.size}") |
|||
|
|||
} catch (e: Exception) { |
|||
Log.e(TAG, "❌ Model test failed", e) |
|||
throw e |
|||
} |
|||
} |
|||
|
|||
fun detect(inputMat: Mat): List<Detection> { |
|||
if (!isInitialized || interpreter == null) { |
|||
Log.w(TAG, "⚠️ TensorFlow Lite detector not initialized") |
|||
return emptyList() |
|||
} |
|||
|
|||
try { |
|||
Log.d(TAG, "🔍 Running TFLite YOLO detection on ${inputMat.cols()}x${inputMat.rows()} image") |
|||
|
|||
// Preprocess image |
|||
preprocessImage(inputMat) |
|||
|
|||
// Create output array as multidimensional array for TensorFlow Lite |
|||
val outputShape = interpreter!!.getOutputTensor(0).shape() |
|||
val output = Array(outputShape[0]) { |
|||
Array(outputShape[1]) { |
|||
FloatArray(outputShape[2]) |
|||
} |
|||
} |
|||
|
|||
// Run inference |
|||
inputBuffer!!.rewind() |
|||
interpreter!!.run(inputBuffer, output) |
|||
|
|||
// Convert 3D array to flat array for postprocessing |
|||
val flatOutput = output[0].flatMap { it.asIterable() }.toFloatArray() |
|||
|
|||
// Post-process results |
|||
val detections = postprocess(flatOutput, inputMat.cols(), inputMat.rows()) |
|||
|
|||
Log.i(TAG, "✅ TFLite YOLO detection complete: ${detections.size} objects detected") |
|||
|
|||
return detections |
|||
|
|||
} catch (e: Exception) { |
|||
Log.e(TAG, "❌ Error during TFLite YOLO detection", e) |
|||
return emptyList() |
|||
} |
|||
} |
|||
|
|||
// Corrected preprocessImage function |
|||
private fun preprocessImage(mat: Mat) { |
|||
// Convert to RGB |
|||
val rgbMat = Mat() |
|||
// It's safer to always explicitly convert to the expected 3-channel type |
|||
// Assuming input `mat` can be from RGBA (screen capture) or BGR (file) |
|||
if (mat.channels() == 4) { |
|||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGRA2RGB) // Assuming screen capture is BGRA |
|||
} else { // Handle 3-channel BGR or RGB direct |
|||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB) // Convert BGR (OpenCV default) to RGB |
|||
} |
|||
|
|||
// Resize to input size (640x640) |
|||
val resized = Mat() |
|||
Imgproc.resize(rgbMat, resized, Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble())) |
|||
|
|||
Log.d(TAG, "🖼️ Preprocessed image: ${resized.cols()}x${resized.rows()}, channels: ${resized.channels()}") |
|||
|
|||
// Prepare a temporary byte array to get pixel data from Mat |
|||
val pixels = ByteArray(INPUT_SIZE * INPUT_SIZE * NUM_CHANNELS) |
|||
resized.get(0, 0, pixels) // Get pixel data in HWC (Height, Width, Channel) byte order |
|||
|
|||
// Convert to ByteBuffer in CHW (Channel, Height, Width) float format |
|||
inputBuffer!!.rewind() |
|||
|
|||
for (c in 0 until NUM_CHANNELS) { // Iterate channels |
|||
for (y in 0 until INPUT_SIZE) { // Iterate height |
|||
for (x in 0 until INPUT_SIZE) { // Iterate width |
|||
// Calculate index in the HWC 'pixels' array |
|||
val pixelIndex = (y * INPUT_SIZE + x) * NUM_CHANNELS + c |
|||
// Get byte value, convert to unsigned int (0-255), then to float, then normalize |
|||
val pixelValue = (pixels[pixelIndex].toInt() and 0xFF) / 255.0f |
|||
inputBuffer!!.putFloat(pixelValue) |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Debug: Check first few values |
|||
inputBuffer!!.rewind() |
|||
val testValues = FloatArray(10) |
|||
inputBuffer!!.asFloatBuffer().get(testValues) |
|||
val testStr = testValues.map { String.format("%.4f", it) }.joinToString(", ") |
|||
Log.d(TAG, "🌐 Input buffer first 10 values (CHW): [$testStr]") |
|||
|
|||
// Clean up |
|||
rgbMat.release() |
|||
resized.release() |
|||
} |
|||
|
|||
/* |
|||
private fun preprocessImage(mat: Mat) { |
|||
// Convert to RGB |
|||
val rgbMat = Mat() |
|||
if (mat.channels() == 4) { |
|||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGRA2RGB) |
|||
} else if (mat.channels() == 3) { |
|||
Imgproc.cvtColor(mat, rgbMat, Imgproc.COLOR_BGR2RGB) |
|||
} else { |
|||
mat.copyTo(rgbMat) |
|||
} |
|||
|
|||
// Resize to input size |
|||
val resized = Mat() |
|||
Imgproc.resize(rgbMat, resized, Size(INPUT_SIZE.toDouble(), INPUT_SIZE.toDouble())) |
|||
|
|||
Log.d(TAG, "🖼️ Preprocessed image: ${resized.cols()}x${resized.rows()}, channels: ${resized.channels()}") |
|||
|
|||
// Convert to ByteBuffer in HWC format [640, 640, 3] |
|||
inputBuffer!!.rewind() |
|||
|
|||
|
|||
val rgbBytes = ByteArray(INPUT_SIZE * INPUT_SIZE * 3) |
|||
resized.get(0, 0, rgbBytes) |
|||
|
|||
// Convert to float and normalize in HWC format (Height, Width, Channels) |
|||
// The data is already in HWC format from OpenCV |
|||
for (i in rgbBytes.indices) { |
|||
val pixelValue = (rgbBytes[i].toInt() and 0xFF) / 255.0f |
|||
inputBuffer!!.putFloat(pixelValue) |
|||
} |
|||
|
|||
// Debug: Check first few values |
|||
inputBuffer!!.rewind() |
|||
val testValues = FloatArray(10) |
|||
inputBuffer!!.asFloatBuffer().get(testValues) |
|||
val testStr = testValues.map { String.format("%.4f", it) }.joinToString(", ") |
|||
Log.d(TAG, "🌐 Input buffer first 10 values: [$testStr]") |
|||
|
|||
// Clean up |
|||
rgbMat.release() |
|||
resized.release() |
|||
} |
|||
*/ |
|||
/* |
|||
private fun postprocess(output: FloatArray, originalWidth: Int, originalHeight: Int): List<Detection> { |
|||
val detections = mutableListOf<Detection>() |
|||
val confidences = mutableListOf<Float>() |
|||
val boxes = mutableListOf<Rect>() |
|||
val classIds = mutableListOf<Int>() |
|||
|
|||
Log.d(TAG, "🔍 Processing detections from output array of size ${output.size}") |
|||
Log.d(TAG, "🔍 Original image size: ${originalWidth}x${originalHeight}") |
|||
|
|||
// YOLOv8 outputs normalized coordinates (0-1), so we scale directly to original image size |
|||
val numFeatures = 100 // From actual model output |
|||
val numDetections = 8400 // From actual model output |
|||
|
|||
var validDetections = 0 |
|||
|
|||
// Process transposed output: [1, 100, 8400] |
|||
// Features are: [x, y, w, h, conf, class0, class1, ..., class94] |
|||
for (i in 0 until numDetections) { |
|||
// In transposed format: feature_idx * numDetections + detection_idx |
|||
// YOLOv8 outputs normalized coordinates (0-1), scale to original image size |
|||
val centerX = output[0 * numDetections + i] * originalWidth // x row |
|||
val centerY = output[1 * numDetections + i] * originalHeight // y row |
|||
val width = output[2 * numDetections + i] * originalWidth // w row |
|||
val height = output[3 * numDetections + i] * originalHeight // h row |
|||
val confidence = output[4 * numDetections + i] // confidence row |
|||
|
|||
// Debug first few detections |
|||
if (i < 3) { |
|||
val rawX = output[0 * numDetections + i] |
|||
val rawY = output[1 * numDetections + i] |
|||
val rawW = output[2 * numDetections + i] |
|||
val rawH = output[3 * numDetections + i] |
|||
Log.d(TAG, "🔍 Detection $i: raw x=${String.format("%.3f", rawX)}, y=${String.format("%.3f", rawY)}, w=${String.format("%.3f", rawW)}, h=${String.format("%.3f", rawH)}") |
|||
Log.d(TAG, "🔍 Detection $i: scaled x=${String.format("%.1f", centerX)}, y=${String.format("%.1f", centerY)}, w=${String.format("%.1f", width)}, h=${String.format("%.1f", height)}") |
|||
} |
|||
|
|||
// Try different YOLOv8 format: no separate confidence, max class score is the confidence |
|||
var maxClassScore = 0f |
|||
var classId = 0 |
|||
|
|||
for (j in 4 until numFeatures) { // Start from feature 4 (after x,y,w,h), no separate conf |
|||
val classIdx = j * numDetections + i |
|||
if (classIdx >= output.size) break |
|||
|
|||
val classScore = output[classIdx] |
|||
if (classScore > maxClassScore) { |
|||
maxClassScore = classScore |
|||
classId = j - 4 // Convert to 0-based class index |
|||
} |
|||
} |
|||
|
|||
// Debug first few with max class scores |
|||
if (i < 3) { |
|||
Log.d(TAG, "🔍 Detection $i: maxClass=${String.format("%.4f", maxClassScore)}, classId=$classId") |
|||
} |
|||
|
|||
if (maxClassScore > CONFIDENCE_THRESHOLD && classId < classNames.size) { |
|||
val x = (centerX - width / 2).toInt() |
|||
val y = (centerY - height / 2).toInt() |
|||
|
|||
boxes.add(Rect(x, y, width.toInt(), height.toInt())) |
|||
confidences.add(maxClassScore) |
|||
classIds.add(classId) |
|||
validDetections++ |
|||
|
|||
if (validDetections <= 3) { |
|||
Log.d(TAG, "✅ Valid transposed detection: class=$classId, conf=${String.format("%.4f", maxClassScore)}") |
|||
} |
|||
} |
|||
} |
|||
|
|||
Log.d(TAG, "🎯 Found ${validDetections} valid detections above confidence threshold") |
|||
|
|||
// Apply Non-Maximum Suppression (simple version) |
|||
val finalDetections = applyNMS(boxes, confidences, classIds) |
|||
|
|||
Log.d(TAG, "🎯 Post-processing complete: ${finalDetections.size} final detections after NMS") |
|||
|
|||
return finalDetections.sortedByDescending { it.confidence } |
|||
} |
|||
*/ |
|||
|
|||
// In postprocess function |
|||
private fun postprocess(output: FloatArray, originalWidth: Int, originalHeight: Int): List<Detection> { |
|||
val detections = mutableListOf<Detection>() |
|||
val confidences = mutableListOf<Float>() |
|||
val boxes = mutableListOf<Rect>() |
|||
val classIds = mutableListOf<Int>() |
|||
|
|||
Log.d(TAG, "🔍 Processing detections from output array of size ${output.size}") |
|||
Log.d(TAG, "🔍 Original image size: ${originalWidth}x${originalHeight}") |
|||
|
|||
// Corrected Interpretation based on YOUR observed output shape [1, 100, 8400] |
|||
// This means attributes (box, confidence, class scores) are in the second dimension (index 1) |
|||
// and detections are in the third dimension (index 2). |
|||
// So, you need to iterate through 'detections' (8400) and for each, access its 'attributes' (100). |
|||
val numAttributesPerDetection = 100 // This is your 'outputShape[1]' |
|||
val totalDetections = 8400 // This is your 'outputShape[2]' |
|||
|
|||
// Loop through each of the 8400 potential detections |
|||
for (i in 0 until totalDetections) { |
|||
// Get the attributes for the i-th detection |
|||
// The data for the i-th detection starts at index 'i' in the 'output' flat array, |
|||
// then it's interleaved. This is why it's better to process from the 3D array directly. |
|||
|
|||
// Re-think: If `output[0]` from interpreter.run is `Array(100) { FloatArray(8400) }` |
|||
// Then it's attributes_per_detection x total_detections. |
|||
// So, output[0][0] is the x-coords for all detections, output[0][1] is y-coords for all detections. |
|||
// Let's assume this structure from your `output` 3D array: |
|||
// output[0][attribute_idx][detection_idx] |
|||
|
|||
val centerX = output[0 * totalDetections + i] // x-coordinate for detection 'i' |
|||
val centerY = output[1 * totalDetections + i] // y-coordinate for detection 'i' |
|||
val width = output[2 * totalDetections + i] // width for detection 'i' |
|||
val height = output[3 * totalDetections + i] // height for detection 'i' |
|||
val objectnessConf = output[4 * totalDetections + i] // Objectness confidence for detection 'i' |
|||
|
|||
// Debug raw values and scaled values before class scores |
|||
if (i < 5) { // Log first 5 detections |
|||
Log.d(TAG, "🔍 Detection $i (pre-scale): x=${String.format("%.3f", centerX)}, y=${String.format("%.3f", centerY)}, w=${String.format("%.3f", width)}, h=${String.format("%.3f", height)}, obj_conf=${String.format("%.4f", objectnessConf)}") |
|||
} |
|||
|
|||
var maxClassScore = 0f |
|||
var classId = -1 // Initialize with -1 to catch issues |
|||
|
|||
// Loop through class scores (starting from index 5 in the attributes list) |
|||
// Indices 5 to 99 are class scores (95 classes total) |
|||
for (j in 5 until numAttributesPerDetection) { |
|||
// Get the class score for detection 'i' and class 'j-5' |
|||
val classScore = output[j * totalDetections + i] |
|||
|
|||
val classScore_sigmoid = 1.0f / (1.0f + Math.exp(-classScore.toDouble())).toFloat() // Apply sigmoid |
|||
|
|||
if (classScore_sigmoid > maxClassScore) { |
|||
maxClassScore = classScore_sigmoid |
|||
classId = j - 5 // Convert to 0-based class index |
|||
} |
|||
} |
|||
|
|||
val objectnessConf_sigmoid = 1.0f / (1.0f + Math.exp(-objectnessConf.toDouble())).toFloat() // Apply sigmoid |
|||
|
|||
// Final confidence: Objectness score multiplied by the max class score |
|||
val finalConfidence = objectnessConf_sigmoid * maxClassScore |
|||
|
|||
// Debug final confidence for first few detections |
|||
if (i < 5) { |
|||
Log.d(TAG, "🔍 Detection $i (post-score): maxClass=${String.format("%.4f", maxClassScore)}, finalConf=${String.format("%.4f", finalConfidence)}, classId=$classId") |
|||
} |
|||
|
|||
// Apply confidence threshold |
|||
if (finalConfidence > CONFIDENCE_THRESHOLD && classId != -1 && classId < classNames.size) { |
|||
// Convert normalized coordinates (0-1) to pixel coordinates based on original image size |
|||
val x = ((centerX - width / 2) * originalWidth).toInt() |
|||
val y = ((centerY - height / 2) * originalHeight).toInt() |
|||
val w = (width * originalWidth).toInt() |
|||
val h = (height * originalHeight).toInt() |
|||
|
|||
// Ensure coordinates are within image bounds |
|||
val x1 = max(0, x) |
|||
val y1 = max(0, y) |
|||
val x2 = min(originalWidth, x + w) |
|||
val y2 = min(originalHeight, y + h) |
|||
|
|||
// Add to lists for NMS |
|||
boxes.add(Rect(x1, y1, x2 - x1, y2 - y1)) |
|||
confidences.add(finalConfidence) |
|||
classIds.add(classId) |
|||
} |
|||
} |
|||
|
|||
Log.d(TAG, "🎯 Found ${boxes.size} detections above confidence threshold before NMS") |
|||
|
|||
// Apply Non-Maximum Suppression (using OpenCV's NMSBoxes which is more robust) |
|||
val finalDetections = applyNMS_OpenCV(boxes, confidences, classIds) |
|||
|
|||
Log.d(TAG, "🎯 Post-processing complete: ${finalDetections.size} final detections after NMS") |
|||
|
|||
return finalDetections.sortedByDescending { it.confidence } |
|||
} |
|||
|
|||
// Replace your applyNMS function with this (or rename your old one and call this one) |
|||
private fun applyNMS_OpenCV(boxes: List<Rect>, confidences: List<Float>, classIds: List<Int>): List<Detection> { |
|||
val finalDetections = mutableListOf<Detection>() |
|||
|
|||
// Convert List<Rect> to List<Rect2d> |
|||
val boxes2d = boxes.map { Rect2d(it.x.toDouble(), it.y.toDouble(), it.width.toDouble(), it.height.toDouble()) } |
|||
|
|||
// Correct way to convert List<Rect2d> to MatOfRect2d |
|||
val boxesMat = MatOfRect2d() |
|||
boxesMat.fromList(boxes2d) // Use fromList to populate the MatOfRect2d |
|||
|
|||
val confsMat = MatOfFloat() |
|||
confsMat.fromList(confidences) // This part was already correct |
|||
|
|||
val indices = MatOfInt() |
|||
|
|||
// OpenCV NMSBoxes |
|||
Dnn.NMSBoxes( |
|||
boxesMat, |
|||
confsMat, |
|||
CONFIDENCE_THRESHOLD, // Confidence threshold (boxes below this are ignored by NMS) |
|||
NMS_THRESHOLD, // IoU threshold (boxes with IoU above this are suppressed) |
|||
indices |
|||
) |
|||
|
|||
val ind = indices.toArray() // Get array of indices to keep |
|||
|
|||
for (i in ind.indices) { |
|||
val idx = ind[i] |
|||
val className = classNames[classIds[idx]] ?: "unknown_${classIds[idx]}" |
|||
finalDetections.add( |
|||
Detection( |
|||
classId = classIds[idx], |
|||
className = className, |
|||
confidence = confidences[idx], |
|||
boundingBox = boxes[idx] // Keep original Rect for the final Detection object if you prefer ints |
|||
) |
|||
) |
|||
} |
|||
|
|||
// Release Mats to prevent memory leaks |
|||
boxesMat.release() |
|||
confsMat.release() |
|||
indices.release() |
|||
|
|||
return finalDetections |
|||
} |
|||
|
|||
private fun applyNMS(boxes: List<Rect>, confidences: List<Float>, classIds: List<Int>): List<Detection> { |
|||
val detections = mutableListOf<Detection>() |
|||
|
|||
// Simple NMS implementation |
|||
val indices = confidences.indices.sortedByDescending { confidences[it] } |
|||
val suppressed = BooleanArray(boxes.size) |
|||
|
|||
for (i in indices) { |
|||
if (suppressed[i]) continue |
|||
|
|||
val className = classNames[classIds[i]] ?: "unknown_${classIds[i]}" |
|||
detections.add( |
|||
Detection( |
|||
classId = classIds[i], |
|||
className = className, |
|||
confidence = confidences[i], |
|||
boundingBox = boxes[i] |
|||
) |
|||
) |
|||
|
|||
// Suppress overlapping boxes |
|||
for (j in indices) { |
|||
if (i != j && !suppressed[j] && classIds[i] == classIds[j]) { |
|||
val iou = calculateIoU(boxes[i], boxes[j]) |
|||
if (iou > NMS_THRESHOLD) { |
|||
suppressed[j] = true |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
return detections |
|||
} |
|||
|
|||
private fun calculateIoU(box1: Rect, box2: Rect): Float { |
|||
val x1 = max(box1.x, box2.x) |
|||
val y1 = max(box1.y, box2.y) |
|||
val x2 = min(box1.x + box1.width, box2.x + box2.width) |
|||
val y2 = min(box1.y + box1.height, box2.y + box2.height) |
|||
|
|||
val intersection = max(0, x2 - x1) * max(0, y2 - y1) |
|||
val area1 = box1.width * box1.height |
|||
val area2 = box2.width * box2.height |
|||
val union = area1 + area2 - intersection |
|||
|
|||
return if (union > 0) intersection.toFloat() / union.toFloat() else 0f |
|||
} |
|||
|
|||
fun testWithStaticImage(): List<Detection> { |
|||
if (!isInitialized) { |
|||
Log.e(TAG, "❌ TensorFlow Lite detector not initialized for test") |
|||
return emptyList() |
|||
} |
|||
|
|||
try { |
|||
Log.i(TAG, "🧪 TESTING WITH STATIC IMAGE (TFLite)") |
|||
|
|||
// Load test image from assets |
|||
val inputStream = context.assets.open("test_pokemon.jpg") |
|||
val bitmap = BitmapFactory.decodeStream(inputStream) |
|||
inputStream.close() |
|||
|
|||
if (bitmap == null) { |
|||
Log.e(TAG, "❌ Failed to load test_pokemon.jpg from assets") |
|||
return emptyList() |
|||
} |
|||
|
|||
Log.i(TAG, "📸 Loaded test image: ${bitmap.width}x${bitmap.height}") |
|||
|
|||
// Convert bitmap to OpenCV Mat |
|||
val mat = Mat() |
|||
Utils.bitmapToMat(bitmap, mat) |
|||
|
|||
Log.i(TAG, "🔄 Converted to Mat: ${mat.cols()}x${mat.rows()}, channels: ${mat.channels()}") |
|||
|
|||
// Run detection |
|||
val detections = detect(mat) |
|||
|
|||
Log.i(TAG, "🎯 TFLite TEST RESULT: ${detections.size} detections found") |
|||
detections.forEachIndexed { index, detection -> |
|||
Log.i(TAG, " $index: ${detection.className} (${String.format("%.3f", detection.confidence)}) at [${detection.boundingBox.x}, ${detection.boundingBox.y}, ${detection.boundingBox.width}, ${detection.boundingBox.height}]") |
|||
} |
|||
|
|||
// Clean up |
|||
mat.release() |
|||
bitmap.recycle() |
|||
|
|||
return detections |
|||
|
|||
} catch (e: Exception) { |
|||
Log.e(TAG, "❌ Error in TFLite static image test", e) |
|||
return emptyList() |
|||
} |
|||
} |
|||
|
|||
private fun copyAssetToInternalStorage(assetName: String): String? { |
|||
return try { |
|||
val inputStream = context.assets.open(assetName) |
|||
val file = context.getFileStreamPath(assetName) |
|||
val outputStream = FileOutputStream(file) |
|||
|
|||
inputStream.copyTo(outputStream) |
|||
inputStream.close() |
|||
outputStream.close() |
|||
|
|||
file.absolutePath |
|||
} catch (e: IOException) { |
|||
Log.e(TAG, "Error copying asset $assetName", e) |
|||
null |
|||
} |
|||
} |
|||
|
|||
fun release() { |
|||
interpreter?.close() |
|||
interpreter = null |
|||
isInitialized = false |
|||
Log.d(TAG, "TensorFlow Lite detector released") |
|||
} |
|||
} |
|||
@ -0,0 +1,211 @@ |
|||
package com.quillstudios.pokegoalshelper.ml |
|||
|
|||
import android.graphics.Bitmap |
|||
import android.util.Log |
|||
import org.opencv.android.Utils |
|||
import org.opencv.core.* |
|||
import org.opencv.imgproc.Imgproc |
|||
import kotlin.math.min |
|||
|
|||
/** |
|||
* Utility class for image preprocessing operations used in ML inference. |
|||
* Extracted for better separation of concerns and reusability. |
|||
*/ |
|||
object ImagePreprocessor |
|||
{ |
|||
private const val TAG = "ImagePreprocessor" |
|||
|
|||
/** |
|||
* Data class representing preprocessing configuration. |
|||
*/ |
|||
data class PreprocessConfig( |
|||
val targetSize: Int = 640, |
|||
val numChannels: Int = 3, |
|||
val normalizeRange: Pair<Float, Float> = Pair(0.0f, 1.0f), |
|||
val useLetterboxing: Boolean = true, |
|||
val colorConversion: Int = Imgproc.COLOR_BGR2RGB |
|||
) |
|||
|
|||
/** |
|||
* Result of preprocessing operation containing the processed data and metadata. |
|||
*/ |
|||
data class PreprocessResult( |
|||
val data: Array<Array<Array<FloatArray>>>, |
|||
val scale: Float, |
|||
val offsetX: Float, |
|||
val offsetY: Float, |
|||
val originalWidth: Int, |
|||
val originalHeight: Int |
|||
) |
|||
|
|||
/** |
|||
* Preprocess a bitmap for ML inference with the specified configuration. |
|||
* |
|||
* @param bitmap The input bitmap to preprocess |
|||
* @param config Preprocessing configuration |
|||
* @return PreprocessResult containing processed data and transformation metadata |
|||
*/ |
|||
fun preprocessBitmap(bitmap: Bitmap, config: PreprocessConfig = PreprocessConfig()): PreprocessResult |
|||
{ |
|||
val original_width = bitmap.width |
|||
val original_height = bitmap.height |
|||
|
|||
// Convert bitmap to Mat |
|||
val original_mat = Mat() |
|||
Utils.bitmapToMat(bitmap, original_mat) |
|||
|
|||
try |
|||
{ |
|||
val (processed_mat, scale, offset_x, offset_y) = if (config.useLetterboxing) |
|||
{ |
|||
applyLetterboxing(original_mat, config.targetSize) |
|||
} |
|||
else |
|||
{ |
|||
// Simple resize without letterboxing |
|||
val resized_mat = Mat() |
|||
Imgproc.resize(original_mat, resized_mat, Size(config.targetSize.toDouble(), config.targetSize.toDouble())) |
|||
val scale_x = config.targetSize.toFloat() / original_width |
|||
val scale_y = config.targetSize.toFloat() / original_height |
|||
val avg_scale = (scale_x + scale_y) / 2.0f |
|||
ResizeResult(resized_mat, avg_scale, 0.0f, 0.0f) |
|||
} |
|||
|
|||
// Apply color conversion |
|||
if (config.colorConversion != -1) |
|||
{ |
|||
Imgproc.cvtColor(processed_mat, processed_mat, config.colorConversion) |
|||
} |
|||
|
|||
// Normalize |
|||
val normalized_mat = Mat() |
|||
val normalization_factor = (config.normalizeRange.second - config.normalizeRange.first) / 255.0 |
|||
processed_mat.convertTo(normalized_mat, CvType.CV_32F, normalization_factor, config.normalizeRange.first.toDouble()) |
|||
|
|||
// Convert to array format [1, channels, height, width] |
|||
val data = matToArray(normalized_mat, config.numChannels, config.targetSize) |
|||
|
|||
// Clean up intermediate Mats |
|||
processed_mat.release() |
|||
normalized_mat.release() |
|||
|
|||
return PreprocessResult( |
|||
data = data, |
|||
scale = scale, |
|||
offsetX = offset_x, |
|||
offsetY = offset_y, |
|||
originalWidth = original_width, |
|||
originalHeight = original_height |
|||
) |
|||
} |
|||
finally |
|||
{ |
|||
original_mat.release() |
|||
} |
|||
} |
|||
|
|||
/** |
|||
* Transform coordinates from model space back to original image space. |
|||
* |
|||
* @param modelX X coordinate in model space |
|||
* @param modelY Y coordinate in model space |
|||
* @param preprocessResult The preprocessing result containing transformation metadata |
|||
* @return Pair of (originalX, originalY) coordinates |
|||
*/ |
|||
fun transformCoordinates( |
|||
modelX: Float, |
|||
modelY: Float, |
|||
preprocessResult: PreprocessResult |
|||
): Pair<Float, Float> |
|||
{ |
|||
val original_x = ((modelX - preprocessResult.offsetX) / preprocessResult.scale) |
|||
.coerceIn(0f, preprocessResult.originalWidth.toFloat()) |
|||
val original_y = ((modelY - preprocessResult.offsetY) / preprocessResult.scale) |
|||
.coerceIn(0f, preprocessResult.originalHeight.toFloat()) |
|||
|
|||
return Pair(original_x, original_y) |
|||
} |
|||
|
|||
/** |
|||
* Transform a bounding box from model space back to original image space. |
|||
* |
|||
* @param boundingBox Bounding box in model space |
|||
* @param preprocessResult The preprocessing result containing transformation metadata |
|||
* @return Transformed bounding box in original image space |
|||
*/ |
|||
fun transformBoundingBox( |
|||
boundingBox: BoundingBox, |
|||
preprocessResult: PreprocessResult |
|||
): BoundingBox |
|||
{ |
|||
val (left, top) = transformCoordinates(boundingBox.left, boundingBox.top, preprocessResult) |
|||
val (right, bottom) = transformCoordinates(boundingBox.right, boundingBox.bottom, preprocessResult) |
|||
|
|||
return BoundingBox(left, top, right, bottom) |
|||
} |
|||
|
|||
private data class ResizeResult( |
|||
val mat: Mat, |
|||
val scale: Float, |
|||
val offsetX: Float, |
|||
val offsetY: Float |
|||
) |
|||
|
|||
private fun applyLetterboxing(inputMat: Mat, targetSize: Int): ResizeResult |
|||
{ |
|||
val scale = min(targetSize.toFloat() / inputMat.width(), targetSize.toFloat() / inputMat.height()) |
|||
|
|||
val new_width = (inputMat.width() * scale).toInt() |
|||
val new_height = (inputMat.height() * scale).toInt() |
|||
|
|||
// Resize while maintaining aspect ratio |
|||
val resized_mat = Mat() |
|||
Imgproc.resize(inputMat, resized_mat, Size(new_width.toDouble(), new_height.toDouble())) |
|||
|
|||
// Create letterboxed image (centered) |
|||
val letterbox_mat = Mat.zeros(targetSize, targetSize, CvType.CV_8UC3) |
|||
val offset_x = (targetSize - new_width) / 2 |
|||
val offset_y = (targetSize - new_height) / 2 |
|||
|
|||
val roi = Rect(offset_x, offset_y, new_width, new_height) |
|||
resized_mat.copyTo(letterbox_mat.submat(roi)) |
|||
|
|||
// Clean up intermediate mat |
|||
resized_mat.release() |
|||
|
|||
return ResizeResult(letterbox_mat, scale, offset_x.toFloat(), offset_y.toFloat()) |
|||
} |
|||
|
|||
private fun matToArray(mat: Mat, numChannels: Int, size: Int): Array<Array<Array<FloatArray>>> |
|||
{ |
|||
val data = Array(1) { Array(numChannels) { Array(size) { FloatArray(size) } } } |
|||
|
|||
for (c in 0 until numChannels) |
|||
{ |
|||
for (h in 0 until size) |
|||
{ |
|||
for (w in 0 until size) |
|||
{ |
|||
val pixel = mat.get(h, w) |
|||
data[0][c][h][w] = if (pixel != null && c < pixel.size) pixel[c].toFloat() else 0.0f |
|||
} |
|||
} |
|||
} |
|||
|
|||
return data |
|||
} |
|||
|
|||
/** |
|||
* Log preprocessing statistics for debugging. |
|||
*/ |
|||
fun logPreprocessStats(result: PreprocessResult) |
|||
{ |
|||
Log.d(TAG, """ |
|||
📊 Preprocessing Stats: |
|||
- Original size: ${result.originalWidth}x${result.originalHeight} |
|||
- Scale factor: ${result.scale} |
|||
- Offset: (${result.offsetX}, ${result.offsetY}) |
|||
- Data shape: [${result.data.size}, ${result.data[0].size}, ${result.data[0][0].size}, ${result.data[0][0][0].size}] |
|||
""".trimIndent()) |
|||
} |
|||
} |
|||
@ -0,0 +1,77 @@ |
|||
package com.quillstudios.pokegoalshelper.ml |
|||
|
|||
import android.graphics.Bitmap |
|||
|
|||
/** |
|||
* Interface for ML model inference operations. |
|||
* Separates ML concerns from the main service for better architecture. |
|||
*/ |
|||
interface MLInferenceEngine |
|||
{ |
|||
/** |
|||
* Initialize the ML model and prepare for inference. |
|||
* @return true if initialization was successful, false otherwise |
|||
*/ |
|||
suspend fun initialize(): Boolean |
|||
|
|||
/** |
|||
* Perform object detection on the provided image. |
|||
* @param image The bitmap image to analyze |
|||
* @return List of detected objects, empty if no objects found |
|||
*/ |
|||
suspend fun detect(image: Bitmap): List<Detection> |
|||
|
|||
/** |
|||
* Set the confidence threshold for detections. |
|||
* @param threshold Minimum confidence value (0.0 to 1.0) |
|||
*/ |
|||
fun setConfidenceThreshold(threshold: Float) |
|||
|
|||
/** |
|||
* Set the class filter for detections. |
|||
* @param className Class name to filter by, null for no filtering |
|||
*/ |
|||
fun setClassFilter(className: String?) |
|||
|
|||
/** |
|||
* Check if the engine is ready for inference. |
|||
* @return true if initialized and ready, false otherwise |
|||
*/ |
|||
fun isReady(): Boolean |
|||
|
|||
/** |
|||
* Get the current inference time statistics. |
|||
* @return Pair of (last inference time ms, average inference time ms) |
|||
*/ |
|||
fun getInferenceStats(): Pair<Long, Long> |
|||
|
|||
/** |
|||
* Clean up all resources. Should be called when engine is no longer needed. |
|||
*/ |
|||
fun cleanup() |
|||
} |
|||
|
|||
/** |
|||
* Data class representing a detected object. |
|||
*/ |
|||
data class Detection( |
|||
val className: String, |
|||
val confidence: Float, |
|||
val boundingBox: BoundingBox |
|||
) |
|||
|
|||
/** |
|||
* Data class representing a bounding box. |
|||
*/ |
|||
data class BoundingBox( |
|||
val left: Float, |
|||
val top: Float, |
|||
val right: Float, |
|||
val bottom: Float |
|||
) |
|||
{ |
|||
val width: Float get() = right - left |
|||
val height: Float get() = bottom - top |
|||
val centerX: Float get() = left + width / 2 |
|||
val centerY: Float get() = top + height / 2 |
|||
} |
|||
File diff suppressed because it is too large
Loading…
Reference in new issue