113 lines
3.4 KiB
Python
113 lines
3.4 KiB
Python
import os
|
|
import uuid
|
|
import numpy as np
|
|
import faiss
|
|
import httpx
|
|
from fastapi import FastAPI, UploadFile, Form
|
|
from langchain_community.document_loaders import Docx2txtLoader
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from pydantic import BaseModel
|
|
from dotenv import load_dotenv
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Initialize the embedding model
|
|
MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
|
OPENROUTER_API_KEY = "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"
|
|
|
|
app = FastAPI(title="RAG-via-FastAPI")
|
|
|
|
# --- helpers ---------------------------------------------------------
|
|
|
|
def docx_to_chunks(file_path):
|
|
pages = Docx2txtLoader(file_path).load()
|
|
chunks = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000, chunk_overlap=200
|
|
).split_documents(pages)
|
|
return [c.page_content for c in chunks]
|
|
|
|
def get_embeddings(texts):
|
|
# Generate embeddings using SentenceTransformer
|
|
embeddings = MODEL.encode(texts, convert_to_tensor=False)
|
|
return np.array(embeddings, dtype="float32")
|
|
|
|
def build_faiss_index(vectors):
|
|
dim = vectors.shape[1]
|
|
index = faiss.IndexFlatL2(dim)
|
|
index.add(vectors)
|
|
return index
|
|
|
|
async def openrouter_chat(prompt):
|
|
headers = {
|
|
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
|
"HTTP-Referer": "https://yourapp.example",
|
|
"Content-Type": "application/json"
|
|
}
|
|
data = {
|
|
"model": "google/gemma-3-27b-it:free", # any OpenRouter-hosted model
|
|
"messages": [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
"temperature": 0.3,
|
|
"max_tokens": 512
|
|
}
|
|
async with httpx.AsyncClient(timeout=120) as client:
|
|
r = await client.post(
|
|
"https://openrouter.ai/api/v1/chat/completions",
|
|
headers=headers,
|
|
json=data
|
|
)
|
|
r.raise_for_status()
|
|
return r.json()["choices"][0]["message"]["content"]
|
|
|
|
# --- request / response models --------------------------------------
|
|
|
|
class QueryRequest(BaseModel):
|
|
query: str
|
|
|
|
class RAGResponse(BaseModel):
|
|
answer: str
|
|
sources: list[int]
|
|
|
|
# --- endpoints ------------------------------------------------------
|
|
|
|
@app.post("/rag", response_model=RAGResponse)
|
|
async def rag_endpoint(file: UploadFile, query: str = Form(...)):
|
|
# Create temp directory if it doesn't exist
|
|
os.makedirs("temp", exist_ok=True)
|
|
|
|
# Use a local temp directory instead of /tmp for Windows compatibility
|
|
tmp_path = f"temp/{uuid.uuid4()}.docx"
|
|
|
|
try:
|
|
# Save uploaded file
|
|
with open(tmp_path, "wb") as f:
|
|
f.write(await file.read())
|
|
|
|
# 1. chunk
|
|
chunks = docx_to_chunks(tmp_path)
|
|
|
|
# 2. embed
|
|
embeddings = get_embeddings(chunks)
|
|
|
|
# 3. search
|
|
index = build_faiss_index(embeddings)
|
|
q_vec = get_embeddings([query])
|
|
distances, idxs = index.search(q_vec, k=3)
|
|
context = "\n\n".join(chunks[i] for i in idxs[0])
|
|
|
|
# 4. generate
|
|
prompt = (f"Use ONLY the context to answer the question.\n\n"
|
|
f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
|
|
answer = await openrouter_chat(prompt)
|
|
|
|
return {"answer": answer, "sources": idxs[0].tolist()}
|
|
|
|
finally:
|
|
# Cleanup temp file
|
|
if os.path.exists(tmp_path):
|
|
os.remove(tmp_path)
|