You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

68 lines
2.8 KiB

import faiss
import numpy as np
import pickle
from pathlib import Path
from typing import List, Optional, Tuple
class FaceStore:
def __init__(self, dimension: int = 512): # 512 for ArcFace
self.dimension = dimension
# Use cosine similarity instead of L2 distance
self.index = faiss.IndexFlatIP(dimension) # Inner Product = Cosine similarity for normalized vectors
self.face_data = []
self.store_path = Path("face_store.pkl")
self.index_path = Path("face_index.faiss")
self.load_if_exists()
def load_if_exists(self):
if self.store_path.exists() and self.index_path.exists():
# Load face data
with open(self.store_path, 'rb') as f:
self.face_data = pickle.load(f)
# Load FAISS index
self.index = faiss.read_index(str(self.index_path))
def save(self):
# Save face data
with open(self.store_path, 'wb') as f:
pickle.dump(self.face_data, f)
# Save FAISS index
faiss.write_index(self.index, str(self.index_path))
def normalize_embedding(self, embedding: np.ndarray) -> np.ndarray:
"""L2 normalize the embedding"""
embedding = embedding.astype(np.float32)
# Reshape to 2D if needed
if embedding.ndim == 1:
embedding = embedding.reshape(1, -1)
# L2 normalize
faiss.normalize_L2(embedding)
return embedding
def add_face(self, name: str, embedding: np.ndarray) -> None:
# Normalizing the embedding before adding
normalized_embedding = self.normalize_embedding(embedding)
self.face_data.append({"name": name, "embedding": normalized_embedding.flatten()})
self.index.add(normalized_embedding)
self.save()
print(f"Added face for {name}. Total faces: {self.index.ntotal}")
def search_face(self, embedding: np.ndarray, threshold: float = 0.5) -> Optional[Tuple[str, float]]:
if self.index.ntotal == 0:
return None
# Normalizing the query embedding
normalized_embedding = self.normalize_embedding(embedding)
# Searching using cosine similarity
similarities, indices = self.index.search(normalized_embedding, 1)
similarity = similarities[0][0]
print(f"Best match similarity: {similarity}, threshold: {threshold}")
# For cosine similarity, higher is better and max is 1.0 so we can optimize and keep on checking
if similarity > threshold:
matched_face = self.face_data[indices[0][0]]
# Similarity is already between 0 and 1 for cosine
return matched_face["name"], float(similarity)
return None