Browse Source
- Add ClassificationManager singleton to handle class name mappings - Load class names dynamically from training dataset.yaml (source of truth) - Remove hardcoded 96-class mapping from YOLOInferenceEngine - Add SnakeYAML dependency for YAML parsing - Update all classNames references to use ClassificationManager - Ensure consistency between model training data and runtime classification Related todos: #create-classification-module 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>arch-002-ml-inference-engine
4 changed files with 187 additions and 109 deletions
@ -0,0 +1,32 @@ |
|||
# dataset.yaml |
|||
train: /content/pokemon_yolov8_dataset/images/train |
|||
val: /content/pokemon_yolov8_dataset/images/val |
|||
|
|||
# Number of classes |
|||
nc: 96 |
|||
|
|||
# Class names (must be in the same order as LabelImg assigned them internally) |
|||
names: [ |
|||
'ball_icon_pokeball', 'ball_icon_greatball', 'ball_icon_ultraball', 'ball_icon_masterball', |
|||
'ball_icon_safariball', 'ball_icon_levelball', 'ball_icon_lureball', 'ball_icon_moonball', |
|||
'ball_icon_friendball', 'ball_icon_loveball', 'ball_icon_heavyball', 'ball_icon_fastball', |
|||
'ball_icon_sportball', 'ball_icon_premierball', 'ball_icon_repeatball', 'ball_icon_timerball', |
|||
'ball_icon_nestball', 'ball_icon_netball', 'ball_icon_diveball', 'ball_icon_luxuryball', |
|||
'ball_icon_healball', 'ball_icon_quickball', 'ball_icon_duskball', 'ball_icon_cherishball', |
|||
'ball_icon_dreamball', 'ball_icon_beastball', 'ball_icon_strangeparts', 'ball_icon_parkball', |
|||
'ball_icon_gsball', 'pokemon_nickname', 'gender_icon_male', 'gender_icon_female', |
|||
'pokemon_level', 'language', 'last_game_stamp_home', 'last_game_stamp_lgp', |
|||
'last_game_stamp_lge', 'last_game_stamp_sw', 'last_game_stamp_sh', 'last_game_stamp_bank', |
|||
'last_game_stamp_bd', 'last_game_stamp_sp', 'last_game_stamp_pla', 'last_game_stamp_sc', |
|||
'last_game_stamp_vi', 'last_game_stamp_go', 'national_dex_number', 'pokemon_species', |
|||
'type_1', 'type_2', 'shiny_icon', 'origin_icon_vc', 'origin_icon_xyoras', |
|||
'origin_icon_smusum', 'origin_icon_lg', 'origin_icon_swsh', 'origin_icon_go', |
|||
'origin_icon_bdsp', 'origin_icon_pla', 'origin_icon_sv', 'pokerus_infected_icon', |
|||
'pokerus_cured_icon', 'hp_value', 'attack_value', 'defense_value', 'sp_atk_value', |
|||
'sp_def_value', 'speed_value', 'ability_name', 'nature_name', 'move_name', |
|||
'original_trainer_name', 'original_trainder_number', 'alpha_mark', 'tera_water', |
|||
'tera_psychic', 'tera_ice', 'tera_fairy', 'tera_poison', 'tera_ghost', 'ball_icon_originball', |
|||
'tera_dragon', 'tera_steel', 'tera_grass', 'tera_normal', 'tera_fire', 'tera_electric', 'tera_fighting', |
|||
'tera_ground', 'tera_flying', 'tera_bug', 'tera_rock', 'tera_dark', 'low_confidence', |
|||
'ball_icon_pokeball_hisui', 'ball_icon_ultraball_husui' |
|||
] |
|||
@ -0,0 +1,140 @@ |
|||
package com.quillstudios.pokegoalshelper.ml |
|||
|
|||
import android.content.Context |
|||
import com.quillstudios.pokegoalshelper.utils.PGHLog |
|||
import org.yaml.snakeyaml.Yaml |
|||
import java.io.IOException |
|||
|
|||
/** |
|||
* Manages class name mappings from the training dataset.yaml file. |
|||
* This ensures the class names match exactly what was used during model training. |
|||
*/ |
|||
class ClassificationManager private constructor(private val context: Context) |
|||
{ |
|||
companion object |
|||
{ |
|||
private const val TAG = "ClassificationManager" |
|||
private const val DATASET_YAML_PATH = "dataset.yaml" |
|||
|
|||
@Volatile |
|||
private var INSTANCE: ClassificationManager? = null |
|||
|
|||
fun getInstance(context: Context): ClassificationManager |
|||
{ |
|||
return INSTANCE ?: synchronized(this) |
|||
{ |
|||
INSTANCE ?: ClassificationManager(context.applicationContext).also { INSTANCE = it } |
|||
} |
|||
} |
|||
} |
|||
|
|||
private var classNames: Map<Int, String> = emptyMap() |
|||
private var numClasses: Int = 0 |
|||
private var isInitialized: Boolean = false |
|||
|
|||
/** |
|||
* Initialize the classification manager by loading class names from dataset.yaml |
|||
*/ |
|||
suspend fun initialize(): MLResult<Unit> |
|||
{ |
|||
return try |
|||
{ |
|||
val yaml_content = loadDatasetYaml() |
|||
val parsed_data = parseDatasetYaml(yaml_content) |
|||
|
|||
numClasses = parsed_data.first |
|||
classNames = parsed_data.second |
|||
isInitialized = true |
|||
|
|||
PGHLog.i(TAG, "✅ Classification manager initialized with $numClasses classes") |
|||
PGHLog.d(TAG, "📋 Loaded classes: ${classNames.values.take(5)}...") |
|||
|
|||
MLResult.Success(Unit) |
|||
} |
|||
catch (e: Exception) |
|||
{ |
|||
PGHLog.e(TAG, "❌ Failed to initialize classification manager: ${e.message}") |
|||
MLResult.Error(MLErrorType.INITIALIZATION_FAILED, e, "Failed to load class names from dataset.yaml") |
|||
} |
|||
} |
|||
|
|||
/** |
|||
* Get class name for a given class ID |
|||
*/ |
|||
fun getClassName(classId: Int): String? |
|||
{ |
|||
if (!isInitialized) |
|||
{ |
|||
PGHLog.w(TAG, "⚠️ Classification manager not initialized, returning null for class $classId") |
|||
return null |
|||
} |
|||
|
|||
return classNames[classId] |
|||
} |
|||
|
|||
/** |
|||
* Get all class names as a map |
|||
*/ |
|||
fun getAllClassNames(): Map<Int, String> |
|||
{ |
|||
return if (isInitialized) classNames else emptyMap() |
|||
} |
|||
|
|||
/** |
|||
* Get the number of classes |
|||
*/ |
|||
fun getNumClasses(): Int |
|||
{ |
|||
return numClasses |
|||
} |
|||
|
|||
/** |
|||
* Check if the manager is properly initialized |
|||
*/ |
|||
fun isInitialized(): Boolean |
|||
{ |
|||
return isInitialized |
|||
} |
|||
|
|||
/** |
|||
* Load the dataset.yaml file from assets |
|||
*/ |
|||
private fun loadDatasetYaml(): String |
|||
{ |
|||
return context.assets.open(DATASET_YAML_PATH).use { input_stream -> |
|||
input_stream.bufferedReader().use { reader -> |
|||
reader.readText() |
|||
} |
|||
} |
|||
} |
|||
|
|||
/** |
|||
* Parse the dataset.yaml content and extract class information |
|||
*/ |
|||
private fun parseDatasetYaml(yaml_content: String): Pair<Int, Map<Int, String>> |
|||
{ |
|||
val yaml = Yaml() |
|||
val data = yaml.load<Map<String, Any>>(yaml_content) |
|||
|
|||
// Extract number of classes |
|||
val nc = data["nc"] as? Int |
|||
?: throw IllegalArgumentException("Missing or invalid 'nc' field in dataset.yaml") |
|||
|
|||
// Extract class names list |
|||
val names = data["names"] as? List<*> |
|||
?: throw IllegalArgumentException("Missing or invalid 'names' field in dataset.yaml") |
|||
|
|||
// Convert to map with indices |
|||
val class_map = names.mapIndexed { index, name -> |
|||
index to (name as? String ?: throw IllegalArgumentException("Invalid class name at index $index")) |
|||
}.toMap() |
|||
|
|||
// Validate consistency |
|||
if (class_map.size != nc) |
|||
{ |
|||
PGHLog.w(TAG, "⚠️ Class count mismatch: nc=$nc but found ${class_map.size} names") |
|||
} |
|||
|
|||
return Pair(nc, class_map) |
|||
} |
|||
} |
|||
Loading…
Reference in new issue