36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
from ultralytics import YOLO
|
|
import cv2
|
|
import numpy as np
|
|
|
|
class YOLOManager:
|
|
def __init__(self, model_path: str = "yolov8n.pt"):
|
|
"""
|
|
Initializes the YOLO model for inference.
|
|
|
|
Args:
|
|
model_path (str): Path to the trained YOLO model weights (.pt file).
|
|
"""
|
|
print(f"Loading YOLO model from {model_path}...")
|
|
self.model = YOLO(model_path)
|
|
|
|
def track(self, frame, conf: float = 0.25, iou: float = 0.5):
|
|
"""
|
|
Runs YOLO tracking on a single frame.
|
|
|
|
Args:
|
|
frame: Numpy array (image).
|
|
conf (float): Confidence threshold.
|
|
iou (float): IoU threshold.
|
|
|
|
Returns:
|
|
Results object from Ultralytics.
|
|
"""
|
|
# persist=True is crucial for tracking to work across frames
|
|
results = self.model.track(frame, persist=True, conf=conf, iou=iou, tracker="bytetrack.yaml", verbose=False)
|
|
return results[0]
|
|
|
|
def detect(self, frame):
|
|
"""Standard detection without tracking."""
|
|
results = self.model.predict(frame, verbose=False)
|
|
return results[0]
|