63 lines
2.3 KiB
Python
63 lines
2.3 KiB
Python
from transformers import CLIPProcessor, CLIPModel
|
|
import torch
|
|
from PIL import Image
|
|
|
|
class CLIPManager:
|
|
def __init__(self, model_id: str = "openai/clip-vit-base-patch32"):
|
|
"""
|
|
Initializes the CLIP model and processor.
|
|
|
|
Args:
|
|
model_id (str): Hugging Face model ID.
|
|
"""
|
|
print(f"Loading CLIP model: {model_id}...")
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
self.model = CLIPModel.from_pretrained(model_id).to(self.device)
|
|
self.processor = CLIPProcessor.from_pretrained(model_id)
|
|
print(f"CLIP loaded on {self.device}.")
|
|
|
|
def classify_image(self, image: Image.Image, candidate_labels: list[str]):
|
|
"""
|
|
Classifies an image against a list of text labels.
|
|
|
|
Args:
|
|
image (PIL.Image): The cropped image to classify.
|
|
candidate_labels (list[str]): List of strings to compare against.
|
|
|
|
Returns:
|
|
dict: {label: score} sorted by confidence.
|
|
"""
|
|
if not candidate_labels:
|
|
return {}
|
|
|
|
inputs = self.processor(text=candidate_labels, images=image, return_tensors="pt", padding=True).to(self.device)
|
|
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs)
|
|
|
|
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
|
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
|
|
|
# Convert to dictionary
|
|
scores = probs.cpu().numpy()[0]
|
|
result = {label: float(score) for label, score in zip(candidate_labels, scores)}
|
|
|
|
# Sort by score descending
|
|
sorted_result = dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
|
|
return sorted_result
|
|
|
|
def get_best_match(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.5):
|
|
"""
|
|
Returns the single best match if it exceeds the threshold.
|
|
"""
|
|
results = self.classify_image(image, candidate_labels)
|
|
if not results:
|
|
return None, 0.0
|
|
|
|
best_label = list(results.keys())[0]
|
|
best_score = list(results.values())[0]
|
|
|
|
if best_score >= threshold:
|
|
return best_label, best_score
|
|
return "Uncertain", best_score
|