72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
import os
|
|
import uuid
|
|
import time
|
|
import logging
|
|
from fastapi import APIRouter, UploadFile, Form
|
|
from app.core.models import RAGResponse
|
|
from app.services.document import docx_to_chunks
|
|
from app.services.embedding import embedding_service
|
|
from app.services.llm import chat_completion
|
|
from app.core.config import settings
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
router = APIRouter()
|
|
|
|
def log_latency(func):
|
|
def wrapper(*args, **kwargs):
|
|
start_time = time.time()
|
|
result = func(*args, **kwargs)
|
|
end_time = time.time()
|
|
latency_ms = (end_time - start_time) * 1000
|
|
logging.info(f"Latency for {func.__name__}: {latency_ms:.2f} ms")
|
|
return result
|
|
return wrapper
|
|
|
|
@log_latency
|
|
def search_faiss(index, query_vector, k, chunks): # Added chunks parameter
|
|
"""Search the FAISS index for similar vectors."""
|
|
distances, idxs = index.search(query_vector, k)
|
|
# Log top k results with their distances and text preview
|
|
for i in range(k):
|
|
chunk_text = chunks[idxs[0][i]]
|
|
preview = ' '.join(chunk_text.split()[:20]) # Get first 50 words
|
|
logging.info(f"Top {i+1} result - Index: {idxs[0][i]}, Distance: {distances[0][i]:.4f}")
|
|
logging.info(f"Text preview: {preview}...")
|
|
return distances, idxs
|
|
|
|
@router.post("/rag", response_model=RAGResponse)
|
|
async def rag_endpoint(file: UploadFile, query: str = Form(...)):
|
|
# Create temp directory if it doesn't exist
|
|
os.makedirs("temp", exist_ok=True)
|
|
|
|
# Use a local temp directory for Windows compatibility
|
|
tmp_path = f"temp/{uuid.uuid4()}.docx"
|
|
|
|
try:
|
|
# Save uploaded file
|
|
with open(tmp_path, "wb") as f:
|
|
f.write(await file.read())
|
|
|
|
# 1. chunk
|
|
chunks = docx_to_chunks(tmp_path)
|
|
|
|
# 2. embed
|
|
embeddings = embedding_service.get_embeddings(chunks)
|
|
|
|
# 3. search
|
|
index = embedding_service.build_faiss_index(embeddings)
|
|
q_vec = embedding_service.get_embeddings([query])
|
|
distances, idxs = search_faiss(index, q_vec, settings.TOP_K_RESULTS, chunks)
|
|
|
|
# 4. generate
|
|
context = "\n\n".join(chunks[i] for i in idxs[0])
|
|
prompt = (f"Use ONLY the context to answer the question.\n\n"
|
|
f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
|
|
answer = await chat_completion(prompt)
|
|
|
|
return {"answer": answer, "sources": idxs[0].tolist()}
|
|
|
|
finally:
|
|
# Cleanup temp file
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path) |