Browse Source

refactor: create dedicated classification module using dataset.yaml

- Add ClassificationManager singleton to handle class name mappings
- Load class names dynamically from training dataset.yaml (source of truth)
- Remove hardcoded 96-class mapping from YOLOInferenceEngine
- Add SnakeYAML dependency for YAML parsing
- Update all classNames references to use ClassificationManager
- Ensure consistency between model training data and runtime classification

Related todos: #create-classification-module

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
arch-002-ml-inference-engine
Quildra 5 months ago
parent
commit
aa4acf0025
  1. 3
      app/build.gradle
  2. 32
      app/src/main/assets/dataset.yaml
  3. 140
      app/src/main/java/com/quillstudios/pokegoalshelper/ml/ClassificationManager.kt
  4. 121
      app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt

3
app/build.gradle

@ -76,6 +76,9 @@ dependencies {
// ONNX Runtime // ONNX Runtime
implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.16.3' implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.16.3'
// YAML parsing for dataset configuration
implementation 'org.yaml:snakeyaml:2.0'
// Material Design Components for Views (FloatingActionButton) // Material Design Components for Views (FloatingActionButton)
implementation 'com.google.android.material:material:1.11.0' implementation 'com.google.android.material:material:1.11.0'
} }

32
app/src/main/assets/dataset.yaml

@ -0,0 +1,32 @@
# dataset.yaml
train: /content/pokemon_yolov8_dataset/images/train
val: /content/pokemon_yolov8_dataset/images/val
# Number of classes
nc: 96
# Class names (must be in the same order as LabelImg assigned them internally)
names: [
'ball_icon_pokeball', 'ball_icon_greatball', 'ball_icon_ultraball', 'ball_icon_masterball',
'ball_icon_safariball', 'ball_icon_levelball', 'ball_icon_lureball', 'ball_icon_moonball',
'ball_icon_friendball', 'ball_icon_loveball', 'ball_icon_heavyball', 'ball_icon_fastball',
'ball_icon_sportball', 'ball_icon_premierball', 'ball_icon_repeatball', 'ball_icon_timerball',
'ball_icon_nestball', 'ball_icon_netball', 'ball_icon_diveball', 'ball_icon_luxuryball',
'ball_icon_healball', 'ball_icon_quickball', 'ball_icon_duskball', 'ball_icon_cherishball',
'ball_icon_dreamball', 'ball_icon_beastball', 'ball_icon_strangeparts', 'ball_icon_parkball',
'ball_icon_gsball', 'pokemon_nickname', 'gender_icon_male', 'gender_icon_female',
'pokemon_level', 'language', 'last_game_stamp_home', 'last_game_stamp_lgp',
'last_game_stamp_lge', 'last_game_stamp_sw', 'last_game_stamp_sh', 'last_game_stamp_bank',
'last_game_stamp_bd', 'last_game_stamp_sp', 'last_game_stamp_pla', 'last_game_stamp_sc',
'last_game_stamp_vi', 'last_game_stamp_go', 'national_dex_number', 'pokemon_species',
'type_1', 'type_2', 'shiny_icon', 'origin_icon_vc', 'origin_icon_xyoras',
'origin_icon_smusum', 'origin_icon_lg', 'origin_icon_swsh', 'origin_icon_go',
'origin_icon_bdsp', 'origin_icon_pla', 'origin_icon_sv', 'pokerus_infected_icon',
'pokerus_cured_icon', 'hp_value', 'attack_value', 'defense_value', 'sp_atk_value',
'sp_def_value', 'speed_value', 'ability_name', 'nature_name', 'move_name',
'original_trainer_name', 'original_trainder_number', 'alpha_mark', 'tera_water',
'tera_psychic', 'tera_ice', 'tera_fairy', 'tera_poison', 'tera_ghost', 'ball_icon_originball',
'tera_dragon', 'tera_steel', 'tera_grass', 'tera_normal', 'tera_fire', 'tera_electric', 'tera_fighting',
'tera_ground', 'tera_flying', 'tera_bug', 'tera_rock', 'tera_dark', 'low_confidence',
'ball_icon_pokeball_hisui', 'ball_icon_ultraball_husui'
]

140
app/src/main/java/com/quillstudios/pokegoalshelper/ml/ClassificationManager.kt

