72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
from backend.models.yolo_manager import YOLOManager
|
|
import cv2
|
|
import os
|
|
|
|
def test_model(model_path="backend/models/best.pt", source="test_video.mp4"):
|
|
"""
|
|
Tests the YOLO model on a video or image.
|
|
"""
|
|
if not os.path.exists(model_path):
|
|
print(f"Model not found at {model_path}. Using standard yolov8n.pt for demo.")
|
|
model_path = "yolov8n.pt"
|
|
|
|
yolo = YOLOManager(model_path)
|
|
|
|
# Check if source is image or video
|
|
ext = os.path.splitext(source)[1].lower()
|
|
if ext in ['.jpg', '.jpeg', '.png', '.bmp']:
|
|
frame = cv2.imread(source)
|
|
if frame is None:
|
|
print(f"Could not read image: {source}")
|
|
return
|
|
|
|
results = yolo.detect(frame)
|
|
res_plotted = results.plot()
|
|
cv2.imshow("YOLO Detection", res_plotted)
|
|
cv2.waitKey(0)
|
|
cv2.destroyAllWindows()
|
|
|
|
else:
|
|
# Video
|
|
cap = cv2.VideoCapture(source)
|
|
if not cap.isOpened():
|
|
print(f"Could not open video: {source}")
|
|
return
|
|
|
|
print("Press 'q' to exit.")
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
# Use 'track' or 'detect'
|
|
results = yolo.track(frame)
|
|
|
|
# Plot results on frame
|
|
annotated_frame = results.plot()
|
|
|
|
cv2.imshow("YOLO Tracking", annotated_frame)
|
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
break
|
|
|
|
cap.release()
|
|
cv2.destroyAllWindows()
|
|
|
|
if __name__ == "__main__":
|
|
# CHANGE THIS to your test file
|
|
TEST_FILE = "d:/path/to/your/test/video_or_image.jpg"
|
|
|
|
if not os.path.exists(TEST_FILE):
|
|
if TEST_FILE == "0":
|
|
# Webcam
|
|
test_model(source=0)
|
|
else:
|
|
print(f"File {TEST_FILE} not found.")
|
|
TEST_FILE = input("Enter path to image/video (or 0 for webcam): ").strip('"')
|
|
if TEST_FILE == "0":
|
|
test_model(source=0)
|
|
else:
|
|
test_model(source=TEST_FILE)
|
|
else:
|
|
test_model(source=TEST_FILE)
|