2025-05-19 19:20:43 +05:30

109 lines
3.1 KiB
Python

import os
import uuid
from fastapi import APIRouter, UploadFile, HTTPException
from pydantic import BaseModel
from typing import List, Optional
from app.core import (
docx_to_chunks,
get_embeddings,
build_faiss_index,
save_index,
save_chunks,
load_index,
load_chunks,
openrouter_chat
)
router = APIRouter()
# --- request / response models --------------------------------------
class UploadResponse(BaseModel):
message: str
doc_id: str
class QueryRequest(BaseModel):
doc_id: str
query: str
top_k: Optional[int] = 5
class QueryResponse(BaseModel):
answer: str
sources: List[int]
class DocumentsResponse(BaseModel):
documents: List[str]
# --- endpoints ------------------------------------------------------
@router.get("/documents", response_model=DocumentsResponse)
async def list_documents():
"""List all indexed documents."""
try:
# Get all .index files from the indices directory
doc_ids = [
os.path.splitext(f)[0] # Remove .index extension
for f in os.listdir("indices")
if f.endswith(".index")
]
return {"documents": doc_ids}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/upload", response_model=UploadResponse)
async def upload_endpoint(file: UploadFile):
if not file.filename.endswith('.docx'):
raise HTTPException(status_code=400, detail="Only .docx files are supported")
# Generate doc_id from filename
doc_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, file.filename))
# Create temp directory if it doesn't exist
os.makedirs("temp", exist_ok=True)
tmp_path = f"temp/{doc_id}.docx"
try:
# Save uploaded file
with open(tmp_path, "wb") as f:
f.write(await file.read())
# Process document
chunks = docx_to_chunks(tmp_path)
embeddings = get_embeddings(chunks)
# Save index and chunks
index = build_faiss_index(embeddings)
save_index(index, doc_id)
save_chunks(chunks, doc_id)
return {"message": "Document indexed", "doc_id": doc_id}
finally:
# Cleanup temp file
if os.path.exists(tmp_path):
os.remove(tmp_path)
@router.post("/query", response_model=QueryResponse)
async def query_endpoint(request: QueryRequest):
try:
# Load stored data
index = load_index(request.doc_id)
chunks = load_chunks(request.doc_id)
# Process query
q_vec = get_embeddings([request.query])
distances, idxs = index.search(q_vec, k=request.top_k)
# Get relevant chunks
context = "\n\n".join(chunks[i] for i in idxs[0])
# Generate answer
prompt = (f"Use ONLY the context to answer the question.\n\n"
f"Context:\n{context}\n\nQuestion: {request.query}\n\nAnswer:")
answer = await openrouter_chat(prompt)
return {"answer": answer, "sources": idxs[0].tolist()}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))