@ -0,0 +1,140 @@
package com.quillstudios.pokegoalshelper.ml
import android.content.Context
import com.quillstudios.pokegoalshelper.utils.PGHLog
import org.yaml.snakeyaml.Yaml
import java.io.IOException
/**
* Manages class name mappings from the training dataset.yaml file.
* This ensures the class names match exactly what was used during model training.
*/
class ClassificationManager private constructor(private val context: Context)
{
companion object
{
private const val TAG = "ClassificationManager"
private const val DATASET_YAML_PATH = "dataset.yaml"
@Volatile
private var INSTANCE: ClassificationManager? = null
fun getInstance(context: Context): ClassificationManager
{
return INSTANCE ?: synchronized(this)
{
INSTANCE ?: ClassificationManager(context.applicationContext).also { INSTANCE = it }
}
}
}
private var classNames: Map<Int, String> = emptyMap()
private var numClasses: Int = 0
private var isInitialized: Boolean = false
/**
* Initialize the classification manager by loading class names from dataset.yaml
*/
suspend fun initialize(): MLResult<Unit>
{
return try
{
val yaml_content = loadDatasetYaml()
val parsed_data = parseDatasetYaml(yaml_content)
numClasses = parsed_data.first
classNames = parsed_data.second
isInitialized = true
PGHLog.i(TAG, "✅ Classification manager initialized with $numClasses classes")
PGHLog.d(TAG, "📋 Loaded classes: ${classNames.values.take(5)}...")
MLResult.Success(Unit)
}
catch (e: Exception)
{
PGHLog.e(TAG, "❌ Failed to initialize classification manager: ${e.message}")
MLResult.Error(MLErrorType.INITIALIZATION_FAILED, e, "Failed to load class names from dataset.yaml")
}
}
/**
* Get class name for a given class ID
*/
fun getClassName(classId: Int): String?
{
if (!isInitialized)
{
PGHLog.w(TAG, "⚠️ Classification manager not initialized, returning null for class $classId")
return null
}
return classNames[classId]
}
/**
* Get all class names as a map
*/
fun getAllClassNames(): Map<Int, String>
{
return if (isInitialized) classNames else emptyMap()
}
/**
* Get the number of classes
*/
fun getNumClasses(): Int
{
return numClasses
}
/**
* Check if the manager is properly initialized
*/
fun isInitialized(): Boolean
{
return isInitialized
}
/**
* Load the dataset.yaml file from assets
*/
private fun loadDatasetYaml(): String
{
return context.assets.open(DATASET_YAML_PATH).use { input_stream ->
input_stream.bufferedReader().use { reader ->
reader.readText()
}
}
}
/**
* Parse the dataset.yaml content and extract class information
*/
private fun parseDatasetYaml(yaml_content: String): Pair<Int, Map<Int, String>>
{
val yaml = Yaml()
val data = yaml.load<Map<String, Any>>(yaml_content)
// Extract number of classes
val nc = data["nc"] as? Int
?: throw IllegalArgumentException("Missing or invalid 'nc' field in dataset.yaml")
// Extract class names list
val names = data["names"] as? List<*>
?: throw IllegalArgumentException("Missing or invalid 'names' field in dataset.yaml")
// Convert to map with indices
val class_map = names.mapIndexed { index, name ->
index to (name as? String ?: throw IllegalArgumentException("Invalid class name at index $index"))
}.toMap()
// Validate consistency
if (class_map.size != nc)
{
PGHLog.w(TAG, "⚠️ Class count mismatch: nc=$nc but found ${class_map.size} names")
}
return Pair(nc, class_map)
}
}

121
app/src/main/java/com/quillstudios/pokegoalshelper/ml/YOLOInferenceEngine.kt

