diff --git a/app/build.gradle b/app/build.gradle index cfa99e2..1a215ed 100644 --- a/app/build.gradle +++ b/app/build.gradle @@ -76,6 +76,9 @@ dependencies { // ONNX Runtime implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.16.3' + // YAML parsing for dataset configuration + implementation 'org.yaml:snakeyaml:2.0' + // Material Design Components for Views (FloatingActionButton) implementation 'com.google.android.material:material:1.11.0' } \ No newline at end of file diff --git a/app/src/main/assets/dataset.yaml b/app/src/main/assets/dataset.yaml new file mode 100644 index 0000000..16ecfba --- /dev/null +++ b/app/src/main/assets/dataset.yaml @@ -0,0 +1,32 @@ +# dataset.yaml +train: /content/pokemon_yolov8_dataset/images/train +val: /content/pokemon_yolov8_dataset/images/val + +# Number of classes +nc: 96 + +# Class names (must be in the same order as LabelImg assigned them internally) +names: [ + 'ball_icon_pokeball', 'ball_icon_greatball', 'ball_icon_ultraball', 'ball_icon_masterball', + 'ball_icon_safariball', 'ball_icon_levelball', 'ball_icon_lureball', 'ball_icon_moonball', + 'ball_icon_friendball', 'ball_icon_loveball', 'ball_icon_heavyball', 'ball_icon_fastball', + 'ball_icon_sportball', 'ball_icon_premierball', 'ball_icon_repeatball', 'ball_icon_timerball', + 'ball_icon_nestball', 'ball_icon_netball', 'ball_icon_diveball', 'ball_icon_luxuryball', + 'ball_icon_healball', 'ball_icon_quickball', 'ball_icon_duskball', 'ball_icon_cherishball', + 'ball_icon_dreamball', 'ball_icon_beastball', 'ball_icon_strangeparts', 'ball_icon_parkball', + 'ball_icon_gsball', 'pokemon_nickname', 'gender_icon_male', 'gender_icon_female', + 'pokemon_level', 'language', 'last_game_stamp_home', 'last_game_stamp_lgp', + 'last_game_stamp_lge', 'last_game_stamp_sw', 'last_game_stamp_sh', 'last_game_stamp_bank', + 'last_game_stamp_bd', 'last_game_stamp_sp', 'last_game_stamp_pla', 'last_game_stamp_sc', + 'last_game_stamp_vi', 'last_game_stamp_go', 'national_dex_number', 'pokemon_species', + 'type_1', 'type_2', 'shiny_icon', 'origin_icon_vc', 'origin_icon_xyoras', + 'origin_icon_smusum', 'origin_icon_lg', 'origin_icon_swsh', 'origin_icon_go', + 'origin_icon_bdsp', 'origin_icon_pla', 'origin_icon_sv', 'pokerus_infected_icon', + 'pokerus_cured_icon', 'hp_value', 'attack_value', 'defense_value', 'sp_atk_value', + 'sp_def_value', 'speed_value', 'ability_name', 'nature_name', 'move_name', + 'original_trainer_name', 'original_trainder_number', 'alpha_mark', 'tera_water', + 'tera_psychic', 'tera_ice', 'tera_fairy', 'tera_poison', 'tera_ghost', 'ball_icon_originball', + 'tera_dragon', 'tera_steel', 'tera_grass', 'tera_normal', 'tera_fire', 'tera_electric', 'tera_fighting', + 'tera_ground', 'tera_flying', 'tera_bug', 'tera_rock', 'tera_dark', 'low_confidence', + 'ball_icon_pokeball_hisui', 'ball_icon_ultraball_husui' +] \ No newline at end of file diff --git a/app/src/main/java/com/quillstudios/pokegoalshelper/ml/ClassificationManager.kt b/app/src/main/java/com/quillstudios/pokegoalshelper/ml/ClassificationManager.kt new file mode 100644 index 0000000..50877a8 --- /dev/null +++ b/app/src/main/java/com/quillstudios/pokegoalshelper/ml/ClassificationManager.kt @@ -0,0 +1,140 @@ +package com.quillstudios.pokegoalshelper.ml + +import android.content.Context +import com.quillstudios.pokegoalshelper.utils.PGHLog +import org.yaml.snakeyaml.Yaml +import java.io.IOException + +/** + * Manages class name mappings from the training dataset.yaml file. + * This ensures the class names match exactly what was used during model training. + */ +class ClassificationManager private constructor(private val context: Context) +{ + companion object + { + private const val TAG = "ClassificationManager" + private const val DATASET_YAML_PATH = "dataset.yaml" + + @Volatile + private var INSTANCE: ClassificationManager? = null + + fun getInstance(context: Context): ClassificationManager + { + return INSTANCE ?: synchronized(this) + { + INSTANCE ?: ClassificationManager(context.applicationContext).also { INSTANCE = it } + } + } + } + + private var classNames: Map = emptyMap() + private var numClasses: Int = 0 + private var isInitialized: Boolean = false + + /** + * Initialize the classification manager by loading class names from dataset.yaml + */ + suspend fun initialize(): MLResult + { + return try + { + val yaml_content = loadDatasetYaml() + val parsed_data = parseDatasetYaml(yaml_content) + + numClasses = parsed_data.first + classNames = parsed_data.second + isInitialized = true + + PGHLog.i(TAG, "✅ Classification manager initialized with $numClasses classes") + PGHLog.d(TAG, "📋 Loaded classes: ${classNames.values.take(5)}...") + + MLResult.Success(Unit) + } + catch (e: Exception) + { + PGHLog.e(TAG, "❌ Failed to initialize classification manager: ${e.message}") + MLResult.Error(MLErrorType.INITIALIZATION_FAILED, e, "Failed to load class names from dataset.yaml") + } + } + + /** + * Get class name for a given class ID + */ + fun getClassName(classId: Int): String? + { + if (!isInitialized) + { + PGHLog.w(TAG, "⚠️ Classification manager not initialized, returning null for class $classId") + return null + } + + return classNames[classId] + } + + /** + * Get all class names as a map + */ + fun getAllClassNames(): Map + { + return if (isInitialized) classNames else emptyMap() + } + + /** + * Get the number of classes + */ + fun getNumClasses(): Int + { + return numClasses + } + + /** + * Check if the manager is properly initialized + */ + fun isInitialized(): Boolean + { + return isInitialized + } + + /** + * Load the dataset.yaml file from assets + */ + private fun loadDatasetYaml(): String + { + return context.assets.open(DATASET_YAML_PATH).use { input_stream -> + input_stream.bufferedReader().use { reader -> + reader.readText() + } + } + } + + /** + * Parse the dataset.yaml content and extract class information + */ + private fun parseDatasetYaml(yaml_content: String): Pair> + { + val yaml = Yaml() + val data = yaml.load>(yaml_content) + + // Extract number of classes + val nc = data["nc"] as? Int + ?: throw IllegalArgumentException("Missing or invalid 'nc' field in dataset.yaml") + + // Extract class names list + val names = data["names"] as? List<*> + ?: throw IllegalArgumentException("Missing or invalid 'names' field in dataset.yaml") + + // Convert to map with indices + val class_map = names.mapIndexed { index, name -> + index to (name as? String ?: throw IllegalArgumentException("Invalid class name at index $index")) + }.toMap() + + // Validate consistency + if (class_map.size != nc) + { + PGHLog.w(TAG, "⚠️ Class count mismatch: nc=$nc but found ${class_map.size} names") + } + + return Pair(nc, class_map) + } +} \ No newline at end of file diff --git a/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt b/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt index 8c5c862..5cc0a55 100644 --- a/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt +++ b/app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt @@ -135,105 +135,8 @@ class YOLOInferenceEngine( private var totalInferenceTime = 0L private var inferenceCount = 0L - // Complete class names mapping (96 classes) - EXACTLY as in original - private val classNames = mapOf( - 0 to "ball_icon_pokeball", - 1 to "ball_icon_greatball", - 2 to "ball_icon_ultraball", - 3 to "ball_icon_masterball", - 4 to "ball_icon_safariball", - 5 to "ball_icon_levelball", - 6 to "ball_icon_lureball", - 7 to "ball_icon_moonball", - 8 to "ball_icon_friendball", - 9 to "ball_icon_loveball", - 10 to "ball_icon_heavyball", - 11 to "ball_icon_fastball", - 12 to "ball_icon_sportball", - 13 to "ball_icon_premierball", - 14 to "ball_icon_repeatball", - 15 to "ball_icon_timerball", - 16 to "ball_icon_nestball", - 17 to "ball_icon_netball", - 18 to "ball_icon_diveball", - 19 to "ball_icon_luxuryball", - 20 to "ball_icon_healball", - 21 to "ball_icon_quickball", - 22 to "ball_icon_duskball", - 23 to "ball_icon_cherishball", - 24 to "ball_icon_dreamball", - 25 to "ball_icon_beastball", - 26 to "ball_icon_strangeparts", - 27 to "ball_icon_parkball", - 28 to "ball_icon_gsball", - 29 to "pokemon_nickname", - 30 to "gender_icon_male", - 31 to "gender_icon_female", - 32 to "pokemon_level", - 33 to "language", - 34 to "last_game_stamp_home", - 35 to "last_game_stamp_lgp", - 36 to "last_game_stamp_lge", - 37 to "last_game_stamp_sw", - 38 to "last_game_stamp_sh", - 39 to "last_game_stamp_bank", - 40 to "last_game_stamp_bd", - 41 to "last_game_stamp_sp", - 42 to "last_game_stamp_pla", - 43 to "last_game_stamp_sc", - 44 to "last_game_stamp_vi", - 45 to "last_game_stamp_go", - 46 to "national_dex_number", - 47 to "pokemon_species", - 48 to "type_1", - 49 to "type_2", - 50 to "shiny_icon", - 51 to "origin_icon_vc", - 52 to "origin_icon_xyoras", - 53 to "origin_icon_smusum", - 54 to "origin_icon_lg", - 55 to "origin_icon_swsh", - 56 to "origin_icon_go", - 57 to "origin_icon_bdsp", - 58 to "origin_icon_pla", - 59 to "origin_icon_sv", - 60 to "pokerus_infected_icon", - 61 to "pokerus_cured_icon", - 62 to "hp_value", - 63 to "attack_value", - 64 to "defense_value", - 65 to "sp_atk_value", - 66 to "sp_def_value", - 67 to "speed_value", - 68 to "ability_name", - 69 to "nature_name", - 70 to "move_name", - 71 to "original_trainer_name", - 72 to "original_trainder_number", - 73 to "alpha_mark", - 74 to "tera_water", - 75 to "tera_psychic", - 76 to "tera_ice", - 77 to "tera_fairy", - 78 to "tera_poison", - 79 to "tera_ghost", - 80 to "ball_icon_originball", - 81 to "tera_dragon", - 82 to "tera_steel", - 83 to "tera_grass", - 84 to "tera_normal", - 85 to "tera_fire", - 86 to "tera_electric", - 87 to "tera_fighting", - 88 to "tera_ground", - 89 to "tera_flying", - 90 to "tera_bug", - 91 to "tera_rock", - 92 to "tera_dark", - 93 to "low_confidence", - 94 to "ball_icon_pokeball_hisui", - 95 to "ball_icon_ultraball_husui" - ) + // Classification manager for class name mappings from dataset.yaml + private lateinit var classificationManager: ClassificationManager override suspend fun initialize(): MLResult = withContext(Dispatchers.IO) { @@ -263,6 +166,13 @@ class YOLOInferenceEngine( // Extract model metadata dynamically extractModelMetadata() + // Initialize ClassificationManager + classificationManager = ClassificationManager.getInstance(context) + val classificationResult = classificationManager.initialize() + if (classificationResult is MLResult.Error) { + throw RuntimeException("Failed to initialize ClassificationManager: ${classificationResult.message}") + } + PGHLog.i(TAG, "✅ ONNX YOLO detector initialized successfully") isInitialized = true }.onError { errorType, exception, message -> @@ -1041,14 +951,7 @@ class YOLOInferenceEngine( } // Get class name for filtering and debugging - val class_name = if (class_id >= 0 && class_id < classNames.size) - { - classNames[class_id] ?: "unknown_$class_id" - } - else - { - "unknown_$class_id" - } + val class_name = classificationManager.getClassName(class_id) ?: "unknown_$class_id" // Debug logging for all detections if enabled if (SHOW_ALL_CONFIDENCES && mapped_confidence > MIN_DEBUG_CONFIDENCE) @@ -1060,7 +963,7 @@ class YOLOInferenceEngine( val passes_class_filter = DEBUG_CLASS_FILTER == null || DEBUG_CLASS_FILTER == class_name // Filter by confidence threshold, class filter, and validate coordinates - if (mapped_confidence > confidenceThreshold && class_id >= 0 && class_id < classNames.size && passes_class_filter) + if (mapped_confidence > confidenceThreshold && class_id >= 0 && class_id < classificationManager.getNumClasses() && passes_class_filter) { // Convert from corner coordinates (x1,y1,x2,y2) to BoundingBox format // Clamp coordinates to image boundaries @@ -1144,7 +1047,7 @@ class YOLOInferenceEngine( // First, apply NMS within each class val detections_by_class = allDetections.groupBy { detection -> // Map class name back to ID for grouping - classNames.entries.find { it.value == detection.className }?.key ?: -1 + classificationManager.getAllClassNames().entries.find { it.value == detection.className }?.key ?: -1 } val class_nms_results = mutableListOf()