You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

160 lines
5.1 KiB

  1. from contextlib import asynccontextmanager
  2. import io
  3. from typing import List
  4. import numpy as np
  5. from fastapi import FastAPI, File, Form, HTTPException, UploadFile
  6. from fastapi.middleware.cors import CORSMiddleware
  7. from PIL import Image
  8. from pydantic import BaseModel
  9. import insightface
  10. from insightface.app import FaceAnalysis
  11. from face_store import FaceStore
  12. # Initializing InsightFace model with better detection settings with robust detection
  13. face_analyzer = FaceAnalysis(
  14. providers=['CPUExecutionProvider'],
  15. allowed_modules=['detection', 'recognition']
  16. )
  17. face_analyzer.prepare(ctx_id=0, det_size=(640, 640))
  18. # Initializing face store
  19. face_store = FaceStore()
  20. @asynccontextmanager
  21. async def lifespan(app: FastAPI):
  22. print("Initializing face recognition system...")
  23. yield
  24. print("Cleaning up resources...")
  25. app = FastAPI(lifespan=lifespan)
  26. app.add_middleware(
  27. CORSMiddleware,
  28. allow_origins=["*"],
  29. allow_credentials=True,
  30. allow_methods=["*"],
  31. allow_headers=["*"],
  32. )
  33. class Visitor(BaseModel):
  34. name: str
  35. encoded_face: List[float]
  36. class Config:
  37. from_attributes = True
  38. def get_largest_face(faces):
  39. """Select the largest face from detected faces based on bounding box area."""
  40. if not faces:
  41. return None
  42. # Calculating areas of all faces
  43. areas = [(face, (face.bbox[2] - face.bbox[0]) * (face.bbox[3] - face.bbox[1]))
  44. for face in faces]
  45. # Returning face with largest area
  46. return max(areas, key=lambda x: x[1])[0]
  47. def process_image(image_data: bytes):
  48. """Process image and return embedding of the largest face."""
  49. try:
  50. image_stream = io.BytesIO(image_data)
  51. image_pil = Image.open(image_stream).convert("RGB")
  52. # Resizing image if too large (optional, adjust dimensions as needed)
  53. max_size = 1920
  54. if max(image_pil.size) > max_size:
  55. ratio = max_size / max(image_pil.size)
  56. new_size = tuple(int(dim * ratio) for dim in image_pil.size)
  57. image_pil = image_pil.resize(new_size, Image.Resampling.LANCZOS)
  58. image_np = np.array(image_pil)
  59. faces = face_analyzer.get(image_np)
  60. if not faces:
  61. return None, "No face detected"
  62. # Get the largest face
  63. largest_face = get_largest_face(faces)
  64. # Converting embedding to numpy array to ensure consistent format
  65. embedding = np.array(largest_face.embedding, dtype=np.float32)
  66. return embedding, None
  67. except Exception as e:
  68. return None, f"Error processing image: {str(e)}"
  69. @app.get("/")
  70. async def health_check():
  71. return {"message": "Face recognition API is running"}
  72. @app.post("/api/register")
  73. async def register_visitor(name: str = Form(...), image: UploadFile = File(...)):
  74. try:
  75. image_data = await image.read()
  76. embedding, error = process_image(image_data)
  77. if error:
  78. return {"message": error}
  79. # Converting embedding to numpy array if it isn't already
  80. embedding = np.array(embedding, dtype=np.float32)
  81. # Adding debug logging
  82. print(f"Registering face for {name}")
  83. print(f"Embedding shape: {embedding.shape}")
  84. print(f"Embedding type: {type(embedding)}")
  85. # Checking if face already exists
  86. existing_match = face_store.search_face(embedding)
  87. if existing_match:
  88. return {
  89. "message": "Visitor already exists",
  90. "name": existing_match[0]
  91. }
  92. # Registering new face
  93. face_store.add_face(name, embedding)
  94. # Verifying registration
  95. verification = face_store.search_face(embedding)
  96. if not verification:
  97. raise HTTPException(status_code=500, detail="Face registration failed verification")
  98. return {
  99. "message": "Visitor registered successfully",
  100. "name": name
  101. }
  102. except Exception as e:
  103. raise HTTPException(status_code=500, detail=str(e))
  104. @app.post("/api/search")
  105. async def search_visitor(image: UploadFile = File(...)):
  106. try:
  107. image_data = await image.read()
  108. embedding, error = process_image(image_data)
  109. if error:
  110. return {"message": error}
  111. # Converting embedding to numpy array if it isn't already
  112. embedding = np.array(embedding, dtype=np.float32)
  113. # Adding debug logging
  114. print(f"Searching for face")
  115. print(f"Embedding shape: {embedding.shape}")
  116. print(f"Embedding type: {type(embedding)}")
  117. match = face_store.search_face(embedding)
  118. if match:
  119. name, confidence = match
  120. return {
  121. "message": "Visitor found",
  122. "name": name,
  123. "confidence": confidence
  124. }
  125. return {"message": "Visitor not found"}
  126. except Exception as e:
  127. raise HTTPException(status_code=500, detail=str(e))