You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
2.2 KiB
61 lines
2.2 KiB
|
5 months ago
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Inspect ONNX model structure to verify class mappings
|
||
|
|
"""
|
||
|
|
|
||
|
|
import onnx
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
def inspect_onnx_model(model_path):
|
||
|
|
print(f"Inspecting ONNX model: {model_path}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Load the model
|
||
|
|
model = onnx.load(model_path)
|
||
|
|
|
||
|
|
print(f"\n📋 Model Info:")
|
||
|
|
print(f"IR Version: {model.ir_version}")
|
||
|
|
print(f"Producer: {model.producer_name} {model.producer_version}")
|
||
|
|
|
||
|
|
# Check inputs
|
||
|
|
print(f"\n📥 Inputs:")
|
||
|
|
for input_info in model.graph.input:
|
||
|
|
print(f" {input_info.name}: {[d.dim_value for d in input_info.type.tensor_type.shape.dim]}")
|
||
|
|
|
||
|
|
# Check outputs
|
||
|
|
print(f"\n📤 Outputs:")
|
||
|
|
for output_info in model.graph.output:
|
||
|
|
shape = [d.dim_value for d in output_info.type.tensor_type.shape.dim]
|
||
|
|
print(f" {output_info.name}: {shape}")
|
||
|
|
|
||
|
|
# For NMS models, try to interpret the output format
|
||
|
|
if len(shape) == 3 and shape[2] == 6:
|
||
|
|
print(f" → NMS format: [batch, {shape[1]} detections, 6 values (x,y,w,h,conf,class)]")
|
||
|
|
elif len(shape) == 3 and shape[1] > 90:
|
||
|
|
print(f" → Raw format: [batch, {shape[1]} channels, {shape[2]} anchors]")
|
||
|
|
print(f" → Channels: 4 coords + {shape[1]-4} classes")
|
||
|
|
|
||
|
|
# Check for any metadata about classes
|
||
|
|
print(f"\n🏷️ Metadata:")
|
||
|
|
for prop in model.metadata_props:
|
||
|
|
print(f" {prop.key}: {prop.value}")
|
||
|
|
|
||
|
|
print(f"\n🔍 Model Summary: {len(model.graph.node)} nodes, {len(model.graph.initializer)} initializers")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ Error inspecting model: {e}")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
models_to_check = [
|
||
|
|
"app/src/main/assets/best.onnx",
|
||
|
|
"raw_models/exports/best_no_nms.onnx",
|
||
|
|
"raw_models/exports/best_nms_relaxed.onnx",
|
||
|
|
"raw_models/exports/best_nms_very_relaxed.onnx"
|
||
|
|
]
|
||
|
|
|
||
|
|
for model_path in models_to_check:
|
||
|
|
try:
|
||
|
|
inspect_onnx_model(model_path)
|
||
|
|
print("\n" + "="*60 + "\n")
|
||
|
|
except FileNotFoundError:
|
||
|
|
print(f"⚠️ Model not found: {model_path}\n")
|