@ -135,105 +135,8 @@ class YOLOInferenceEngine(
private var totalInferenceTime = 0L private var totalInferenceTime = 0L
private var inferenceCount = 0L private var inferenceCount = 0L
// Complete class names mapping (96 classes) - EXACTLY as in original // Classification manager for class name mappings from dataset.yaml
private val classNames = mapOf( private lateinit var classificationManager: ClassificationManager
0 to "ball_icon_pokeball",
1 to "ball_icon_greatball",
2 to "ball_icon_ultraball",
3 to "ball_icon_masterball",
4 to "ball_icon_safariball",
5 to "ball_icon_levelball",
6 to "ball_icon_lureball",
7 to "ball_icon_moonball",
8 to "ball_icon_friendball",
9 to "ball_icon_loveball",
10 to "ball_icon_heavyball",
11 to "ball_icon_fastball",
12 to "ball_icon_sportball",
13 to "ball_icon_premierball",
14 to "ball_icon_repeatball",
15 to "ball_icon_timerball",
16 to "ball_icon_nestball",
17 to "ball_icon_netball",
18 to "ball_icon_diveball",
19 to "ball_icon_luxuryball",
20 to "ball_icon_healball",
21 to "ball_icon_quickball",
22 to "ball_icon_duskball",
23 to "ball_icon_cherishball",
24 to "ball_icon_dreamball",
25 to "ball_icon_beastball",
26 to "ball_icon_strangeparts",
27 to "ball_icon_parkball",
28 to "ball_icon_gsball",
29 to "pokemon_nickname",
30 to "gender_icon_male",
31 to "gender_icon_female",
32 to "pokemon_level",
33 to "language",
34 to "last_game_stamp_home",
35 to "last_game_stamp_lgp",
36 to "last_game_stamp_lge",
37 to "last_game_stamp_sw",
38 to "last_game_stamp_sh",
39 to "last_game_stamp_bank",
40 to "last_game_stamp_bd",
41 to "last_game_stamp_sp",
42 to "last_game_stamp_pla",
43 to "last_game_stamp_sc",
44 to "last_game_stamp_vi",
45 to "last_game_stamp_go",
46 to "national_dex_number",
47 to "pokemon_species",
48 to "type_1",
49 to "type_2",
50 to "shiny_icon",
51 to "origin_icon_vc",
52 to "origin_icon_xyoras",
53 to "origin_icon_smusum",
54 to "origin_icon_lg",
55 to "origin_icon_swsh",
56 to "origin_icon_go",
57 to "origin_icon_bdsp",
58 to "origin_icon_pla",
59 to "origin_icon_sv",
60 to "pokerus_infected_icon",
61 to "pokerus_cured_icon",
62 to "hp_value",
63 to "attack_value",
64 to "defense_value",
65 to "sp_atk_value",
66 to "sp_def_value",
67 to "speed_value",
68 to "ability_name",
69 to "nature_name",
70 to "move_name",
71 to "original_trainer_name",
72 to "original_trainder_number",
73 to "alpha_mark",
74 to "tera_water",
75 to "tera_psychic",
76 to "tera_ice",
77 to "tera_fairy",
78 to "tera_poison",
79 to "tera_ghost",
80 to "ball_icon_originball",
81 to "tera_dragon",
82 to "tera_steel",
83 to "tera_grass",
84 to "tera_normal",
85 to "tera_fire",
86 to "tera_electric",
87 to "tera_fighting",
88 to "tera_ground",
89 to "tera_flying",
90 to "tera_bug",
91 to "tera_rock",
92 to "tera_dark",
93 to "low_confidence",
94 to "ball_icon_pokeball_hisui",
95 to "ball_icon_ultraball_husui"
)
override suspend fun initialize(): MLResult<Unit> = withContext(Dispatchers.IO) override suspend fun initialize(): MLResult<Unit> = withContext(Dispatchers.IO)
{ {
@ -263,6 +166,13 @@ class YOLOInferenceEngine(
// Extract model metadata dynamically // Extract model metadata dynamically
extractModelMetadata() extractModelMetadata()
// Initialize ClassificationManager
classificationManager = ClassificationManager.getInstance(context)
val classificationResult = classificationManager.initialize()
if (classificationResult is MLResult.Error) {
throw RuntimeException("Failed to initialize ClassificationManager: ${classificationResult.message}")
}
PGHLog.i(TAG, "✅ ONNX YOLO detector initialized successfully") PGHLog.i(TAG, "✅ ONNX YOLO detector initialized successfully")
isInitialized = true isInitialized = true
}.onError { errorType, exception, message -> }.onError { errorType, exception, message ->
@ -1041,14 +951,7 @@ class YOLOInferenceEngine(
} }
// Get class name for filtering and debugging // Get class name for filtering and debugging
val class_name = if (class_id >= 0 && class_id < classNames.size) val class_name = classificationManager.getClassName(class_id) ?: "unknown_$class_id"
{
classNames[class_id] ?: "unknown_$class_id"
}
else
{
"unknown_$class_id"
}
// Debug logging for all detections if enabled // Debug logging for all detections if enabled
if (SHOW_ALL_CONFIDENCES && mapped_confidence > MIN_DEBUG_CONFIDENCE) if (SHOW_ALL_CONFIDENCES && mapped_confidence > MIN_DEBUG_CONFIDENCE)
@ -1060,7 +963,7 @@ class YOLOInferenceEngine(
val passes_class_filter = DEBUG_CLASS_FILTER == null || DEBUG_CLASS_FILTER == class_name val passes_class_filter = DEBUG_CLASS_FILTER == null || DEBUG_CLASS_FILTER == class_name
// Filter by confidence threshold, class filter, and validate coordinates // Filter by confidence threshold, class filter, and validate coordinates
if (mapped_confidence > confidenceThreshold && class_id >= 0 && class_id < classNames.size && passes_class_filter) if (mapped_confidence > confidenceThreshold && class_id >= 0 && class_id < classificationManager.getNumClasses() && passes_class_filter)
{ {
// Convert from corner coordinates (x1,y1,x2,y2) to BoundingBox format // Convert from corner coordinates (x1,y1,x2,y2) to BoundingBox format
// Clamp coordinates to image boundaries // Clamp coordinates to image boundaries
@ -1144,7 +1047,7 @@ class YOLOInferenceEngine(
// First, apply NMS within each class // First, apply NMS within each class
val detections_by_class = allDetections.groupBy { detection -> val detections_by_class = allDetections.groupBy { detection ->
// Map class name back to ID for grouping // Map class name back to ID for grouping
classNames.entries.find { it.value == detection.className }?.key ?: -1 classificationManager.getAllClassNames().entries.find { it.value == detection.className }?.key ?: -1
} }
val class_nms_results = mutableListOf<Detection>() val class_nms_results = mutableListOf<Detection>()

Loading…
Cancel
Save