You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
7.6 KiB

#!/usr/bin/env python3
"""
Test ONNX model against static images to isolate OpenCV capture issues
This bypasses Android screen capture and tests pure ONNX inference
"""
import cv2
import numpy as np
import onnxruntime as ort
import os
from pathlib import Path
# Force CPU-only execution to avoid CUDA compatibility issues
os.environ['CUDA_VISIBLE_DEVICES'] = ''
def letterbox_preprocess(img, target_size=(640, 640)):
"""Exact letterbox preprocessing matching Android implementation"""
h, w = img.shape[:2]
# Calculate scale factor
scale = min(target_size[0] / h, target_size[1] / w)
# Calculate new dimensions
new_w = int(w * scale)
new_h = int(h * scale)
# Resize image
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
# Create padded image
padded = np.full((target_size[0], target_size[1], 3), 114, dtype=np.uint8)
# Calculate padding offsets
pad_x = (target_size[1] - new_w) // 2
pad_y = (target_size[0] - new_h) // 2
# Place resized image in center
padded[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = resized
return padded, scale, (pad_x, pad_y)
def test_onnx_static(model_path, image_path, confidence_threshold=0.01):
"""Test ONNX model on static image with detailed output"""
print(f"🔧 Testing ONNX model: {Path(model_path).name}")
print(f"📸 Image: {Path(image_path).name}")
# Load image
img = cv2.imread(str(image_path))
if img is None:
print(f"❌ Could not load image: {image_path}")
return None
print(f" Original image size: {img.shape}")
# Preprocess
processed_img, scale, padding = letterbox_preprocess(img)
print(f" Processed size: {processed_img.shape}")
print(f" Scale factor: {scale:.4f}")
print(f" Padding (x, y): {padding}")
# Convert for ONNX (RGB, normalize, CHW, batch)
img_rgb = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)
img_norm = img_rgb.astype(np.float32) / 255.0
img_chw = np.transpose(img_norm, (2, 0, 1))
img_batch = np.expand_dims(img_chw, axis=0)
print(f" Final tensor: {img_batch.shape}, range: [{img_batch.min():.3f}, {img_batch.max():.3f}]")
# Load ONNX model
try:
session = ort.InferenceSession(str(model_path))
input_name = session.get_inputs()[0].name
print(f" Model loaded, input: {input_name}")
except Exception as e:
print(f"❌ Failed to load ONNX model: {e}")
return None
# Run inference
try:
outputs = session.run(None, {input_name: img_batch})
print(f" Inference successful, {len(outputs)} outputs")
except Exception as e:
print(f"❌ Inference failed: {e}")
return None
# Process outputs
if len(outputs) == 0:
print(f"❌ No outputs from model")
return None
detections = outputs[0]
print(f" Detection tensor shape: {detections.shape}")
if len(detections.shape) != 3:
print(f"❌ Unexpected detection shape: {detections.shape}")
return None
batch_size, num_detections, num_values = detections.shape
detection_data = detections[0] # Remove batch dimension
if num_values == 6: # NMS format
print(f" Format: NMS (x, y, w, h, conf, class)")
# Filter by confidence
valid_mask = detection_data[:, 4] > confidence_threshold
valid_detections = detection_data[valid_mask]
print(f" Valid detections (conf > {confidence_threshold}): {len(valid_detections)}")
if len(valid_detections) == 0:
print(f" ❌ No detections above confidence threshold")
return []
# Analyze by class
classes = valid_detections[:, 5].astype(int)
confidences = valid_detections[:, 4]
class_counts = {}
for cls_id in classes:
class_counts[cls_id] = class_counts.get(cls_id, 0) + 1
print(f" Classes detected: {sorted(class_counts.keys())}")
# Focus on shiny icons (class 50)
shiny_mask = classes == 50
shiny_detections = valid_detections[shiny_mask]
if len(shiny_detections) > 0:
print(f" ✨ SHINY ICONS FOUND: {len(shiny_detections)}")
for i, det in enumerate(shiny_detections):
x, y, w, h, conf, cls = det
print(f" Shiny {i+1}: conf={conf:.6f}, box=[{x:.1f}, {y:.1f}, {w:.1f}, {h:.1f}]")
else:
print(f" ❌ NO SHINY ICONS (class 50) detected")
# Show top detections
if len(valid_detections) > 0:
# Sort by confidence
sorted_indices = np.argsort(confidences)[::-1]
top_detections = valid_detections[sorted_indices[:10]]
print(f" 🎯 Top 10 detections:")
for i, det in enumerate(top_detections):
x, y, w, h, conf, cls = det
print(f" {i+1}. Class {int(cls)}: conf={conf:.4f}, box=[{x:.1f}, {y:.1f}, {w:.1f}, {h:.1f}]")
return valid_detections
else:
print(f" ⚠️ Raw format detected ({num_values} values) - not processed")
return None
def test_multiple_models(image_path):
"""Test multiple ONNX models on the same image"""
print("="*80)
print("🔍 STATIC IMAGE ONNX TESTING")
print("="*80)
models_to_test = [
"app/src/main/assets/best.onnx",
"raw_models/exports/best_no_nms.onnx",
"raw_models/exports/best_nms_relaxed.onnx",
"raw_models/exports/best_nms_very_relaxed.onnx"
]
results = {}
for model_path in models_to_test:
if Path(model_path).exists():
print(f"\n{'='*60}")
detections = test_onnx_static(model_path, image_path)
results[model_path] = detections
else:
print(f"\n⚠️ Model not found: {model_path}")
results[model_path] = None
# Summary comparison
print(f"\n{'='*80}")
print("📊 COMPARISON SUMMARY")
print("="*80)
for model_path, detections in results.items():
model_name = Path(model_path).name
if detections is None:
print(f"{model_name}: Failed or not found")
continue
if len(detections) == 0:
print(f"🔵 {model_name}: No detections")
continue
# Count shiny icons
classes = detections[:, 5].astype(int) if len(detections) > 0 else []
shiny_count = np.sum(classes == 50) if len(classes) > 0 else 0
total_count = len(detections)
print(f"{model_name}: {total_count} total, {shiny_count} shiny icons")
print("="*80)
if __name__ == "__main__":
# Look for test images
test_image_candidates = [
"test_images/shiny_test.jpg",
"test_images/test.jpg",
"screenshots/shiny.jpg",
"screenshots/test.png"
]
test_image_found = None
for candidate in test_image_candidates:
if Path(candidate).exists():
test_image_found = candidate
break
if test_image_found:
print(f"🎯 Using test image: {test_image_found}")
test_multiple_models(test_image_found)
else:
print("❌ No test image found. Available options:")
for candidate in test_image_candidates:
print(f" {candidate}")
print("\nPlease provide a test image with shiny icon at one of these paths.")
print("You can use the debug_model_comparison.py script to capture a screenshot.")