You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.8 KiB

  1. import faiss
  2. import numpy as np
  3. import pickle
  4. from pathlib import Path
  5. from typing import List, Optional, Tuple
  6. class FaceStore:
  7. def __init__(self, dimension: int = 512): # 512 for ArcFace
  8. self.dimension = dimension
  9. # Use cosine similarity instead of L2 distance
  10. self.index = faiss.IndexFlatIP(dimension) # Inner Product = Cosine similarity for normalized vectors
  11. self.face_data = []
  12. self.store_path = Path("face_store.pkl")
  13. self.index_path = Path("face_index.faiss")
  14. self.load_if_exists()
  15. def load_if_exists(self):
  16. if self.store_path.exists() and self.index_path.exists():
  17. # Load face data
  18. with open(self.store_path, 'rb') as f:
  19. self.face_data = pickle.load(f)
  20. # Load FAISS index
  21. self.index = faiss.read_index(str(self.index_path))
  22. def save(self):
  23. # Save face data
  24. with open(self.store_path, 'wb') as f:
  25. pickle.dump(self.face_data, f)
  26. # Save FAISS index
  27. faiss.write_index(self.index, str(self.index_path))
  28. def normalize_embedding(self, embedding: np.ndarray) -> np.ndarray:
  29. """L2 normalize the embedding"""
  30. embedding = embedding.astype(np.float32)
  31. # Reshape to 2D if needed
  32. if embedding.ndim == 1:
  33. embedding = embedding.reshape(1, -1)
  34. # L2 normalize
  35. faiss.normalize_L2(embedding)
  36. return embedding
  37. def add_face(self, name: str, embedding: np.ndarray) -> None:
  38. # Normalizing the embedding before adding
  39. normalized_embedding = self.normalize_embedding(embedding)
  40. self.face_data.append({"name": name, "embedding": normalized_embedding.flatten()})
  41. self.index.add(normalized_embedding)
  42. self.save()
  43. print(f"Added face for {name}. Total faces: {self.index.ntotal}")
  44. def search_face(self, embedding: np.ndarray, threshold: float = 0.5) -> Optional[Tuple[str, float]]:
  45. if self.index.ntotal == 0:
  46. return None
  47. # Normalizing the query embedding
  48. normalized_embedding = self.normalize_embedding(embedding)
  49. # Searching using cosine similarity
  50. similarities, indices = self.index.search(normalized_embedding, 1)
  51. similarity = similarities[0][0]
  52. print(f"Best match similarity: {similarity}, threshold: {threshold}")
  53. # For cosine similarity, higher is better and max is 1.0 so we can optimize and keep on checking
  54. if similarity > threshold:
  55. matched_face = self.face_data[indices[0][0]]
  56. # Similarity is already between 0 and 1 for cosine
  57. return matched_face["name"], float(similarity)
  58. return None