From 79a33a700bc4e5b2404fdc64cc3ef5618d8859be Mon Sep 17 00:00:00 2001 From: Quildra Date: Sat, 2 Aug 2025 15:25:19 +0100 Subject: [PATCH] refactor: implement dynamic model metadata extraction from ONNX runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace hardcoded NUM_DETECTIONS, NUM_CLASSES, NMS_OUTPUT_FEATURES_PER_DETECTION with runtime extracted values from ONNX model session - Add extractModelMetadata() method to dynamically determine model properties - Update parseNMSOutput, matToTensorArray, and postprocessResults to use dynamic variables - Fallback to constants if runtime extraction fails for safety - Replace remaining magic numbers with named constants Related todos: #extract-model-metadata 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../pokegoalshelper/ScreenCaptureService.kt | 27 ++- .../pokegoalshelper/ml/YOLOInferenceEngine.kt | 157 +++++++++++++++--- 2 files changed, 153 insertions(+), 31 deletions(-) diff --git a/app/src/main/java/com/quillstudios/pokegoalshelper/ScreenCaptureService.kt b/app/src/main/java/com/quillstudios/pokegoalshelper/ScreenCaptureService.kt index f5d1f4e..cc87849 100644 --- a/app/src/main/java/com/quillstudios/pokegoalshelper/ScreenCaptureService.kt +++ b/app/src/main/java/com/quillstudios/pokegoalshelper/ScreenCaptureService.kt @@ -94,6 +94,15 @@ class ScreenCaptureService : Service() { private const val TAG = "ScreenCaptureService" private const val NOTIFICATION_ID = 1001 private const val CHANNEL_ID = "screen_capture_channel" + + // Timeout Constants + private const val ANALYSIS_STUCK_TIMEOUT_MS = 30000L // 30 seconds + private const val OCR_TASK_TIMEOUT_SECONDS = 10L + private const val DEFAULT_CAPTURE_INTERVAL_MS = 2000L // 2 seconds + private const val OCR_INDIVIDUAL_TIMEOUT_SECONDS = 5L + + // UI Constants + private const val MIN_PIXEL_CHANNELS = 3 const val ACTION_START = "START_SCREEN_CAPTURE" const val ACTION_STOP = "STOP_SCREEN_CAPTURE" @@ -120,7 +129,7 @@ class ScreenCaptureService : Service() { private var enhancedFloatingFAB: EnhancedFloatingFAB? = null private val handler = Handler(Looper.getMainLooper()) - private var captureInterval = 2000L // Capture every 2 seconds + private var captureInterval = DEFAULT_CAPTURE_INTERVAL_MS private var autoProcessing = false // Disable automatic processing // Thread pool for OCR processing (4 threads for parallel text extraction) @@ -383,7 +392,7 @@ class ScreenCaptureService : Service() { val centerY = mat.rows() / 2 val centerX = mat.cols() / 2 val pixel = mat.get(centerY, centerX) - if (pixel != null && pixel.size >= 3) { + if (pixel != null && pixel.size >= MIN_PIXEL_CHANNELS) { val b = pixel[0].toInt() val g = pixel[1].toInt() val r = pixel[2].toInt() @@ -412,7 +421,7 @@ class ScreenCaptureService : Service() { // Check if analysis has been stuck for too long (30 seconds max) val currentTime = System.currentTimeMillis() - if (isAnalyzing && (currentTime - analysisStartTime) > 30000) { + if (isAnalyzing && (currentTime - analysisStartTime) > ANALYSIS_STUCK_TIMEOUT_MS) { PGHLog.w(TAG, "⚠️ Analysis stuck for >30s, resetting flag") isAnalyzing = false } @@ -553,10 +562,10 @@ class ScreenCaptureService : Service() { val levelDetection = detectionMap["pokemon_level"]?.maxByOrNull { it.boundingBox.width } submitLevelOCRTask("level", mat, levelDetection, ocrResults, latch) - // Wait for all OCR tasks to complete (max 10 seconds total) - val completed = latch.await(10, TimeUnit.SECONDS) + // Wait for all OCR tasks to complete (max timeout) + val completed = latch.await(OCR_TASK_TIMEOUT_SECONDS, TimeUnit.SECONDS) if (!completed) { - PGHLog.w(TAG, "⏱️ Some OCR tasks timed out after 10 seconds") + PGHLog.w(TAG, "⏱️ Some OCR tasks timed out after ${OCR_TASK_TIMEOUT_SECONDS} seconds") } // Extract results @@ -952,10 +961,10 @@ class ScreenCaptureService : Service() { } // Wait for OCR to complete (max 5 seconds to allow ML Kit to work) - val completed = latch.await(5, TimeUnit.SECONDS) + val completed = latch.await(OCR_INDIVIDUAL_TIMEOUT_SECONDS, TimeUnit.SECONDS) if (!completed) { - PGHLog.e(TAG, "⏱️ OCR timeout for $purpose after 5 seconds") + PGHLog.e(TAG, "⏱️ OCR timeout for $purpose after ${OCR_INDIVIDUAL_TIMEOUT_SECONDS} seconds") return null } @@ -1246,7 +1255,7 @@ class ScreenCaptureService : Service() { // Proper executor shutdown with timeout ocrExecutor.shutdown() try { - if (!ocrExecutor.awaitTermination(5, TimeUnit.SECONDS)) { + if (!ocrExecutor.awaitTermination(OCR_INDIVIDUAL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { PGHLog.w(TAG, "OCR executor did not terminate gracefully, forcing shutdown") ocrExecutor.shutdownNow() } diff --git a/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt b/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt index 00c0a29..8c5c862 100644 --- a/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt +++ b/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt @@ -66,6 +66,28 @@ class YOLOInferenceEngine( private const val ENABLE_CONFIDENCE_MAPPING = true private const val RAW_TO_MOBILE_SCALE = 0.75f // Based on observation that mobile shows lower conf + // Image Processing Constants + private const val GAUSSIAN_BLUR_KERNEL_SIZE = 3.0 + private const val GAUSSIAN_BLUR_SIGMA = 0.5 + private const val CLAHE_CLIP_LIMIT = 1.5 + private const val CLAHE_TILE_SIZE = 8.0 + private const val SHARPENING_CENTER_VALUE = 5.0 + private const val SHARPENING_EDGE_VALUE = -1.0 + private const val LETTERBOX_PADDING_GRAY = 114.0 + private const val PIXEL_NORMALIZATION_FACTOR = 255.0f + private const val COLOR_CHANNEL_MASK = 0xFF + + // NMS Output Parsing Constants + private const val NMS_OUTPUT_FEATURES_PER_DETECTION = 6 // [x1, y1, x2, y2, confidence, class_id] + private const val NMS_COORDINATE_X1_OFFSET = 0 + private const val NMS_COORDINATE_Y1_OFFSET = 1 + private const val NMS_COORDINATE_X2_OFFSET = 2 + private const val NMS_COORDINATE_Y2_OFFSET = 3 + private const val NMS_CONFIDENCE_OFFSET = 4 + private const val NMS_CLASS_ID_OFFSET = 5 + private const val MIN_DEBUG_CONFIDENCE = 0.1f + private const val MAX_DEBUG_DETECTIONS_TO_LOG = 3 + fun setCoordinateMode(mode: String) { COORD_TRANSFORM_MODE = mode @@ -96,6 +118,12 @@ class YOLOInferenceEngine( private var ortEnvironment: OrtEnvironment? = null private var isInitialized = false + // Dynamic model metadata (extracted at runtime) + private var modelInputSize: Int = INPUT_SIZE // fallback to constant + private var modelNumDetections: Int = NUM_DETECTIONS // fallback to constant + private var modelNumClasses: Int = NUM_CLASSES // fallback to constant + private var modelOutputFeatures: Int = NMS_OUTPUT_FEATURES_PER_DETECTION // fallback to constant + // Shared thread pool for preprocessing operations (prevents creating new pools per detection) private val preprocessingExecutor = Executors.newFixedThreadPool(config.threadPoolSize) @@ -232,6 +260,9 @@ class YOLOInferenceEngine( ortSession = ortEnvironment?.createSession(model_path, session_options) ?: throw RuntimeException("Failed to create ONNX session") + // Extract model metadata dynamically + extractModelMetadata() + PGHLog.i(TAG, "✅ ONNX YOLO detector initialized successfully") isInitialized = true }.onError { errorType, exception, message -> @@ -511,7 +542,7 @@ class YOLOInferenceEngine( if (processed_mat.type() == CvType.CV_8UC3 || processed_mat.type() == CvType.CV_8UC1) { // Use Gaussian blur as a more reliable alternative to bilateral filter - Imgproc.GaussianBlur(processed_mat, denoised, Size(3.0, 3.0), 0.5) + Imgproc.GaussianBlur(processed_mat, denoised, Size(GAUSSIAN_BLUR_KERNEL_SIZE, GAUSSIAN_BLUR_KERNEL_SIZE), GAUSSIAN_BLUR_SIGMA) processed_mat.release() PGHLog.d(TAG, "✅ Ultralytics preprocessing complete with Gaussian smoothing") return denoised @@ -558,7 +589,7 @@ class YOLOInferenceEngine( inputMat.copyTo(gray) } - val clahe = Imgproc.createCLAHE(1.5, Size(8.0, 8.0)) + val clahe = Imgproc.createCLAHE(CLAHE_CLIP_LIMIT, Size(CLAHE_TILE_SIZE, CLAHE_TILE_SIZE)) clahe.apply(gray, enhanced_gray) // Convert back to color @@ -592,7 +623,7 @@ class YOLOInferenceEngine( { // Create sharpening kernel val kernel = Mat(3, 3, CvType.CV_32F) - kernel.put(0, 0, 0.0, -1.0, 0.0, -1.0, 5.0, -1.0, 0.0, -1.0, 0.0) + kernel.put(0, 0, 0.0, SHARPENING_EDGE_VALUE, 0.0, SHARPENING_EDGE_VALUE, SHARPENING_CENTER_VALUE, SHARPENING_EDGE_VALUE, 0.0, SHARPENING_EDGE_VALUE, 0.0) // Apply filter Imgproc.filter2D(inputMat, sharpened, -1, kernel) @@ -638,7 +669,7 @@ class YOLOInferenceEngine( Imgproc.resize(inputMat, resized, Size(new_width.toDouble(), new_height.toDouble()), 0.0, 0.0, Imgproc.INTER_CUBIC) // Create letterbox with padding - val letterboxed = Mat(targetHeight, targetWidth, inputMat.type(), Scalar(114.0, 114.0, 114.0)) // Gray padding + val letterboxed = Mat(targetHeight, targetWidth, inputMat.type(), Scalar(LETTERBOX_PADDING_GRAY, LETTERBOX_PADDING_GRAY, LETTERBOX_PADDING_GRAY)) // Gray padding // Calculate padding offsets val offset_x = (targetWidth - new_width) / 2 @@ -666,23 +697,23 @@ class YOLOInferenceEngine( try { // Create array format [1, 3, height, width] - val data = Array(1) { Array(NUM_CHANNELS) { Array(INPUT_SIZE) { FloatArray(INPUT_SIZE) } } } + val data = Array(1) { Array(NUM_CHANNELS) { Array(modelInputSize) { FloatArray(modelInputSize) } } } // Get RGB bytes - val rgb_bytes = ByteArray(INPUT_SIZE * INPUT_SIZE * 3) + val rgb_bytes = ByteArray(modelInputSize * modelInputSize * 3) rgb_mat.get(0, 0, rgb_bytes) // Convert HWC to CHW format and normalize for (c in 0 until NUM_CHANNELS) { - for (h in 0 until INPUT_SIZE) + for (h in 0 until modelInputSize) { - for (w in 0 until INPUT_SIZE) + for (w in 0 until modelInputSize) { - val pixel_idx = (h * INPUT_SIZE + w) * 3 + c + val pixel_idx = (h * modelInputSize + w) * 3 + c data[0][c][h][w] = if (pixel_idx < rgb_bytes.size) { - (rgb_bytes[pixel_idx].toInt() and 0xFF) / 255.0f + (rgb_bytes[pixel_idx].toInt() and COLOR_CHANNEL_MASK) / PIXEL_NORMALIZATION_FACTOR } else 0.0f } } @@ -702,7 +733,89 @@ class YOLOInferenceEngine( private fun postprocessResults(output: Array>, originalSize: Size): List { val flat_output = output[0].flatMap { it.asIterable() }.toFloatArray() - return parseNMSOutput(flat_output, originalSize.width.toInt(), originalSize.height.toInt(), INPUT_SIZE) + return parseNMSOutput(flat_output, originalSize.width.toInt(), originalSize.height.toInt(), modelInputSize) + } + + /** + * Extract model metadata dynamically from ONNX session + */ + private fun extractModelMetadata() + { + try + { + val session = ortSession ?: return + + // Get input info + val inputInfo = session.inputInfo + val inputShape = inputInfo.values.firstOrNull()?.info as? ai.onnxruntime.TensorInfo + if (inputShape != null) + { + val shape = inputShape.shape + if (shape.size >= 3) + { + // Typically YOLO input is [batch, channels, height, width] or [batch, height, width, channels] + val extractedInputSize = maxOf(shape[2].toInt(), shape[3].toInt()) // Take max of height/width + if (extractedInputSize > 0) + { + modelInputSize = extractedInputSize + PGHLog.i(TAG, "📐 Extracted input size from model: $modelInputSize") + } + } + } + + // Get output info + val outputInfo = session.outputInfo + val outputTensorInfo = outputInfo.values.firstOrNull()?.info as? ai.onnxruntime.TensorInfo + if (outputTensorInfo != null) + { + val outputShape = outputTensorInfo.shape + PGHLog.i(TAG, "📊 Model output shape: ${outputShape.contentToString()}") + + if (outputShape.size >= 2) + { + // For NMS output: typically [batch, num_detections, features_per_detection] + // For raw output: typically [batch, num_detections, 4+1+num_classes] or similar + val numDetections = outputShape[1].toInt() + val featuresPerDetection = outputShape[2].toInt() + + if (numDetections > 0) + { + modelNumDetections = numDetections + PGHLog.i(TAG, "🔢 Extracted num detections from model: $modelNumDetections") + } + + if (featuresPerDetection > 0) + { + modelOutputFeatures = featuresPerDetection + PGHLog.i(TAG, "📊 Extracted output features per detection: $modelOutputFeatures") + + // Try to infer number of classes from output features + // NMS output: [x1, y1, x2, y2, confidence, class_id] = 6 features + // Raw output: [x, y, w, h, confidence, class1, class2, ..., classN] = 5 + num_classes + if (featuresPerDetection == 6) + { + PGHLog.i(TAG, "🎯 Detected NMS post-processed output format") + } + else if (featuresPerDetection > 5) + { + val inferredNumClasses = featuresPerDetection - 5 // 4 coords + 1 confidence + if (inferredNumClasses > 0 && inferredNumClasses <= 1000) // Reasonable range + { + modelNumClasses = inferredNumClasses + PGHLog.i(TAG, "🏷️ Inferred num classes from output: $modelNumClasses") + } + } + } + } + } + + PGHLog.i(TAG, "📋 Final model metadata - Input: ${modelInputSize}x${modelInputSize}, Detections: $modelNumDetections, Features: $modelOutputFeatures, Classes: $modelNumClasses") + + } + catch (e: Exception) + { + PGHLog.w(TAG, "⚠️ Failed to extract model metadata, using fallback constants", e) + } } /** @@ -892,10 +1005,10 @@ class YOLOInferenceEngine( { val detections = mutableListOf() - val num_detections = 300 // From model output [1, 300, 6] - val features_per_detection = 6 // [x1, y1, x2, y2, confidence, class_id] + val num_detections = modelNumDetections + val features_per_detection = modelOutputFeatures - PGHLog.d(TAG, "🔍 Parsing NMS output: 300 post-processed detections") + PGHLog.d(TAG, "🔍 Parsing NMS output: $num_detections post-processed detections") var valid_detections = 0 @@ -905,17 +1018,17 @@ class YOLOInferenceEngine( // Extract and transform coordinates from model output val coords = transformCoordinates( - rawX1 = output[base_idx], - rawY1 = output[base_idx + 1], - rawX2 = output[base_idx + 2], - rawY2 = output[base_idx + 3], + rawX1 = output[base_idx + NMS_COORDINATE_X1_OFFSET], + rawY1 = output[base_idx + NMS_COORDINATE_Y1_OFFSET], + rawX2 = output[base_idx + NMS_COORDINATE_X2_OFFSET], + rawY2 = output[base_idx + NMS_COORDINATE_Y2_OFFSET], originalWidth = originalWidth, originalHeight = originalHeight, inputScale = inputScale ) - val confidence = output[base_idx + 4] - val class_id = output[base_idx + 5].toInt() + val confidence = output[base_idx + NMS_CONFIDENCE_OFFSET] + val class_id = output[base_idx + NMS_CLASS_ID_OFFSET].toInt() // Apply confidence mapping if enabled val mapped_confidence = if (ENABLE_CONFIDENCE_MAPPING) @@ -938,7 +1051,7 @@ class YOLOInferenceEngine( } // Debug logging for all detections if enabled - if (SHOW_ALL_CONFIDENCES && mapped_confidence > 0.1f) + if (SHOW_ALL_CONFIDENCES && mapped_confidence > MIN_DEBUG_CONFIDENCE) { PGHLog.d(TAG, "🔍 [DEBUG] Class: $class_name (ID: $class_id), Confidence: %.3f, Original: %.3f".format(mapped_confidence, confidence)) } @@ -969,7 +1082,7 @@ class YOLOInferenceEngine( valid_detections++ - if (valid_detections <= 3) + if (valid_detections <= MAX_DEBUG_DETECTIONS_TO_LOG) { PGHLog.d(TAG, "✅ Valid NMS detection: class=$class_id ($class_name), conf=${String.format("%.4f", mapped_confidence)}") }