You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
225 lines
7.6 KiB
225 lines
7.6 KiB
#!/usr/bin/env python3
|
|
"""
|
|
Test ONNX model against static images to isolate OpenCV capture issues
|
|
This bypasses Android screen capture and tests pure ONNX inference
|
|
"""
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
import os
|
|
from pathlib import Path
|
|
|
|
# Force CPU-only execution to avoid CUDA compatibility issues
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
|
|
|
def letterbox_preprocess(img, target_size=(640, 640)):
|
|
"""Exact letterbox preprocessing matching Android implementation"""
|
|
h, w = img.shape[:2]
|
|
|
|
# Calculate scale factor
|
|
scale = min(target_size[0] / h, target_size[1] / w)
|
|
|
|
# Calculate new dimensions
|
|
new_w = int(w * scale)
|
|
new_h = int(h * scale)
|
|
|
|
# Resize image
|
|
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# Create padded image
|
|
padded = np.full((target_size[0], target_size[1], 3), 114, dtype=np.uint8)
|
|
|
|
# Calculate padding offsets
|
|
pad_x = (target_size[1] - new_w) // 2
|
|
pad_y = (target_size[0] - new_h) // 2
|
|
|
|
# Place resized image in center
|
|
padded[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = resized
|
|
|
|
return padded, scale, (pad_x, pad_y)
|
|
|
|
def test_onnx_static(model_path, image_path, confidence_threshold=0.01):
|
|
"""Test ONNX model on static image with detailed output"""
|
|
print(f"🔧 Testing ONNX model: {Path(model_path).name}")
|
|
print(f"📸 Image: {Path(image_path).name}")
|
|
|
|
# Load image
|
|
img = cv2.imread(str(image_path))
|
|
if img is None:
|
|
print(f"❌ Could not load image: {image_path}")
|
|
return None
|
|
|
|
print(f" Original image size: {img.shape}")
|
|
|
|
# Preprocess
|
|
processed_img, scale, padding = letterbox_preprocess(img)
|
|
print(f" Processed size: {processed_img.shape}")
|
|
print(f" Scale factor: {scale:.4f}")
|
|
print(f" Padding (x, y): {padding}")
|
|
|
|
# Convert for ONNX (RGB, normalize, CHW, batch)
|
|
img_rgb = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)
|
|
img_norm = img_rgb.astype(np.float32) / 255.0
|
|
img_chw = np.transpose(img_norm, (2, 0, 1))
|
|
img_batch = np.expand_dims(img_chw, axis=0)
|
|
|
|
print(f" Final tensor: {img_batch.shape}, range: [{img_batch.min():.3f}, {img_batch.max():.3f}]")
|
|
|
|
# Load ONNX model
|
|
try:
|
|
session = ort.InferenceSession(str(model_path))
|
|
input_name = session.get_inputs()[0].name
|
|
print(f" Model loaded, input: {input_name}")
|
|
except Exception as e:
|
|
print(f"❌ Failed to load ONNX model: {e}")
|
|
return None
|
|
|
|
# Run inference
|
|
try:
|
|
outputs = session.run(None, {input_name: img_batch})
|
|
print(f" Inference successful, {len(outputs)} outputs")
|
|
except Exception as e:
|
|
print(f"❌ Inference failed: {e}")
|
|
return None
|
|
|
|
# Process outputs
|
|
if len(outputs) == 0:
|
|
print(f"❌ No outputs from model")
|
|
return None
|
|
|
|
detections = outputs[0]
|
|
print(f" Detection tensor shape: {detections.shape}")
|
|
|
|
if len(detections.shape) != 3:
|
|
print(f"❌ Unexpected detection shape: {detections.shape}")
|
|
return None
|
|
|
|
batch_size, num_detections, num_values = detections.shape
|
|
detection_data = detections[0] # Remove batch dimension
|
|
|
|
if num_values == 6: # NMS format
|
|
print(f" Format: NMS (x, y, w, h, conf, class)")
|
|
|
|
# Filter by confidence
|
|
valid_mask = detection_data[:, 4] > confidence_threshold
|
|
valid_detections = detection_data[valid_mask]
|
|
|
|
print(f" Valid detections (conf > {confidence_threshold}): {len(valid_detections)}")
|
|
|
|
if len(valid_detections) == 0:
|
|
print(f" ❌ No detections above confidence threshold")
|
|
return []
|
|
|
|
# Analyze by class
|
|
classes = valid_detections[:, 5].astype(int)
|
|
confidences = valid_detections[:, 4]
|
|
|
|
class_counts = {}
|
|
for cls_id in classes:
|
|
class_counts[cls_id] = class_counts.get(cls_id, 0) + 1
|
|
|
|
print(f" Classes detected: {sorted(class_counts.keys())}")
|
|
|
|
# Focus on shiny icons (class 50)
|
|
shiny_mask = classes == 50
|
|
shiny_detections = valid_detections[shiny_mask]
|
|
|
|
if len(shiny_detections) > 0:
|
|
print(f" ✨ SHINY ICONS FOUND: {len(shiny_detections)}")
|
|
for i, det in enumerate(shiny_detections):
|
|
x, y, w, h, conf, cls = det
|
|
print(f" Shiny {i+1}: conf={conf:.6f}, box=[{x:.1f}, {y:.1f}, {w:.1f}, {h:.1f}]")
|
|
else:
|
|
print(f" ❌ NO SHINY ICONS (class 50) detected")
|
|
|
|
# Show top detections
|
|
if len(valid_detections) > 0:
|
|
# Sort by confidence
|
|
sorted_indices = np.argsort(confidences)[::-1]
|
|
top_detections = valid_detections[sorted_indices[:10]]
|
|
|
|
print(f" 🎯 Top 10 detections:")
|
|
for i, det in enumerate(top_detections):
|
|
x, y, w, h, conf, cls = det
|
|
print(f" {i+1}. Class {int(cls)}: conf={conf:.4f}, box=[{x:.1f}, {y:.1f}, {w:.1f}, {h:.1f}]")
|
|
|
|
return valid_detections
|
|
|
|
else:
|
|
print(f" ⚠️ Raw format detected ({num_values} values) - not processed")
|
|
return None
|
|
|
|
def test_multiple_models(image_path):
|
|
"""Test multiple ONNX models on the same image"""
|
|
print("="*80)
|
|
print("🔍 STATIC IMAGE ONNX TESTING")
|
|
print("="*80)
|
|
|
|
models_to_test = [
|
|
"app/src/main/assets/best.onnx",
|
|
"raw_models/exports/best_no_nms.onnx",
|
|
"raw_models/exports/best_nms_relaxed.onnx",
|
|
"raw_models/exports/best_nms_very_relaxed.onnx"
|
|
]
|
|
|
|
results = {}
|
|
|
|
for model_path in models_to_test:
|
|
if Path(model_path).exists():
|
|
print(f"\n{'='*60}")
|
|
detections = test_onnx_static(model_path, image_path)
|
|
results[model_path] = detections
|
|
else:
|
|
print(f"\n⚠️ Model not found: {model_path}")
|
|
results[model_path] = None
|
|
|
|
# Summary comparison
|
|
print(f"\n{'='*80}")
|
|
print("📊 COMPARISON SUMMARY")
|
|
print("="*80)
|
|
|
|
for model_path, detections in results.items():
|
|
model_name = Path(model_path).name
|
|
|
|
if detections is None:
|
|
print(f"❌ {model_name}: Failed or not found")
|
|
continue
|
|
|
|
if len(detections) == 0:
|
|
print(f"🔵 {model_name}: No detections")
|
|
continue
|
|
|
|
# Count shiny icons
|
|
classes = detections[:, 5].astype(int) if len(detections) > 0 else []
|
|
shiny_count = np.sum(classes == 50) if len(classes) > 0 else 0
|
|
total_count = len(detections)
|
|
|
|
print(f"✅ {model_name}: {total_count} total, {shiny_count} shiny icons")
|
|
|
|
print("="*80)
|
|
|
|
if __name__ == "__main__":
|
|
# Look for test images
|
|
test_image_candidates = [
|
|
"test_images/shiny_test.jpg",
|
|
"test_images/test.jpg",
|
|
"screenshots/shiny.jpg",
|
|
"screenshots/test.png"
|
|
]
|
|
|
|
test_image_found = None
|
|
for candidate in test_image_candidates:
|
|
if Path(candidate).exists():
|
|
test_image_found = candidate
|
|
break
|
|
|
|
if test_image_found:
|
|
print(f"🎯 Using test image: {test_image_found}")
|
|
test_multiple_models(test_image_found)
|
|
else:
|
|
print("❌ No test image found. Available options:")
|
|
for candidate in test_image_candidates:
|
|
print(f" {candidate}")
|
|
print("\nPlease provide a test image with shiny icon at one of these paths.")
|
|
print("You can use the debug_model_comparison.py script to capture a screenshot.")
|