96 lines
2.6 KiB
Python
96 lines
2.6 KiB
Python
import os
|
|
import uuid
|
|
import json
|
|
import numpy as np
|
|
import faiss
|
|
import httpx
|
|
from fastapi import APIRouter, UploadFile, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional
|
|
|
|
from app import (
|
|
docx_to_chunks,
|
|
get_embeddings,
|
|
build_faiss_index,
|
|
save_index,
|
|
save_chunks,
|
|
load_index,
|
|
load_chunks,
|
|
openrouter_chat
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
# --- request / response models --------------------------------------
|
|
|
|
class UploadResponse(BaseModel):
|
|
message: str
|
|
doc_id: str
|
|
|
|
class QueryRequest(BaseModel):
|
|
doc_id: str
|
|
query: str
|
|
top_k: Optional[int] = 5
|
|
|
|
class QueryResponse(BaseModel):
|
|
answer: str
|
|
sources: List[int]
|
|
|
|
# --- endpoints ------------------------------------------------------
|
|
|
|
@router.post("/upload", response_model=UploadResponse)
|
|
async def upload_endpoint(file: UploadFile):
|
|
if not file.filename.endswith('.docx'):
|
|
raise HTTPException(status_code=400, detail="Only .docx files are supported")
|
|
|
|
# Generate doc_id from filename
|
|
doc_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, file.filename))
|
|
|
|
# Create temp directory if it doesn't exist
|
|
os.makedirs("temp", exist_ok=True)
|
|
tmp_path = f"temp/{doc_id}.docx"
|
|
|
|
try:
|
|
# Save uploaded file
|
|
with open(tmp_path, "wb") as f:
|
|
f.write(await file.read())
|
|
|
|
# Process document
|
|
chunks = docx_to_chunks(tmp_path)
|
|
embeddings = get_embeddings(chunks)
|
|
|
|
# Save index and chunks
|
|
index = build_faiss_index(embeddings)
|
|
save_index(index, doc_id)
|
|
save_chunks(chunks, doc_id)
|
|
|
|
return {"message": "Document indexed", "doc_id": doc_id}
|
|
|
|
finally:
|
|
# Cleanup temp file
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|
|
|
|
@router.post("/query", response_model=QueryResponse)
|
|
async def query_endpoint(request: QueryRequest):
|
|
try:
|
|
# Load stored data
|
|
index = load_index(request.doc_id)
|
|
chunks = load_chunks(request.doc_id)
|
|
|
|
# Process query
|
|
q_vec = get_embeddings([request.query])
|
|
distances, idxs = index.search(q_vec, k=request.top_k)
|
|
|
|
# Get relevant chunks
|
|
context = "\n\n".join(chunks[i] for i in idxs[0])
|
|
|
|
# Generate answer
|
|
prompt = (f"Use ONLY the context to answer the question.\n\n"
|
|
f"Context:\n{context}\n\nQuestion: {request.query}\n\nAnswer:")
|
|
answer = await openrouter_chat(prompt)
|
|
|
|
return {"answer": answer, "sources": idxs[0].tolist()}
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e)) |