package com.quillstudios.pokegoalshelper.ml import android.content.Context import android.graphics.Bitmap import android.graphics.BitmapFactory import android.util.Log import com.quillstudios.pokegoalshelper.utils.PGHLog import org.opencv.android.Utils import org.opencv.core.* import org.opencv.imgproc.Imgproc import ai.onnxruntime.* import java.io.FileOutputStream import java.io.IOException import java.util.concurrent.Executors import java.util.concurrent.Future import java.util.concurrent.TimeUnit import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import kotlin.math.max import kotlin.math.min /** * YOLO ONNX-based implementation of MLInferenceEngine. * Preserves ALL functionality from the original YOLOOnnxDetector including: * - Complete 96-class mapping * - Multiple preprocessing techniques * - Coordinate transformation modes * - Weighted NMS and TTA * - Debug and testing features */ class YOLOInferenceEngine( private val context: Context, private val config: YOLOConfig = YOLOConfig() ) : MLInferenceEngine { companion object { private const val TAG = "YOLOInferenceEngine" private const val MODEL_FILE = "best.onnx" private const val CONFIDENCE_THRESHOLD = 0.55f private const val NMS_THRESHOLD = 0.3f private const val NUM_CHANNELS = 3 // Enhanced accuracy settings for ONNX (fixed input size) private const val ENABLE_TTA = true // Test-time augmentation private const val MAX_INFERENCE_TIME_MS = 4500L // Leave 500ms for other processing // Coordinate transformation modes - HYBRID is the correct method var COORD_TRANSFORM_MODE = "HYBRID" // HYBRID and LETTERBOX work correctly // Class filtering for debugging var DEBUG_CLASS_FILTER: String? = null // Set to class name to show only that class var SHOW_ALL_CONFIDENCES = false // Show all detections with their confidences // Preprocessing enhancement techniques (single pass only) private const val ENABLE_NOISE_REDUCTION = true // Confidence threshold optimization for mobile ONNX vs raw processing private const val ENABLE_CONFIDENCE_MAPPING = true private const val RAW_TO_MOBILE_SCALE = 0.75f // Based on observation that mobile shows lower conf // Image Processing Constants private const val GAUSSIAN_BLUR_KERNEL_SIZE = 3.0 private const val GAUSSIAN_BLUR_SIGMA = 0.5 private const val CLAHE_CLIP_LIMIT = 1.5 private const val CLAHE_TILE_SIZE = 8.0 private const val SHARPENING_CENTER_VALUE = 5.0 private const val SHARPENING_EDGE_VALUE = -1.0 private const val LETTERBOX_PADDING_GRAY = 114.0 private const val PIXEL_NORMALIZATION_FACTOR = 255.0f private const val COLOR_CHANNEL_MASK = 0xFF // NMS Output Parsing Constants private const val NMS_OUTPUT_FEATURES_PER_DETECTION = 6 // [x1, y1, x2, y2, confidence, class_id] private const val NMS_COORDINATE_X1_OFFSET = 0 private const val NMS_COORDINATE_Y1_OFFSET = 1 private const val NMS_COORDINATE_X2_OFFSET = 2 private const val NMS_COORDINATE_Y2_OFFSET = 3 private const val NMS_CONFIDENCE_OFFSET = 4 private const val NMS_CLASS_ID_OFFSET = 5 private const val MIN_DEBUG_CONFIDENCE = 0.1f private const val MAX_DEBUG_DETECTIONS_TO_LOG = 3 fun setCoordinateMode(mode: String) { COORD_TRANSFORM_MODE = mode PGHLog.i(TAG, "๐Ÿ”ง Coordinate transform mode changed to: $mode") } fun toggleShowAllConfidences() { SHOW_ALL_CONFIDENCES = !SHOW_ALL_CONFIDENCES PGHLog.i(TAG, "๐Ÿ“Š Show all confidences: $SHOW_ALL_CONFIDENCES") } fun setClassFilter(className: String?) { DEBUG_CLASS_FILTER = className if (className != null) { PGHLog.i(TAG, "๐Ÿ” Class filter set to: '$className' (ID will be shown in debug output)") } else { PGHLog.i(TAG, "๐Ÿ” Class filter set to: ALL CLASSES") } } } private var ortSession: OrtSession? = null private var ortEnvironment: OrtEnvironment? = null private var isInitialized = false // Dynamic model metadata (extracted at runtime from ONNX model) private var modelInputSize: Int = 640 // Default fallback private var modelNumDetections: Int = 300 // Default fallback private var modelNumClasses: Int = 96 // Default fallback (based on dataset.yaml) private var modelOutputFeatures: Int = NMS_OUTPUT_FEATURES_PER_DETECTION // Default fallback // Shared thread pool for preprocessing operations (prevents creating new pools per detection) private val preprocessingExecutor = Executors.newFixedThreadPool(config.threadPoolSize) private var confidenceThreshold = config.confidenceThreshold private var classFilter: String? = config.classFilter // Performance tracking private var lastInferenceTime = 0L private var totalInferenceTime = 0L private var inferenceCount = 0L // Classification manager for class name mappings from dataset.yaml private lateinit var classificationManager: ClassificationManager override suspend fun initialize(): MLResult = withContext(Dispatchers.IO) { if (isInitialized) return@withContext MLResult.Success(Unit) mlTry(MLErrorType.INITIALIZATION_FAILED) { PGHLog.i(TAG, "๐Ÿค– Initializing ONNX YOLO detector...") // Initialize ONNX Runtime environment ortEnvironment = OrtEnvironment.getEnvironment() // Copy model from assets to internal storage PGHLog.i(TAG, "๐Ÿ“‚ Copying model file: ${config.modelFile}") val model_path = copyAssetToInternalStorage(config.modelFile) ?: throw RuntimeException("Failed to copy ONNX model from assets") PGHLog.i(TAG, "โœ… Model copied to: $model_path") // Create ONNX session PGHLog.i(TAG, "๐Ÿ“ฅ Loading ONNX model from: $model_path") val session_options = OrtSession.SessionOptions() session_options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT) ortSession = ortEnvironment?.createSession(model_path, session_options) ?: throw RuntimeException("Failed to create ONNX session") // Extract model metadata dynamically extractModelMetadata() // Initialize ClassificationManager classificationManager = ClassificationManager.getInstance(context) val classificationResult = classificationManager.initialize() if (classificationResult is MLResult.Error) { throw RuntimeException("Failed to initialize ClassificationManager: ${classificationResult.message}") } PGHLog.i(TAG, "โœ… ONNX YOLO detector initialized successfully") isInitialized = true }.onError { errorType, exception, message -> PGHLog.e(TAG, "โŒ $message", exception) } } override suspend fun detect(image: Bitmap): MLResult> = withContext(Dispatchers.IO) { if (image.isRecycled) { return@withContext mlError(MLErrorType.INVALID_INPUT, "Bitmap is recycled") } mlTry(MLErrorType.INFERENCE_FAILED) { // Convert Bitmap to Mat first to maintain compatibility with original logic val input_mat = Mat() Utils.bitmapToMat(image, input_mat) try { detectWithMat(input_mat) } finally { input_mat.release() } }.onError { errorType, exception, message -> PGHLog.e(TAG, "โŒ Detection failed: $message", exception) } } /** * Main detection method that preserves original Mat-based logic */ private suspend fun detectWithMat(inputMat: Mat): List { if (!isInitialized || ortSession == null) { PGHLog.w(TAG, "โš ๏ธ ONNX detector not initialized") return emptyList() } val start_time = System.currentTimeMillis() return try { PGHLog.d(TAG, "๐ŸŽฏ Starting ONNX YOLO detection...") // Single preprocessing method - ultralytics standard val detections = detectWithPreprocessing(inputMat) // Update performance stats lastInferenceTime = System.currentTimeMillis() - start_time totalInferenceTime += lastInferenceTime inferenceCount++ PGHLog.d(TAG, "๐ŸŽฏ Detection completed: ${detections.size} objects found in ${lastInferenceTime}ms") detections } catch (e: Exception) { PGHLog.e(TAG, "โŒ Error during enhanced ONNX YOLO detection", e) emptyList() } } private fun detectWithPreprocessing(inputMat: Mat): List { try { // Apply ultralytics preprocessing (only method used) val preprocessed_mat = preprocessUltralyticsStyle(inputMat) return try { runInference(preprocessed_mat, inputMat.size()) } finally { preprocessed_mat.release() } } catch (e: Exception) { PGHLog.e(TAG, "โŒ Error in preprocessing", e) return emptyList() } } private fun runInference(preprocessedMat: Mat, originalSize: Size): List { // Convert Mat to tensor format val input_data = matToTensorArray(preprocessedMat) // Run ONNX inference val input_name = ortSession!!.inputNames.iterator().next() val input_tensor = OnnxTensor.createTensor(ortEnvironment, input_data) val inputs = mapOf(input_name to input_tensor) return try { val results = ortSession!!.run(inputs) val output_tensor = results[0].value as Array> // Post-process results with coordinate transformation postprocessResults(output_tensor, originalSize) } finally { input_tensor.close() } } override fun setConfidenceThreshold(threshold: Float) { confidenceThreshold = threshold.coerceIn(0.0f, 1.0f) PGHLog.d(TAG, "๐ŸŽš๏ธ Confidence threshold set to: $confidenceThreshold") } override fun setClassFilter(className: String?) { classFilter = className DEBUG_CLASS_FILTER = className PGHLog.d(TAG, "๐Ÿ” Class filter set to: ${className ?: "none"}") } override fun isReady(): Boolean = isInitialized override fun getInferenceStats(): Pair { val average_time = if (inferenceCount > 0) totalInferenceTime / inferenceCount else 0L return Pair(lastInferenceTime, average_time) } override fun cleanup() { PGHLog.d(TAG, "๐Ÿงน Cleaning up YOLO Inference Engine") try { // Shutdown thread pool with grace period preprocessingExecutor.shutdown() if (!preprocessingExecutor.awaitTermination(2, TimeUnit.SECONDS)) { PGHLog.w(TAG, "โš ๏ธ Thread pool shutdown timeout, forcing shutdown") preprocessingExecutor.shutdownNow() } ortSession?.close() ortEnvironment?.close() } catch (e: Exception) { PGHLog.e(TAG, "Error during cleanup", e) } finally { ortSession = null ortEnvironment = null isInitialized = false } } // [Continue with all the preprocessing methods, coordinate transformations, NMS, etc.] // This is a partial implementation - I need to add all the missing methods private fun copyAssetToInternalStorage(fileName: String): String? { return try { val internal_file = context.getFileStreamPath(fileName) if (!internal_file.exists()) { context.assets.open(fileName).use { input_stream -> FileOutputStream(internal_file).use { output_stream -> input_stream.copyTo(output_stream) } } } internal_file.absolutePath } catch (e: IOException) { PGHLog.e(TAG, "Failed to copy asset to internal storage", e) null } } /** * Ultralytics-style preprocessing with letterbox resize and noise reduction */ private fun preprocessUltralyticsStyle(inputMat: Mat): Mat { try { PGHLog.d(TAG, "๐Ÿ”ง Ultralytics preprocessing: input ${inputMat.cols()}x${inputMat.rows()}, type=${inputMat.type()}") // Step 1: Letterbox resize (preserves aspect ratio with padding) val letterboxed = letterboxResize(inputMat, config.inputSize, config.inputSize) // Step 2: Apply slight noise reduction (Ultralytics uses this) val denoised = Mat() // Ensure proper BGR format val processed_mat = ensureBGRFormat(letterboxed) // Apply gentle smoothing (more reliable than bilateral filter) if (processed_mat.type() == CvType.CV_8UC3 || processed_mat.type() == CvType.CV_8UC1) { // Use Gaussian blur as a more reliable alternative to bilateral filter Imgproc.GaussianBlur(processed_mat, denoised, Size(GAUSSIAN_BLUR_KERNEL_SIZE, GAUSSIAN_BLUR_KERNEL_SIZE), GAUSSIAN_BLUR_SIGMA) processed_mat.release() PGHLog.d(TAG, "โœ… Ultralytics preprocessing complete with Gaussian smoothing") return denoised } else { PGHLog.w(TAG, "โš ๏ธ Smoothing skipped - unsupported image type: ${processed_mat.type()}") denoised.release() return processed_mat } } catch (e: Exception) { PGHLog.e(TAG, "โŒ Error in Ultralytics preprocessing", e) // Return a copy instead of the original to avoid memory issues val safe_copy = Mat() inputMat.copyTo(safe_copy) return safe_copy } } /** * Letterbox resize maintaining aspect ratio with gray padding */ private fun letterboxResize(inputMat: Mat, targetWidth: Int, targetHeight: Int): Mat { val original_height = inputMat.rows() val original_width = inputMat.cols() // Calculate scale to fit within target size while preserving aspect ratio val scale = minOf( targetWidth.toDouble() / original_width, targetHeight.toDouble() / original_height ) // Calculate new dimensions val new_width = (original_width * scale).toInt() val new_height = (original_height * scale).toInt() // Resize with high quality (similar to PIL LANCZOS) val resized = Mat() Imgproc.resize(inputMat, resized, Size(new_width.toDouble(), new_height.toDouble()), 0.0, 0.0, Imgproc.INTER_CUBIC) // Create letterbox with padding val letterboxed = Mat(targetHeight, targetWidth, inputMat.type(), Scalar(LETTERBOX_PADDING_GRAY, LETTERBOX_PADDING_GRAY, LETTERBOX_PADDING_GRAY)) // Gray padding // Calculate padding offsets val offset_x = (targetWidth - new_width) / 2 val offset_y = (targetHeight - new_height) / 2 // Copy resized image to center of letterboxed image val roi = Rect(offset_x, offset_y, new_width, new_height) val roi_mat = Mat(letterboxed, roi) resized.copyTo(roi_mat) resized.release() roi_mat.release() return letterboxed } /** * Convert Mat to tensor array format for ONNX inference */ private fun matToTensorArray(mat: Mat): Array>> { // Convert to RGB using utility method val rgb_mat = ensureRGBFormat(mat) try { // Create array format [1, 3, height, width] val data = Array(1) { Array(NUM_CHANNELS) { Array(modelInputSize) { FloatArray(modelInputSize) } } } // Get RGB bytes val rgb_bytes = ByteArray(modelInputSize * modelInputSize * 3) rgb_mat.get(0, 0, rgb_bytes) // Convert HWC to CHW format and normalize for (c in 0 until NUM_CHANNELS) { for (h in 0 until modelInputSize) { for (w in 0 until modelInputSize) { val pixel_idx = (h * modelInputSize + w) * 3 + c data[0][c][h][w] = if (pixel_idx < rgb_bytes.size) { (rgb_bytes[pixel_idx].toInt() and COLOR_CHANNEL_MASK) / PIXEL_NORMALIZATION_FACTOR } else 0.0f } } } return data } finally { rgb_mat.release() } } /** * Post-process ONNX model output to Detection objects */ private fun postprocessResults(output: Array>, originalSize: Size): List { val flat_output = output[0].flatMap { it.asIterable() }.toFloatArray() return parseNMSOutput(flat_output, originalSize.width.toInt(), originalSize.height.toInt(), modelInputSize) } /** * Extract model metadata dynamically from ONNX session */ private fun extractModelMetadata() { try { val session = ortSession ?: return // Get input info val inputInfo = session.inputInfo val inputShape = inputInfo.values.firstOrNull()?.info as? ai.onnxruntime.TensorInfo if (inputShape != null) { val shape = inputShape.shape if (shape.size >= 3) { // Typically YOLO input is [batch, channels, height, width] or [batch, height, width, channels] val extractedInputSize = maxOf(shape[2].toInt(), shape[3].toInt()) // Take max of height/width if (extractedInputSize > 0) { modelInputSize = extractedInputSize PGHLog.i(TAG, "๐Ÿ“ Extracted input size from model: $modelInputSize") } } } // Get output info val outputInfo = session.outputInfo val outputTensorInfo = outputInfo.values.firstOrNull()?.info as? ai.onnxruntime.TensorInfo if (outputTensorInfo != null) { val outputShape = outputTensorInfo.shape PGHLog.i(TAG, "๐Ÿ“Š Model output shape: ${outputShape.contentToString()}") if (outputShape.size >= 2) { // For NMS output: typically [batch, num_detections, features_per_detection] // For raw output: typically [batch, num_detections, 4+1+num_classes] or similar val numDetections = outputShape[1].toInt() val featuresPerDetection = outputShape[2].toInt() if (numDetections > 0) { modelNumDetections = numDetections PGHLog.i(TAG, "๐Ÿ”ข Extracted num detections from model: $modelNumDetections") } if (featuresPerDetection > 0) { modelOutputFeatures = featuresPerDetection PGHLog.i(TAG, "๐Ÿ“Š Extracted output features per detection: $modelOutputFeatures") // Try to infer number of classes from output features // NMS output: [x1, y1, x2, y2, confidence, class_id] = 6 features // Raw output: [x, y, w, h, confidence, class1, class2, ..., classN] = 5 + num_classes if (featuresPerDetection == 6) { PGHLog.i(TAG, "๐ŸŽฏ Detected NMS post-processed output format") } else if (featuresPerDetection > 5) { val inferredNumClasses = featuresPerDetection - 5 // 4 coords + 1 confidence if (inferredNumClasses > 0 && inferredNumClasses <= 1000) // Reasonable range { modelNumClasses = inferredNumClasses PGHLog.i(TAG, "๐Ÿท๏ธ Inferred num classes from output: $modelNumClasses") } } } } } PGHLog.i(TAG, "๐Ÿ“‹ Final model metadata - Input: ${modelInputSize}x${modelInputSize}, Detections: $modelNumDetections, Features: $modelOutputFeatures, Classes: $modelNumClasses") } catch (e: Exception) { PGHLog.w(TAG, "โš ๏ธ Failed to extract model metadata, using fallback constants", e) } } /** * Utility methods for common preprocessing operations */ /** * Safely convert Mat to BGR format (3-channel) if needed */ private fun ensureBGRFormat(inputMat: Mat): Mat { return when (inputMat.type()) { CvType.CV_8UC3 -> inputMat CvType.CV_8UC4 -> { val converted = Mat() Imgproc.cvtColor(inputMat, converted, Imgproc.COLOR_BGRA2BGR) converted } CvType.CV_8UC1 -> { val converted = Mat() Imgproc.cvtColor(inputMat, converted, Imgproc.COLOR_GRAY2BGR) converted } else -> { val converted = Mat() inputMat.convertTo(converted, CvType.CV_8UC3) converted } } } /** * Safely perform Mat operation with fallback */ private inline fun safeMatOperation( operation: () -> T, fallback: () -> T, errorMessage: String ): T { return try { operation() } catch (e: Exception) { PGHLog.e(TAG, "โŒ $errorMessage", e) fallback() } } /** * Utility function to ensure Mat is in RGB format for ONNX model input */ private fun ensureRGBFormat(inputMat: Mat): Mat { return when (inputMat.channels()) { 3 -> { val converted = Mat() Imgproc.cvtColor(inputMat, converted, Imgproc.COLOR_BGR2RGB) converted } 4 -> { val converted = Mat() Imgproc.cvtColor(inputMat, converted, Imgproc.COLOR_BGRA2RGB) converted } else -> inputMat // 1-channel or other formats pass through } } /** * Safe resize operation with fallback */ private fun safeResize(inputMat: Mat, targetSize: Size): Mat { return safeMatOperation( operation = { val resized = Mat() Imgproc.resize(inputMat, resized, targetSize) resized }, fallback = { val fallback = Mat() inputMat.copyTo(fallback) fallback }, errorMessage = "Error resizing image" ) } /** * Data class for transformed coordinates */ private data class TransformedCoordinates(val x1: Float, val y1: Float, val x2: Float, val y2: Float) /** * Transform coordinates from model output to original image space */ private fun transformCoordinates( rawX1: Float, rawY1: Float, rawX2: Float, rawY2: Float, originalWidth: Int, originalHeight: Int, inputScale: Int ): TransformedCoordinates { return when (COORD_TRANSFORM_MODE) { "LETTERBOX" -> { val letterbox_params = calculateLetterboxInverse(originalWidth, originalHeight, inputScale) val scale_x = letterbox_params[0] val scale_y = letterbox_params[1] val offset_x = letterbox_params[2] val offset_y = letterbox_params[3] TransformedCoordinates( x1 = (rawX1 - offset_x) * scale_x, y1 = (rawY1 - offset_y) * scale_y, x2 = (rawX2 - offset_x) * scale_x, y2 = (rawY2 - offset_y) * scale_y ) } "DIRECT" -> { val direct_scale_x = originalWidth.toFloat() / inputScale.toFloat() val direct_scale_y = originalHeight.toFloat() / inputScale.toFloat() TransformedCoordinates( x1 = rawX1 * direct_scale_x, y1 = rawY1 * direct_scale_y, x2 = rawX2 * direct_scale_x, y2 = rawY2 * direct_scale_y ) } "HYBRID" -> { val letterbox_params = calculateLetterboxInverse(originalWidth, originalHeight, inputScale) val offset_x = letterbox_params[2] val offset_y = letterbox_params[3] val scale = minOf(inputScale.toDouble() / originalWidth, inputScale.toDouble() / originalHeight) val scaled_width = (originalWidth * scale) val scaled_height = (originalHeight * scale) val hybrid_scale_x = originalWidth.toFloat() / scaled_width.toFloat() val hybrid_scale_y = originalHeight.toFloat() / scaled_height.toFloat() TransformedCoordinates( x1 = (rawX1 - offset_x) * hybrid_scale_x, y1 = (rawY1 - offset_y) * hybrid_scale_y, x2 = (rawX2 - offset_x) * hybrid_scale_x, y2 = (rawY2 - offset_y) * hybrid_scale_y ) } else -> { // Default to HYBRID mode for unknown coordinate modes val letterbox_params = calculateLetterboxInverse(originalWidth, originalHeight, inputScale) val offset_x = letterbox_params[2] val offset_y = letterbox_params[3] val scale = minOf(inputScale.toDouble() / originalWidth, inputScale.toDouble() / originalHeight) val scaled_width = (originalWidth * scale) val scaled_height = (originalHeight * scale) val hybrid_scale_x = originalWidth.toFloat() / scaled_width.toFloat() val hybrid_scale_y = originalHeight.toFloat() / scaled_height.toFloat() TransformedCoordinates( x1 = (rawX1 - offset_x) * hybrid_scale_x, y1 = (rawY1 - offset_y) * hybrid_scale_y, x2 = (rawX2 - offset_x) * hybrid_scale_x, y2 = (rawY2 - offset_y) * hybrid_scale_y ) } } } /** * Parse NMS (Non-Maximum Suppression) output format */ private fun parseNMSOutput(output: FloatArray, originalWidth: Int, originalHeight: Int, inputScale: Int): List { val detections = mutableListOf() val num_detections = modelNumDetections val features_per_detection = modelOutputFeatures PGHLog.d(TAG, "๐Ÿ” Parsing NMS output: $num_detections post-processed detections") var valid_detections = 0 for (i in 0 until num_detections) { val base_idx = i * features_per_detection // Extract and transform coordinates from model output val coords = transformCoordinates( rawX1 = output[base_idx + NMS_COORDINATE_X1_OFFSET], rawY1 = output[base_idx + NMS_COORDINATE_Y1_OFFSET], rawX2 = output[base_idx + NMS_COORDINATE_X2_OFFSET], rawY2 = output[base_idx + NMS_COORDINATE_Y2_OFFSET], originalWidth = originalWidth, originalHeight = originalHeight, inputScale = inputScale ) val confidence = output[base_idx + NMS_CONFIDENCE_OFFSET] val class_id = output[base_idx + NMS_CLASS_ID_OFFSET].toInt() // Apply confidence mapping if enabled val mapped_confidence = if (ENABLE_CONFIDENCE_MAPPING) { mapConfidenceForMobile(confidence) } else { confidence } // Get class name for filtering and debugging val class_name = classificationManager.getClassName(class_id) ?: "unknown_$class_id" // Debug logging for all detections if enabled if (SHOW_ALL_CONFIDENCES && mapped_confidence > MIN_DEBUG_CONFIDENCE) { PGHLog.d(TAG, "๐Ÿ” [DEBUG] Class: $class_name (ID: $class_id), Confidence: %.3f, Original: %.3f".format(mapped_confidence, confidence)) } // Apply class filter if set val passes_class_filter = DEBUG_CLASS_FILTER == null || DEBUG_CLASS_FILTER == class_name // Filter by confidence threshold, class filter, and validate coordinates if (mapped_confidence > confidenceThreshold && class_id >= 0 && class_id < classificationManager.getNumClasses() && passes_class_filter) { // Convert from corner coordinates (x1,y1,x2,y2) to BoundingBox format // Clamp coordinates to image boundaries val clamped_x1 = max(0.0f, min(coords.x1, originalWidth.toFloat())) val clamped_y1 = max(0.0f, min(coords.y1, originalHeight.toFloat())) val clamped_x2 = max(clamped_x1, min(coords.x2, originalWidth.toFloat())) val clamped_y2 = max(clamped_y1, min(coords.y2, originalHeight.toFloat())) // Validate bounding box dimensions and coordinates if (clamped_x2 > clamped_x1 && clamped_y2 > clamped_y1) { detections.add( Detection( className = class_name, confidence = mapped_confidence, boundingBox = BoundingBox(clamped_x1, clamped_y1, clamped_x2, clamped_y2) ) ) valid_detections++ if (valid_detections <= MAX_DEBUG_DETECTIONS_TO_LOG) { PGHLog.d(TAG, "โœ… Valid NMS detection: class=$class_id ($class_name), conf=${String.format("%.4f", mapped_confidence)}") } } } } PGHLog.d(TAG, "๐ŸŽฏ NMS parsing complete: $valid_detections valid detections") return detections.sortedByDescending { it.confidence } } /** * Calculate letterbox inverse transformation parameters */ private fun calculateLetterboxInverse(originalWidth: Int, originalHeight: Int, inputScale: Int): Array { // Calculate the scale that was used during letterbox resize val scale = minOf( inputScale.toDouble() / originalWidth, inputScale.toDouble() / originalHeight ) // Calculate the scaled dimensions (what the image became after resize but before padding) val scaled_width = (originalWidth * scale) val scaled_height = (originalHeight * scale) // Calculate padding offsets (in the 640x640 space) val offset_x = (inputScale - scaled_width) / 2.0 val offset_y = (inputScale - scaled_height) / 2.0 val scale_back_x = 1.0 / scale // Same for both X and Y since letterbox uses uniform scaling val scale_back_y = 1.0 / scale return arrayOf(scale_back_x.toFloat(), scale_back_y.toFloat(), offset_x.toFloat(), offset_y.toFloat()) } /** * Apply confidence mapping for mobile optimization */ private fun mapConfidenceForMobile(rawConfidence: Float): Float { // Apply scaling and optional curve adjustment var mapped = rawConfidence / RAW_TO_MOBILE_SCALE // Optional: Apply sigmoid-like curve to boost mid-range confidences mapped = (mapped * mapped) / (mapped * mapped + (1 - mapped) * (1 - mapped)) // Clamp to valid range return min(1.0f, max(0.0f, mapped)) } /** * Merge and filter detections using weighted NMS */ private fun mergeAndFilterDetections(allDetections: List): List { if (allDetections.isEmpty()) return emptyList() // First, apply NMS within each class val detections_by_class = allDetections.groupBy { detection -> // Map class name back to ID for grouping classificationManager.getAllClassNames().entries.find { it.value == detection.className }?.key ?: -1 } val class_nms_results = mutableListOf() for ((class_id, class_detections) in detections_by_class) { if (class_id >= 0) { val nms_results = applyWeightedNMS(class_detections, class_detections.first().className) class_nms_results.addAll(nms_results) } } // Then, apply cross-class NMS for semantically related classes val final_detections = applyCrossClassNMS(class_nms_results) return final_detections.sortedByDescending { it.confidence } } /** * Apply weighted NMS within a class */ private fun applyWeightedNMS(detections: List, className: String): List { val results = mutableListOf() if (detections.isEmpty()) return results // Sort by confidence val sorted_detections = detections.sortedByDescending { it.confidence } val suppressed = BooleanArray(sorted_detections.size) for (i in sorted_detections.indices) { if (suppressed[i]) continue var final_confidence = sorted_detections[i].confidence var final_box = sorted_detections[i].boundingBox val overlapping_boxes = mutableListOf>() overlapping_boxes.add(Pair(sorted_detections[i].boundingBox, sorted_detections[i].confidence)) // Find overlapping boxes and combine them with weighted averaging for (j in sorted_detections.indices) { if (i != j && !suppressed[j]) { val iou = calculateIoU(sorted_detections[i].boundingBox, sorted_detections[j].boundingBox) if (iou > NMS_THRESHOLD) { overlapping_boxes.add(Pair(sorted_detections[j].boundingBox, sorted_detections[j].confidence)) suppressed[j] = true } } } // If multiple overlapping boxes, use weighted average if (overlapping_boxes.size > 1) { val total_weight = overlapping_boxes.sumOf { it.second.toDouble() } val weighted_left = overlapping_boxes.sumOf { it.first.left * it.second.toDouble() } / total_weight val weighted_top = overlapping_boxes.sumOf { it.first.top * it.second.toDouble() } / total_weight val weighted_right = overlapping_boxes.sumOf { it.first.right * it.second.toDouble() } / total_weight val weighted_bottom = overlapping_boxes.sumOf { it.first.bottom * it.second.toDouble() } / total_weight final_box = BoundingBox( weighted_left.toFloat(), weighted_top.toFloat(), weighted_right.toFloat(), weighted_bottom.toFloat() ) final_confidence = (total_weight / overlapping_boxes.size).toFloat() PGHLog.d(TAG, "๐Ÿ”— Merged ${overlapping_boxes.size} overlapping detections, final conf: ${String.format("%.3f", final_confidence)}") } results.add( Detection( className = className, confidence = final_confidence, boundingBox = final_box ) ) } return results } /** * Apply cross-class NMS for semantically related classes */ private fun applyCrossClassNMS(detections: List): List { val result = mutableListOf() val suppressed = BooleanArray(detections.size) // Define semantically related class groups val level_related_classes = setOf("pokemon_level", "level_value", "digit", "number") val stat_related_classes = setOf("hp_value", "attack_value", "defense_value", "sp_atk_value", "sp_def_value", "speed_value") val text_related_classes = setOf("pokemon_nickname", "pokemon_species", "move_name", "ability_name", "nature_name") for (i in detections.indices) { if (suppressed[i]) continue val current_detection = detections[i] var best_detection = current_detection // Check for overlapping detections in related classes for (j in detections.indices) { if (i != j && !suppressed[j]) { val other_detection = detections[j] val iou = calculateIoU(current_detection.boundingBox, other_detection.boundingBox) // If highly overlapping, check if they're semantically related if (iou > 0.5f) // High overlap threshold for cross-class NMS { val are_related = areClassesRelated( current_detection.className, other_detection.className, level_related_classes, stat_related_classes, text_related_classes ) if (are_related) { // Keep the higher confidence detection if (other_detection.confidence > best_detection.confidence) { best_detection = other_detection } suppressed[j] = true PGHLog.d(TAG, "๐Ÿ”— Cross-class NMS: merged ${current_detection.className} with ${other_detection.className}") } } } } result.add(best_detection) } return result } /** * Check if two classes are semantically related */ private fun areClassesRelated( class1: String, class2: String, levelClasses: Set, statClasses: Set, textClasses: Set ): Boolean { return (levelClasses.contains(class1) && levelClasses.contains(class2)) || (statClasses.contains(class1) && statClasses.contains(class2)) || (textClasses.contains(class1) && textClasses.contains(class2)) } /** * Calculate Intersection over Union (IoU) for two bounding boxes */ private fun calculateIoU(box1: BoundingBox, box2: BoundingBox): Float { val x1 = max(box1.left, box2.left) val y1 = max(box1.top, box2.top) val x2 = min(box1.right, box2.right) val y2 = min(box1.bottom, box2.bottom) val intersection = max(0f, x2 - x1) * max(0f, y2 - y1) val area1 = box1.width * box1.height val area2 = box2.width * box2.height val union = area1 + area2 - intersection return if (union > 0) intersection / union else 0f } }