2025-05-19 19:20:43 +05:30

96 lines
2.6 KiB
Python

import os
import uuid
import json
import numpy as np
import faiss
import httpx
from fastapi import APIRouter, UploadFile, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from app import (
docx_to_chunks,
get_embeddings,
build_faiss_index,
save_index,
save_chunks,
load_index,
load_chunks,
openrouter_chat
)
router = APIRouter()
# --- request / response models --------------------------------------
class UploadResponse(BaseModel):
message: str
doc_id: str
class QueryRequest(BaseModel):
doc_id: str
query: str
top_k: Optional[int] = 5
class QueryResponse(BaseModel):
answer: str
sources: List[int]
# --- endpoints ------------------------------------------------------
@router.post("/upload", response_model=UploadResponse)
async def upload_endpoint(file: UploadFile):
if not file.filename.endswith('.docx'):
raise HTTPException(status_code=400, detail="Only .docx files are supported")
# Generate doc_id from filename
doc_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, file.filename))
# Create temp directory if it doesn't exist
os.makedirs("temp", exist_ok=True)
tmp_path = f"temp/{doc_id}.docx"
try:
# Save uploaded file
with open(tmp_path, "wb") as f:
f.write(await file.read())
# Process document
chunks = docx_to_chunks(tmp_path)
embeddings = get_embeddings(chunks)
# Save index and chunks
index = build_faiss_index(embeddings)
save_index(index, doc_id)
save_chunks(chunks, doc_id)
return {"message": "Document indexed", "doc_id": doc_id}
finally:
# Cleanup temp file
if os.path.exists(tmp_path):
os.remove(tmp_path)
@router.post("/query", response_model=QueryResponse)
async def query_endpoint(request: QueryRequest):
try:
# Load stored data
index = load_index(request.doc_id)
chunks = load_chunks(request.doc_id)
# Process query
q_vec = get_embeddings([request.query])
distances, idxs = index.search(q_vec, k=request.top_k)
# Get relevant chunks
context = "\n\n".join(chunks[i] for i in idxs[0])
# Generate answer
prompt = (f"Use ONLY the context to answer the question.\n\n"
f"Context:\n{context}\n\nQuestion: {request.query}\n\nAnswer:")
answer = await openrouter_chat(prompt)
return {"answer": answer, "sources": idxs[0].tolist()}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))