You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

61 lines
2.2 KiB

#!/usr/bin/env python3
"""
Inspect ONNX model structure to verify class mappings
"""
import onnx
import numpy as np
def inspect_onnx_model(model_path):
print(f"Inspecting ONNX model: {model_path}")
try:
# Load the model
model = onnx.load(model_path)
print(f"\n📋 Model Info:")
print(f"IR Version: {model.ir_version}")
print(f"Producer: {model.producer_name} {model.producer_version}")
# Check inputs
print(f"\n📥 Inputs:")
for input_info in model.graph.input:
print(f" {input_info.name}: {[d.dim_value for d in input_info.type.tensor_type.shape.dim]}")
# Check outputs
print(f"\n📤 Outputs:")
for output_info in model.graph.output:
shape = [d.dim_value for d in output_info.type.tensor_type.shape.dim]
print(f" {output_info.name}: {shape}")
# For NMS models, try to interpret the output format
if len(shape) == 3 and shape[2] == 6:
print(f" → NMS format: [batch, {shape[1]} detections, 6 values (x,y,w,h,conf,class)]")
elif len(shape) == 3 and shape[1] > 90:
print(f" → Raw format: [batch, {shape[1]} channels, {shape[2]} anchors]")
print(f" → Channels: 4 coords + {shape[1]-4} classes")
# Check for any metadata about classes
print(f"\n🏷️ Metadata:")
for prop in model.metadata_props:
print(f" {prop.key}: {prop.value}")
print(f"\n🔍 Model Summary: {len(model.graph.node)} nodes, {len(model.graph.initializer)} initializers")
except Exception as e:
print(f"❌ Error inspecting model: {e}")
if __name__ == "__main__":
models_to_check = [
"app/src/main/assets/best.onnx",
"raw_models/exports/best_no_nms.onnx",
"raw_models/exports/best_nms_relaxed.onnx",
"raw_models/exports/best_nms_very_relaxed.onnx"
]
for model_path in models_to_check:
try:
inspect_onnx_model(model_path)
print("\n" + "="*60 + "\n")
except FileNotFoundError:
print(f"⚠️ Model not found: {model_path}\n")