40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
from ultralytics import YOLO
|
|
|
|
def train_yolo(data_yaml_path: str, model_size: str = "yolov8n.pt", epochs: int = 50):
|
|
"""
|
|
Trains a YOLOv8 model on a custom dataset.
|
|
|
|
Args:
|
|
data_yaml_path (str): Path to the dataset.yaml file.
|
|
model_size (str): Pre-trained model to start from (e.g., yolov8n.pt, yolov8s.pt).
|
|
epochs (int): Number of training epochs.
|
|
"""
|
|
print(f"Loading {model_size}...")
|
|
model = YOLO(model_size)
|
|
|
|
print(f"Starting training for {epochs} epochs using {data_yaml_path}...")
|
|
model.train(data=data_yaml_path, epochs=epochs, imgsz=640)
|
|
|
|
print("Training complete. Validating...")
|
|
metrics = model.val()
|
|
print(f"Validation metrics: {metrics}")
|
|
|
|
print("Exporting model...")
|
|
path = model.export(format="onnx")
|
|
print(f"Model exported to {path}")
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage:
|
|
# Ensure you have a data.yaml file configured for your dataset
|
|
# train_yolo("path/to/data.yaml")
|
|
|
|
# Use relative path from where user is running (backend folder)
|
|
# They are running from 'backend', so dataset is at '../datasets/...'
|
|
dataset_path = "../datasets/road_signs_potholes/data.yaml"
|
|
|
|
# Or absolute path if needed:
|
|
# dataset_path = "d:/Time-Pass-Projects/pothole-roadsign detection/datasets/road_signs_potholes/data.yaml"
|
|
|
|
print(f"Using dataset: {dataset_path}")
|
|
train_yolo(dataset_path, epochs=100) # Increased epochs for better results on small data
|