Browse Source

refactor: implement dynamic model metadata extraction from ONNX runtime

- Replace hardcoded NUM_DETECTIONS, NUM_CLASSES, NMS_OUTPUT_FEATURES_PER_DETECTION
  with runtime extracted values from ONNX model session
- Add extractModelMetadata() method to dynamically determine model properties
- Update parseNMSOutput, matToTensorArray, and postprocessResults to use dynamic variables
- Fallback to constants if runtime extraction fails for safety
- Replace remaining magic numbers with named constants

Related todos: #extract-model-metadata

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
arch-002-ml-inference-engine
Quildra 5 months ago
parent
commit
79a33a700b
  1. 27
      app/src/main/java/com/quillstudios/pokegoalshelper/ScreenCaptureService.kt
  2. 157
      app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt

27
app/src/main/java/com/quillstudios/pokegoalshelper/ScreenCaptureService.kt

@ -95,6 +95,15 @@ class ScreenCaptureService : Service() {
private const val NOTIFICATION_ID = 1001
private const val CHANNEL_ID = "screen_capture_channel"
// Timeout Constants
private const val ANALYSIS_STUCK_TIMEOUT_MS = 30000L // 30 seconds
private const val OCR_TASK_TIMEOUT_SECONDS = 10L
private const val DEFAULT_CAPTURE_INTERVAL_MS = 2000L // 2 seconds
private const val OCR_INDIVIDUAL_TIMEOUT_SECONDS = 5L
// UI Constants
private const val MIN_PIXEL_CHANNELS = 3
const val ACTION_START = "START_SCREEN_CAPTURE"
const val ACTION_STOP = "STOP_SCREEN_CAPTURE"
const val EXTRA_RESULT_DATA = "result_data"
@ -120,7 +129,7 @@ class ScreenCaptureService : Service() {
private var enhancedFloatingFAB: EnhancedFloatingFAB? = null
private val handler = Handler(Looper.getMainLooper())
private var captureInterval = 2000L // Capture every 2 seconds
private var captureInterval = DEFAULT_CAPTURE_INTERVAL_MS
private var autoProcessing = false // Disable automatic processing
// Thread pool for OCR processing (4 threads for parallel text extraction)
@ -383,7 +392,7 @@ class ScreenCaptureService : Service() {
val centerY = mat.rows() / 2
val centerX = mat.cols() / 2
val pixel = mat.get(centerY, centerX)
if (pixel != null && pixel.size >= 3) {
if (pixel != null && pixel.size >= MIN_PIXEL_CHANNELS) {
val b = pixel[0].toInt()
val g = pixel[1].toInt()
val r = pixel[2].toInt()
@ -412,7 +421,7 @@ class ScreenCaptureService : Service() {
// Check if analysis has been stuck for too long (30 seconds max)
val currentTime = System.currentTimeMillis()
if (isAnalyzing && (currentTime - analysisStartTime) > 30000) {
if (isAnalyzing && (currentTime - analysisStartTime) > ANALYSIS_STUCK_TIMEOUT_MS) {
PGHLog.w(TAG, "⚠️ Analysis stuck for >30s, resetting flag")
isAnalyzing = false
}
@ -553,10 +562,10 @@ class ScreenCaptureService : Service() {
val levelDetection = detectionMap["pokemon_level"]?.maxByOrNull { it.boundingBox.width }
submitLevelOCRTask("level", mat, levelDetection, ocrResults, latch)
// Wait for all OCR tasks to complete (max 10 seconds total)
val completed = latch.await(10, TimeUnit.SECONDS)
// Wait for all OCR tasks to complete (max timeout)
val completed = latch.await(OCR_TASK_TIMEOUT_SECONDS, TimeUnit.SECONDS)
if (!completed) {
PGHLog.w(TAG, "⏱️ Some OCR tasks timed out after 10 seconds")
PGHLog.w(TAG, "⏱️ Some OCR tasks timed out after ${OCR_TASK_TIMEOUT_SECONDS} seconds")
}
// Extract results
@ -952,10 +961,10 @@ class ScreenCaptureService : Service() {
}
// Wait for OCR to complete (max 5 seconds to allow ML Kit to work)
val completed = latch.await(5, TimeUnit.SECONDS)
val completed = latch.await(OCR_INDIVIDUAL_TIMEOUT_SECONDS, TimeUnit.SECONDS)
if (!completed) {
PGHLog.e(TAG, "⏱️ OCR timeout for $purpose after 5 seconds")
PGHLog.e(TAG, "⏱️ OCR timeout for $purpose after ${OCR_INDIVIDUAL_TIMEOUT_SECONDS} seconds")
return null
}
@ -1246,7 +1255,7 @@ class ScreenCaptureService : Service() {
// Proper executor shutdown with timeout
ocrExecutor.shutdown()
try {
if (!ocrExecutor.awaitTermination(5, TimeUnit.SECONDS)) {
if (!ocrExecutor.awaitTermination(OCR_INDIVIDUAL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) {
PGHLog.w(TAG, "OCR executor did not terminate gracefully, forcing shutdown")
ocrExecutor.shutdownNow()
}

157
app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt

@ -66,6 +66,28 @@ class YOLOInferenceEngine(
private const val ENABLE_CONFIDENCE_MAPPING = true
private const val RAW_TO_MOBILE_SCALE = 0.75f // Based on observation that mobile shows lower conf
// Image Processing Constants
private const val GAUSSIAN_BLUR_KERNEL_SIZE = 3.0
private const val GAUSSIAN_BLUR_SIGMA = 0.5
private const val CLAHE_CLIP_LIMIT = 1.5
private const val CLAHE_TILE_SIZE = 8.0
private const val SHARPENING_CENTER_VALUE = 5.0
private const val SHARPENING_EDGE_VALUE = -1.0
private const val LETTERBOX_PADDING_GRAY = 114.0
private const val PIXEL_NORMALIZATION_FACTOR = 255.0f
private const val COLOR_CHANNEL_MASK = 0xFF
// NMS Output Parsing Constants
private const val NMS_OUTPUT_FEATURES_PER_DETECTION = 6 // [x1, y1, x2, y2, confidence, class_id]
private const val NMS_COORDINATE_X1_OFFSET = 0
private const val NMS_COORDINATE_Y1_OFFSET = 1
private const val NMS_COORDINATE_X2_OFFSET = 2
private const val NMS_COORDINATE_Y2_OFFSET = 3
private const val NMS_CONFIDENCE_OFFSET = 4
private const val NMS_CLASS_ID_OFFSET = 5
private const val MIN_DEBUG_CONFIDENCE = 0.1f
private const val MAX_DEBUG_DETECTIONS_TO_LOG = 3
fun setCoordinateMode(mode: String)
{
COORD_TRANSFORM_MODE = mode
@ -96,6 +118,12 @@ class YOLOInferenceEngine(
private var ortEnvironment: OrtEnvironment? = null
private var isInitialized = false
// Dynamic model metadata (extracted at runtime)
private var modelInputSize: Int = INPUT_SIZE // fallback to constant
private var modelNumDetections: Int = NUM_DETECTIONS // fallback to constant
private var modelNumClasses: Int = NUM_CLASSES // fallback to constant
private var modelOutputFeatures: Int = NMS_OUTPUT_FEATURES_PER_DETECTION // fallback to constant
// Shared thread pool for preprocessing operations (prevents creating new pools per detection)
private val preprocessingExecutor = Executors.newFixedThreadPool(config.threadPoolSize)
@ -232,6 +260,9 @@ class YOLOInferenceEngine(
ortSession = ortEnvironment?.createSession(model_path, session_options)
?: throw RuntimeException("Failed to create ONNX session")
// Extract model metadata dynamically
extractModelMetadata()
PGHLog.i(TAG, "✅ ONNX YOLO detector initialized successfully")
isInitialized = true
}.onError { errorType, exception, message ->
@ -511,7 +542,7 @@ class YOLOInferenceEngine(
if (processed_mat.type() == CvType.CV_8UC3 || processed_mat.type() == CvType.CV_8UC1)
{
// Use Gaussian blur as a more reliable alternative to bilateral filter
Imgproc.GaussianBlur(processed_mat, denoised, Size(3.0, 3.0), 0.5)
Imgproc.GaussianBlur(processed_mat, denoised, Size(GAUSSIAN_BLUR_KERNEL_SIZE, GAUSSIAN_BLUR_KERNEL_SIZE), GAUSSIAN_BLUR_SIGMA)
processed_mat.release()
PGHLog.d(TAG, "✅ Ultralytics preprocessing complete with Gaussian smoothing")
return denoised
@ -558,7 +589,7 @@ class YOLOInferenceEngine(
inputMat.copyTo(gray)
}
val clahe = Imgproc.createCLAHE(1.5, Size(8.0, 8.0))
val clahe = Imgproc.createCLAHE(CLAHE_CLIP_LIMIT, Size(CLAHE_TILE_SIZE, CLAHE_TILE_SIZE))
clahe.apply(gray, enhanced_gray)
// Convert back to color
@ -592,7 +623,7 @@ class YOLOInferenceEngine(
{
// Create sharpening kernel
val kernel = Mat(3, 3, CvType.CV_32F)
kernel.put(0, 0, 0.0, -1.0, 0.0, -1.0, 5.0, -1.0, 0.0, -1.0, 0.0)
kernel.put(0, 0, 0.0, SHARPENING_EDGE_VALUE, 0.0, SHARPENING_EDGE_VALUE, SHARPENING_CENTER_VALUE, SHARPENING_EDGE_VALUE, 0.0, SHARPENING_EDGE_VALUE, 0.0)
// Apply filter
Imgproc.filter2D(inputMat, sharpened, -1, kernel)
@ -638,7 +669,7 @@ class YOLOInferenceEngine(
Imgproc.resize(inputMat, resized, Size(new_width.toDouble(), new_height.toDouble()), 0.0, 0.0, Imgproc.INTER_CUBIC)
// Create letterbox with padding
val letterboxed = Mat(targetHeight, targetWidth, inputMat.type(), Scalar(114.0, 114.0, 114.0)) // Gray padding
val letterboxed = Mat(targetHeight, targetWidth, inputMat.type(), Scalar(LETTERBOX_PADDING_GRAY, LETTERBOX_PADDING_GRAY, LETTERBOX_PADDING_GRAY)) // Gray padding
// Calculate padding offsets
val offset_x = (targetWidth - new_width) / 2
@ -666,23 +697,23 @@ class YOLOInferenceEngine(
try
{
// Create array format [1, 3, height, width]
val data = Array(1) { Array(NUM_CHANNELS) { Array(INPUT_SIZE) { FloatArray(INPUT_SIZE) } } }
val data = Array(1) { Array(NUM_CHANNELS) { Array(modelInputSize) { FloatArray(modelInputSize) } } }
// Get RGB bytes
val rgb_bytes = ByteArray(INPUT_SIZE * INPUT_SIZE * 3)
val rgb_bytes = ByteArray(modelInputSize * modelInputSize * 3)
rgb_mat.get(0, 0, rgb_bytes)
// Convert HWC to CHW format and normalize
for (c in 0 until NUM_CHANNELS)
{
for (h in 0 until INPUT_SIZE)
for (h in 0 until modelInputSize)
{
for (w in 0 until INPUT_SIZE)
for (w in 0 until modelInputSize)
{
val pixel_idx = (h * INPUT_SIZE + w) * 3 + c
val pixel_idx = (h * modelInputSize + w) * 3 + c
data[0][c][h][w] = if (pixel_idx < rgb_bytes.size)
{
(rgb_bytes[pixel_idx].toInt() and 0xFF) / 255.0f
(rgb_bytes[pixel_idx].toInt() and COLOR_CHANNEL_MASK) / PIXEL_NORMALIZATION_FACTOR
} else 0.0f
}
}
@ -702,7 +733,89 @@ class YOLOInferenceEngine(
private fun postprocessResults(output: Array<Array<FloatArray>>, originalSize: Size): List<Detection>
{
val flat_output = output[0].flatMap { it.asIterable() }.toFloatArray()
return parseNMSOutput(flat_output, originalSize.width.toInt(), originalSize.height.toInt(), INPUT_SIZE)
return parseNMSOutput(flat_output, originalSize.width.toInt(), originalSize.height.toInt(), modelInputSize)
}
/**
* Extract model metadata dynamically from ONNX session
*/
private fun extractModelMetadata()
{
try
{
val session = ortSession ?: return
// Get input info
val inputInfo = session.inputInfo
val inputShape = inputInfo.values.firstOrNull()?.info as? ai.onnxruntime.TensorInfo
if (inputShape != null)
{
val shape = inputShape.shape
if (shape.size >= 3)
{
// Typically YOLO input is [batch, channels, height, width] or [batch, height, width, channels]
val extractedInputSize = maxOf(shape[2].toInt(), shape[3].toInt()) // Take max of height/width
if (extractedInputSize > 0)
{
modelInputSize = extractedInputSize
PGHLog.i(TAG, "📐 Extracted input size from model: $modelInputSize")
}
}
}
// Get output info
val outputInfo = session.outputInfo
val outputTensorInfo = outputInfo.values.firstOrNull()?.info as? ai.onnxruntime.TensorInfo
if (outputTensorInfo != null)
{
val outputShape = outputTensorInfo.shape
PGHLog.i(TAG, "📊 Model output shape: ${outputShape.contentToString()}")
if (outputShape.size >= 2)
{
// For NMS output: typically [batch, num_detections, features_per_detection]
// For raw output: typically [batch, num_detections, 4+1+num_classes] or similar
val numDetections = outputShape[1].toInt()
val featuresPerDetection = outputShape[2].toInt()
if (numDetections > 0)
{
modelNumDetections = numDetections
PGHLog.i(TAG, "🔢 Extracted num detections from model: $modelNumDetections")
}
if (featuresPerDetection > 0)
{
modelOutputFeatures = featuresPerDetection
PGHLog.i(TAG, "📊 Extracted output features per detection: $modelOutputFeatures")
// Try to infer number of classes from output features
// NMS output: [x1, y1, x2, y2, confidence, class_id] = 6 features
// Raw output: [x, y, w, h, confidence, class1, class2, ..., classN] = 5 + num_classes
if (featuresPerDetection == 6)
{
PGHLog.i(TAG, "🎯 Detected NMS post-processed output format")
}
else if (featuresPerDetection > 5)
{
val inferredNumClasses = featuresPerDetection - 5 // 4 coords + 1 confidence
if (inferredNumClasses > 0 && inferredNumClasses <= 1000) // Reasonable range
{
modelNumClasses = inferredNumClasses
PGHLog.i(TAG, "🏷️ Inferred num classes from output: $modelNumClasses")
}
}
}
}
}
PGHLog.i(TAG, "📋 Final model metadata - Input: ${modelInputSize}x${modelInputSize}, Detections: $modelNumDetections, Features: $modelOutputFeatures, Classes: $modelNumClasses")
}
catch (e: Exception)
{
PGHLog.w(TAG, "⚠️ Failed to extract model metadata, using fallback constants", e)
}
}
/**
@ -892,10 +1005,10 @@ class YOLOInferenceEngine(
{
val detections = mutableListOf<Detection>()
val num_detections = 300 // From model output [1, 300, 6]
val features_per_detection = 6 // [x1, y1, x2, y2, confidence, class_id]
val num_detections = modelNumDetections
val features_per_detection = modelOutputFeatures
PGHLog.d(TAG, "🔍 Parsing NMS output: 300 post-processed detections")
PGHLog.d(TAG, "🔍 Parsing NMS output: $num_detections post-processed detections")
var valid_detections = 0
@ -905,17 +1018,17 @@ class YOLOInferenceEngine(
// Extract and transform coordinates from model output
val coords = transformCoordinates(
rawX1 = output[base_idx],
rawY1 = output[base_idx + 1],
rawX2 = output[base_idx + 2],
rawY2 = output[base_idx + 3],
rawX1 = output[base_idx + NMS_COORDINATE_X1_OFFSET],
rawY1 = output[base_idx + NMS_COORDINATE_Y1_OFFSET],
rawX2 = output[base_idx + NMS_COORDINATE_X2_OFFSET],
rawY2 = output[base_idx + NMS_COORDINATE_Y2_OFFSET],
originalWidth = originalWidth,
originalHeight = originalHeight,
inputScale = inputScale
)
val confidence = output[base_idx + 4]
val class_id = output[base_idx + 5].toInt()
val confidence = output[base_idx + NMS_CONFIDENCE_OFFSET]
val class_id = output[base_idx + NMS_CLASS_ID_OFFSET].toInt()
// Apply confidence mapping if enabled
val mapped_confidence = if (ENABLE_CONFIDENCE_MAPPING)
@ -938,7 +1051,7 @@ class YOLOInferenceEngine(
}
// Debug logging for all detections if enabled
if (SHOW_ALL_CONFIDENCES && mapped_confidence > 0.1f)
if (SHOW_ALL_CONFIDENCES && mapped_confidence > MIN_DEBUG_CONFIDENCE)
{
PGHLog.d(TAG, "🔍 [DEBUG] Class: $class_name (ID: $class_id), Confidence: %.3f, Original: %.3f".format(mapped_confidence, confidence))
}
@ -969,7 +1082,7 @@ class YOLOInferenceEngine(
valid_detections++
if (valid_detections <= 3)
if (valid_detections <= MAX_DEBUG_DETECTIONS_TO_LOG)
{
PGHLog.d(TAG, "✅ Valid NMS detection: class=$class_id ($class_name), conf=${String.format("%.4f", mapped_confidence)}")
}

Loading…
Cancel
Save