39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
import numpy as np
|
|
import faiss
|
|
from sentence_transformers import SentenceTransformer
|
|
from app.core.config import settings
|
|
import logging
|
|
import time
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
def log_latency(func):
|
|
def wrapper(*args, **kwargs):
|
|
start_time = time.time()
|
|
result = func(*args, **kwargs)
|
|
end_time = time.time()
|
|
latency_ms = (end_time - start_time) * 1000
|
|
logging.info(f"Latency for {func.__name__}: {latency_ms:.2f} ms")
|
|
return result
|
|
return wrapper
|
|
|
|
class EmbeddingService:
|
|
|
|
def __init__(self):
|
|
self.model = SentenceTransformer(settings.MODEL_NAME)
|
|
|
|
@log_latency
|
|
def get_embeddings(self, texts: list[str]) -> np.ndarray:
|
|
"""Generate embeddings for a list of texts."""
|
|
embeddings = self.model.encode(texts, convert_to_tensor=False)
|
|
return np.array(embeddings, dtype="float32")
|
|
|
|
@log_latency
|
|
def build_faiss_index(self, vectors: np.ndarray) -> faiss.Index:
|
|
"""Build a FAISS index from vectors."""
|
|
dim = vectors.shape[1]
|
|
index = faiss.IndexFlatL2(dim)
|
|
index.add(vectors)
|
|
return index
|
|
|
|
embedding_service = EmbeddingService() |