diff --git a/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt b/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt index 4294a11..02aa9c3 100644 --- a/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt +++ b/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt @@ -501,27 +501,8 @@ class YOLOInferenceEngine( // Step 2: Apply slight noise reduction (Ultralytics uses this) val denoised = Mat() - // Ensure proper format for bilateral filter - val processed_mat = when - { - letterboxed.type() == CvType.CV_8UC3 -> letterboxed - letterboxed.type() == CvType.CV_8UC4 -> - { - val converted = Mat() - Imgproc.cvtColor(letterboxed, converted, Imgproc.COLOR_BGRA2BGR) - letterboxed.release() - converted - } - letterboxed.type() == CvType.CV_8UC1 -> letterboxed - else -> - { - // Convert to 8-bit if needed - val converted = Mat() - letterboxed.convertTo(converted, CvType.CV_8UC3) - letterboxed.release() - converted - } - } + // Ensure proper BGR format + val processed_mat = ensureBGRFormat(letterboxed) // Apply gentle smoothing (more reliable than bilateral filter) if (processed_mat.type() == CvType.CV_8UC3 || processed_mat.type() == CvType.CV_8UC1) @@ -628,17 +609,7 @@ class YOLOInferenceEngine( */ private fun preprocessOriginalStyle(inputMat: Mat): Mat { - val resized = Mat() - try - { - Imgproc.resize(inputMat, resized, Size(config.inputSize.toDouble(), config.inputSize.toDouble())) - } - catch (e: Exception) - { - Log.e(TAG, "❌ Error in original preprocessing", e) - inputMat.copyTo(resized) - } - return resized + return safeResize(inputMat, Size(config.inputSize.toDouble(), config.inputSize.toDouble())) } /** @@ -686,20 +657,8 @@ class YOLOInferenceEngine( */ private fun matToTensorArray(mat: Mat): Array>> { - // Convert to RGB - val rgb_mat = Mat() - if (mat.channels() == 4) - { - Imgproc.cvtColor(mat, rgb_mat, Imgproc.COLOR_BGRA2RGB) - } - else if (mat.channels() == 3) - { - Imgproc.cvtColor(mat, rgb_mat, Imgproc.COLOR_BGR2RGB) - } - else - { - mat.copyTo(rgb_mat) - } + // Convert to RGB using utility method + val rgb_mat = ensureRGBFormat(mat) try { @@ -743,6 +702,102 @@ class YOLOInferenceEngine( return parseNMSOutput(flat_output, originalSize.width.toInt(), originalSize.height.toInt(), INPUT_SIZE) } + /** + * Utility methods for common preprocessing operations + */ + + /** + * Safely convert Mat to BGR format (3-channel) if needed + */ + private fun ensureBGRFormat(inputMat: Mat): Mat + { + return when (inputMat.type()) + { + CvType.CV_8UC3 -> inputMat + CvType.CV_8UC4 -> + { + val converted = Mat() + Imgproc.cvtColor(inputMat, converted, Imgproc.COLOR_BGRA2BGR) + converted + } + CvType.CV_8UC1 -> + { + val converted = Mat() + Imgproc.cvtColor(inputMat, converted, Imgproc.COLOR_GRAY2BGR) + converted + } + else -> + { + val converted = Mat() + inputMat.convertTo(converted, CvType.CV_8UC3) + converted + } + } + } + + /** + * Safely perform Mat operation with fallback + */ + private inline fun safeMatOperation( + operation: () -> T, + fallback: () -> T, + errorMessage: String + ): T + { + return try + { + operation() + } + catch (e: Exception) + { + Log.e(TAG, "❌ $errorMessage", e) + fallback() + } + } + + /** + * Utility function to ensure Mat is in RGB format for ONNX model input + */ + private fun ensureRGBFormat(inputMat: Mat): Mat + { + return when (inputMat.channels()) + { + 3 -> + { + val converted = Mat() + Imgproc.cvtColor(inputMat, converted, Imgproc.COLOR_BGR2RGB) + converted + } + 4 -> + { + val converted = Mat() + Imgproc.cvtColor(inputMat, converted, Imgproc.COLOR_BGRA2RGB) + converted + } + else -> inputMat // 1-channel or other formats pass through + } + } + + /** + * Safe resize operation with fallback + */ + private fun safeResize(inputMat: Mat, targetSize: Size): Mat + { + return safeMatOperation( + operation = { + val resized = Mat() + Imgproc.resize(inputMat, resized, targetSize) + resized + }, + fallback = { + val fallback = Mat() + inputMat.copyTo(fallback) + fallback + }, + errorMessage = "Error resizing image" + ) + } + /** * Data class for transformed coordinates */