2026-02-09 12:50:10 +05:30

40 lines
1.5 KiB
Python

from ultralytics import YOLO
def train_yolo(data_yaml_path: str, model_size: str = "yolov8n.pt", epochs: int = 50):
"""
Trains a YOLOv8 model on a custom dataset.
Args:
data_yaml_path (str): Path to the dataset.yaml file.
model_size (str): Pre-trained model to start from (e.g., yolov8n.pt, yolov8s.pt).
epochs (int): Number of training epochs.
"""
print(f"Loading {model_size}...")
model = YOLO(model_size)
print(f"Starting training for {epochs} epochs using {data_yaml_path}...")
model.train(data=data_yaml_path, epochs=epochs, imgsz=640)
print("Training complete. Validating...")
metrics = model.val()
print(f"Validation metrics: {metrics}")
print("Exporting model...")
path = model.export(format="onnx")
print(f"Model exported to {path}")
if __name__ == "__main__":
# Example usage:
# Ensure you have a data.yaml file configured for your dataset
# train_yolo("path/to/data.yaml")
# Use relative path from where user is running (backend folder)
# They are running from 'backend', so dataset is at '../datasets/...'
dataset_path = "../datasets/road_signs_potholes/data.yaml"
# Or absolute path if needed:
# dataset_path = "d:/Time-Pass-Projects/pothole-roadsign detection/datasets/road_signs_potholes/data.yaml"
print(f"Using dataset: {dataset_path}")
train_yolo(dataset_path, epochs=100) # Increased epochs for better results on small data