Joydeep Pandey 6e16fc99c9 Feature done
Average latency 25ms
2025-05-19 17:20:17 +05:30

113 lines
3.4 KiB
Python

import os
import uuid
import numpy as np
import faiss
import httpx
from fastapi import FastAPI, UploadFile, Form
from langchain_community.document_loaders import Docx2txtLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from pydantic import BaseModel
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
# Load environment variables
load_dotenv()
# Initialize the embedding model
MODEL = SentenceTransformer("all-MiniLM-L6-v2")
OPENROUTER_API_KEY = "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"
app = FastAPI(title="RAG-via-FastAPI")
# --- helpers ---------------------------------------------------------
def docx_to_chunks(file_path):
pages = Docx2txtLoader(file_path).load()
chunks = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200
).split_documents(pages)
return [c.page_content for c in chunks]
def get_embeddings(texts):
# Generate embeddings using SentenceTransformer
embeddings = MODEL.encode(texts, convert_to_tensor=False)
return np.array(embeddings, dtype="float32")
def build_faiss_index(vectors):
dim = vectors.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(vectors)
return index
async def openrouter_chat(prompt):
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"HTTP-Referer": "https://yourapp.example",
"Content-Type": "application/json"
}
data = {
"model": "google/gemma-3-27b-it:free", # any OpenRouter-hosted model
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
"temperature": 0.3,
"max_tokens": 512
}
async with httpx.AsyncClient(timeout=120) as client:
r = await client.post(
"https://openrouter.ai/api/v1/chat/completions",
headers=headers,
json=data
)
r.raise_for_status()
return r.json()["choices"][0]["message"]["content"]
# --- request / response models --------------------------------------
class QueryRequest(BaseModel):
query: str
class RAGResponse(BaseModel):
answer: str
sources: list[int]
# --- endpoints ------------------------------------------------------
@app.post("/rag", response_model=RAGResponse)
async def rag_endpoint(file: UploadFile, query: str = Form(...)):
# Create temp directory if it doesn't exist
os.makedirs("temp", exist_ok=True)
# Use a local temp directory instead of /tmp for Windows compatibility
tmp_path = f"temp/{uuid.uuid4()}.docx"
try:
# Save uploaded file
with open(tmp_path, "wb") as f:
f.write(await file.read())
# 1. chunk
chunks = docx_to_chunks(tmp_path)
# 2. embed
embeddings = get_embeddings(chunks)
# 3. search
index = build_faiss_index(embeddings)
q_vec = get_embeddings([query])
distances, idxs = index.search(q_vec, k=3)
context = "\n\n".join(chunks[i] for i in idxs[0])
# 4. generate
prompt = (f"Use ONLY the context to answer the question.\n\n"
f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
answer = await openrouter_chat(prompt)
return {"answer": answer, "sources": idxs[0].tolist()}
finally:
# Cleanup temp file
if os.path.exists(tmp_path):
os.remove(tmp_path)