Compare commits
1 Commits
main
...
faiss-inde
Author | SHA1 | Date | |
---|---|---|---|
ad9e4b553a |
2
.gitignore
vendored
2
.gitignore
vendored
@ -144,6 +144,8 @@ venv.bak/
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
config.py
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
|
80
app.py
80
app.py
@ -1,12 +1,11 @@
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
import numpy as np
|
||||
import faiss
|
||||
import httpx
|
||||
from fastapi import FastAPI, UploadFile, Form
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from langchain_community.document_loaders import Docx2txtLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
@ -19,6 +18,10 @@ OPENROUTER_API_KEY = "sk-or-v1-7420d7a6c2f5ab366682e2270c543a1802399485ec0a56c4f
|
||||
|
||||
app = FastAPI(title="RAG-via-FastAPI")
|
||||
|
||||
# Create necessary directories
|
||||
os.makedirs("indices", exist_ok=True)
|
||||
os.makedirs("chunks", exist_ok=True)
|
||||
|
||||
# --- helpers ---------------------------------------------------------
|
||||
|
||||
def docx_to_chunks(file_path):
|
||||
@ -39,6 +42,28 @@ def build_faiss_index(vectors):
|
||||
index.add(vectors)
|
||||
return index
|
||||
|
||||
def save_index(index, doc_id):
|
||||
index_path = f"indices/{doc_id}.index"
|
||||
faiss.write_index(index, index_path)
|
||||
|
||||
def save_chunks(chunks, doc_id):
|
||||
chunks_path = f"chunks/{doc_id}.json"
|
||||
with open(chunks_path, "w") as f:
|
||||
json.dump(chunks, f)
|
||||
|
||||
def load_index(doc_id):
|
||||
index_path = f"indices/{doc_id}.index"
|
||||
if not os.path.exists(index_path):
|
||||
raise HTTPException(status_code=404, detail=f"Index not found for doc_id: {doc_id}")
|
||||
return faiss.read_index(index_path)
|
||||
|
||||
def load_chunks(doc_id):
|
||||
chunks_path = f"chunks/{doc_id}.json"
|
||||
if not os.path.exists(chunks_path):
|
||||
raise HTTPException(status_code=404, detail=f"Chunks not found for doc_id: {doc_id}")
|
||||
with open(chunks_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
async def openrouter_chat(prompt):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
||||
@ -46,7 +71,7 @@ async def openrouter_chat(prompt):
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": "google/gemma-3-27b-it:free", # any OpenRouter-hosted model
|
||||
"model": "google/gemma-3-27b-it:free",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
@ -63,50 +88,3 @@ async def openrouter_chat(prompt):
|
||||
r.raise_for_status()
|
||||
return r.json()["choices"][0]["message"]["content"]
|
||||
|
||||
# --- request / response models --------------------------------------
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
class RAGResponse(BaseModel):
|
||||
answer: str
|
||||
sources: list[int]
|
||||
|
||||
# --- endpoints ------------------------------------------------------
|
||||
|
||||
@app.post("/rag", response_model=RAGResponse)
|
||||
async def rag_endpoint(file: UploadFile, query: str = Form(...)):
|
||||
# Create temp directory if it doesn't exist
|
||||
os.makedirs("temp", exist_ok=True)
|
||||
|
||||
# Use a local temp directory instead of /tmp for Windows compatibility
|
||||
tmp_path = f"temp/{uuid.uuid4()}.docx"
|
||||
|
||||
try:
|
||||
# Save uploaded file
|
||||
with open(tmp_path, "wb") as f:
|
||||
f.write(await file.read())
|
||||
|
||||
# 1. chunk
|
||||
chunks = docx_to_chunks(tmp_path)
|
||||
|
||||
# 2. embed
|
||||
embeddings = get_embeddings(chunks)
|
||||
|
||||
# 3. search
|
||||
index = build_faiss_index(embeddings)
|
||||
q_vec = get_embeddings([query])
|
||||
distances, idxs = index.search(q_vec, k=3)
|
||||
context = "\n\n".join(chunks[i] for i in idxs[0])
|
||||
|
||||
# 4. generate
|
||||
prompt = (f"Use ONLY the context to answer the question.\n\n"
|
||||
f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
|
||||
answer = await openrouter_chat(prompt)
|
||||
|
||||
return {"answer": answer, "sources": idxs[0].tolist()}
|
||||
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
@ -1,3 +1,9 @@
|
||||
"""
|
||||
RAG (Retrieval Augmented Generation) application package.
|
||||
"""
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from app.api.routes import router
|
||||
|
||||
app = FastAPI(title="RAG-via-FastAPI")
|
||||
app.include_router(router)
|
@ -0,0 +1,3 @@
|
||||
"""
|
||||
API routes package.
|
||||
"""
|
@ -1,72 +1,109 @@
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
from fastapi import APIRouter, UploadFile, Form
|
||||
from app.core.models import RAGResponse
|
||||
from app.services.document import docx_to_chunks
|
||||
from app.services.embedding import embedding_service
|
||||
from app.services.llm import chat_completion
|
||||
from app.core.config import settings
|
||||
from fastapi import APIRouter, UploadFile, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core import (
|
||||
docx_to_chunks,
|
||||
get_embeddings,
|
||||
build_faiss_index,
|
||||
save_index,
|
||||
save_chunks,
|
||||
load_index,
|
||||
load_chunks,
|
||||
openrouter_chat
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
router = APIRouter()
|
||||
|
||||
def log_latency(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
latency_ms = (end_time - start_time) * 1000
|
||||
logging.info(f"Latency for {func.__name__}: {latency_ms:.2f} ms")
|
||||
return result
|
||||
return wrapper
|
||||
# --- request / response models --------------------------------------
|
||||
|
||||
@log_latency
|
||||
def search_faiss(index, query_vector, k, chunks): # Added chunks parameter
|
||||
"""Search the FAISS index for similar vectors."""
|
||||
distances, idxs = index.search(query_vector, k)
|
||||
# Log top k results with their distances and text preview
|
||||
for i in range(k):
|
||||
chunk_text = chunks[idxs[0][i]]
|
||||
preview = ' '.join(chunk_text.split()[:20]) # Get first 50 words
|
||||
logging.info(f"Top {i+1} result - Index: {idxs[0][i]}, Distance: {distances[0][i]:.4f}")
|
||||
logging.info(f"Text preview: {preview}...")
|
||||
return distances, idxs
|
||||
class UploadResponse(BaseModel):
|
||||
message: str
|
||||
doc_id: str
|
||||
|
||||
@router.post("/rag", response_model=RAGResponse)
|
||||
async def rag_endpoint(file: UploadFile, query: str = Form(...)):
|
||||
class QueryRequest(BaseModel):
|
||||
doc_id: str
|
||||
query: str
|
||||
top_k: Optional[int] = 5
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
answer: str
|
||||
sources: List[int]
|
||||
|
||||
class DocumentsResponse(BaseModel):
|
||||
documents: List[str]
|
||||
|
||||
# --- endpoints ------------------------------------------------------
|
||||
|
||||
@router.get("/documents", response_model=DocumentsResponse)
|
||||
async def list_documents():
|
||||
"""List all indexed documents."""
|
||||
try:
|
||||
# Get all .index files from the indices directory
|
||||
doc_ids = [
|
||||
os.path.splitext(f)[0] # Remove .index extension
|
||||
for f in os.listdir("indices")
|
||||
if f.endswith(".index")
|
||||
]
|
||||
return {"documents": doc_ids}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/upload", response_model=UploadResponse)
|
||||
async def upload_endpoint(file: UploadFile):
|
||||
if not file.filename.endswith('.docx'):
|
||||
raise HTTPException(status_code=400, detail="Only .docx files are supported")
|
||||
|
||||
# Generate doc_id from filename
|
||||
doc_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, file.filename))
|
||||
|
||||
# Create temp directory if it doesn't exist
|
||||
os.makedirs("temp", exist_ok=True)
|
||||
|
||||
# Use a local temp directory for Windows compatibility
|
||||
tmp_path = f"temp/{uuid.uuid4()}.docx"
|
||||
tmp_path = f"temp/{doc_id}.docx"
|
||||
|
||||
try:
|
||||
# Save uploaded file
|
||||
with open(tmp_path, "wb") as f:
|
||||
f.write(await file.read())
|
||||
|
||||
# 1. chunk
|
||||
# Process document
|
||||
chunks = docx_to_chunks(tmp_path)
|
||||
embeddings = get_embeddings(chunks)
|
||||
|
||||
# Save index and chunks
|
||||
index = build_faiss_index(embeddings)
|
||||
save_index(index, doc_id)
|
||||
save_chunks(chunks, doc_id)
|
||||
|
||||
# 2. embed
|
||||
embeddings = embedding_service.get_embeddings(chunks)
|
||||
|
||||
# 3. search
|
||||
index = embedding_service.build_faiss_index(embeddings)
|
||||
q_vec = embedding_service.get_embeddings([query])
|
||||
distances, idxs = search_faiss(index, q_vec, settings.TOP_K_RESULTS, chunks)
|
||||
|
||||
# 4. generate
|
||||
context = "\n\n".join(chunks[i] for i in idxs[0])
|
||||
prompt = (f"Use ONLY the context to answer the question.\n\n"
|
||||
f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
|
||||
answer = await chat_completion(prompt)
|
||||
|
||||
return {"answer": answer, "sources": idxs[0].tolist()}
|
||||
return {"message": "Document indexed", "doc_id": doc_id}
|
||||
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
os.remove(tmp_path)
|
||||
|
||||
@router.post("/query", response_model=QueryResponse)
|
||||
async def query_endpoint(request: QueryRequest):
|
||||
try:
|
||||
# Load stored data
|
||||
index = load_index(request.doc_id)
|
||||
chunks = load_chunks(request.doc_id)
|
||||
|
||||
# Process query
|
||||
q_vec = get_embeddings([request.query])
|
||||
distances, idxs = index.search(q_vec, k=request.top_k)
|
||||
|
||||
# Get relevant chunks
|
||||
context = "\n\n".join(chunks[i] for i in idxs[0])
|
||||
|
||||
# Generate answer
|
||||
prompt = (f"Use ONLY the context to answer the question.\n\n"
|
||||
f"Context:\n{context}\n\nQuestion: {request.query}\n\nAnswer:")
|
||||
answer = await openrouter_chat(prompt)
|
||||
|
||||
return {"answer": answer, "sources": idxs[0].tolist()}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
85
app/core.py
Normal file
85
app/core.py
Normal file
@ -0,0 +1,85 @@
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import faiss
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from langchain_community.document_loaders import Docx2txtLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from dotenv import load_dotenv
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize the embedding model
|
||||
MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
OPENROUTER_API_KEY = "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"
|
||||
|
||||
# Create necessary directories
|
||||
os.makedirs("indices", exist_ok=True)
|
||||
os.makedirs("chunks", exist_ok=True)
|
||||
|
||||
def docx_to_chunks(file_path):
|
||||
pages = Docx2txtLoader(file_path).load()
|
||||
chunks = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=200
|
||||
).split_documents(pages)
|
||||
return [c.page_content for c in chunks]
|
||||
|
||||
def get_embeddings(texts):
|
||||
# Generate embeddings using SentenceTransformer
|
||||
embeddings = MODEL.encode(texts, convert_to_tensor=False)
|
||||
return np.array(embeddings, dtype="float32")
|
||||
|
||||
def build_faiss_index(vectors):
|
||||
dim = vectors.shape[1]
|
||||
index = faiss.IndexFlatL2(dim)
|
||||
index.add(vectors)
|
||||
return index
|
||||
|
||||
def save_index(index, doc_id):
|
||||
index_path = f"indices/{doc_id}.index"
|
||||
faiss.write_index(index, index_path)
|
||||
|
||||
def save_chunks(chunks, doc_id):
|
||||
chunks_path = f"chunks/{doc_id}.json"
|
||||
with open(chunks_path, "w") as f:
|
||||
json.dump(chunks, f)
|
||||
|
||||
def load_index(doc_id):
|
||||
index_path = f"indices/{doc_id}.index"
|
||||
if not os.path.exists(index_path):
|
||||
raise HTTPException(status_code=404, detail=f"Index not found for doc_id: {doc_id}")
|
||||
return faiss.read_index(index_path)
|
||||
|
||||
def load_chunks(doc_id):
|
||||
chunks_path = f"chunks/{doc_id}.json"
|
||||
if not os.path.exists(chunks_path):
|
||||
raise HTTPException(status_code=404, detail=f"Chunks not found for doc_id: {doc_id}")
|
||||
with open(chunks_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
async def openrouter_chat(prompt):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
||||
"HTTP-Referer": "https://yourapp.example",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": "google/gemma-3-27b-it:free",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 512
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
r = await client.post(
|
||||
"https://openrouter.ai/api/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()["choices"][0]["message"]["content"]
|
@ -0,0 +1,89 @@
|
||||
"""
|
||||
Core functionality for the RAG application.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import faiss
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from langchain_community.document_loaders import Docx2txtLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from dotenv import load_dotenv
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize the embedding model
|
||||
MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
OPENROUTER_API_KEY = "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"
|
||||
|
||||
# Create necessary directories
|
||||
os.makedirs("indices", exist_ok=True)
|
||||
os.makedirs("chunks", exist_ok=True)
|
||||
|
||||
def docx_to_chunks(file_path):
|
||||
pages = Docx2txtLoader(file_path).load()
|
||||
chunks = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=200
|
||||
).split_documents(pages)
|
||||
return [c.page_content for c in chunks]
|
||||
|
||||
def get_embeddings(texts):
|
||||
# Generate embeddings using SentenceTransformer
|
||||
embeddings = MODEL.encode(texts, convert_to_tensor=False)
|
||||
return np.array(embeddings, dtype="float32")
|
||||
|
||||
def build_faiss_index(vectors):
|
||||
dim = vectors.shape[1]
|
||||
index = faiss.IndexFlatL2(dim)
|
||||
index.add(vectors)
|
||||
return index
|
||||
|
||||
def save_index(index, doc_id):
|
||||
index_path = f"indices/{doc_id}.index"
|
||||
faiss.write_index(index, index_path)
|
||||
|
||||
def save_chunks(chunks, doc_id):
|
||||
chunks_path = f"chunks/{doc_id}.json"
|
||||
with open(chunks_path, "w") as f:
|
||||
json.dump(chunks, f)
|
||||
|
||||
def load_index(doc_id):
|
||||
index_path = f"indices/{doc_id}.index"
|
||||
if not os.path.exists(index_path):
|
||||
raise HTTPException(status_code=404, detail=f"Index not found for doc_id: {doc_id}")
|
||||
return faiss.read_index(index_path)
|
||||
|
||||
def load_chunks(doc_id):
|
||||
chunks_path = f"chunks/{doc_id}.json"
|
||||
if not os.path.exists(chunks_path):
|
||||
raise HTTPException(status_code=404, detail=f"Chunks not found for doc_id: {doc_id}")
|
||||
with open(chunks_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
async def openrouter_chat(prompt):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
||||
"HTTP-Referer": "https://yourapp.example",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": "google/gemma-3-27b-it:free",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 512
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
r = await client.post(
|
||||
"https://openrouter.ai/api/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()["choices"][0]["message"]["content"]
|
1
chunks/f22de691-bec3-51ac-9dcc-7f3c440eca7b.json
Normal file
1
chunks/f22de691-bec3-51ac-9dcc-7f3c440eca7b.json
Normal file
@ -0,0 +1 @@
|
||||
["Refund for Pay on Delivery Orders:\n\nFor Pay on Delivery orders, refunds will be processed either to your bank account (via National Electronic Funds Transfer - NEFT) or to your Amazon account (as Amazon Pay balance).\n\nTo receive the refund to your bank account, update your bank account details via the Returns Centre when returning an item.\n\nSteps to add bank account:\n\nThrough Mobile App:\n\nGo to Orders\n\nClick on the order you want to return\n\nSelect Return or Replace items\n\nChoose Refund to your bank account\n\nSelect Choose a bank account\n\nSelect Add a new bank account and enter the bank details\n\nThrough Website:\n\nGo to Orders\n\nClick on the order you want to return\n\nSelect Return or Replacement items\n\nChoose Refund to your bank account\n\nSelect Choose a bank account\n\nSelect Add a new bank account and enter the bank details\n\nNote: Refunds cannot be processed to third-party accounts. The name on your Amazon account should match the name of the bank account holder.\n\n\n\nPaper Cheque Clearing:", "Note: Refunds cannot be processed to third-party accounts. The name on your Amazon account should match the name of the bank account holder.\n\n\n\nPaper Cheque Clearing:\n\nAll cheque refunds will be in the form of \"at par\" Deutsche Bank cheques. These cheques are cleared locally in the following cities:\n\nAhmedabad, Aurangabad, Bangalore, Chennai, Delhi, Gurgaon, Kolhapur, Kolkata, Ludhiana, Moradabad, Mumbai, Noida, Pune, Salem, Surat, Vellore\n\nIf presenting the cheque in another city:\n\nEnsure the cheque is sent for outstation clearing\n\nIf dropping the cheque in a clearance box:\n\nIn cities listed above, use the Local Cheques box\n\nIn other cities, use the Outstation Cheques box\n\nFailing to follow these instructions may result in the cheque not being processed and a penalty.\n\nNote: Once a cheque is issued, Amazon will email tracking details of the refund cheque within 4 business days from the date of refund.\n\n\n\nShipping Cost Refunds:", "Note: Once a cheque is issued, Amazon will email tracking details of the refund cheque within 4 business days from the date of refund.\n\n\n\nShipping Cost Refunds:\n\nFor Fulfilled by Amazon and Prime Eligible items, return shipping costs up to Rs. 100 will be refunded.\n\nGift-wrapping charges, if any, will also be refunded.\n\nRefunds in such cases will be issued via cheque.\n\nNote: For return shipping charges over Rs.100 (for large/heavy items), contact Amazon for an additional refund with proof of payment (e.g., courier receipt).\n\nFor Seller-Fulfilled items:\n\nYou can request the seller to reimburse return shipping costs.\n\nThe seller may ask for the courier receipt as proof."]
|
BIN
indices/f22de691-bec3-51ac-9dcc-7f3c440eca7b.index
Normal file
BIN
indices/f22de691-bec3-51ac-9dcc-7f3c440eca7b.index
Normal file
Binary file not shown.
8
main.py
8
main.py
@ -1,5 +1,5 @@
|
||||
from fastapi import FastAPI
|
||||
from app.api.routes import router
|
||||
import uvicorn
|
||||
from app import app
|
||||
|
||||
app = FastAPI(title="RAG-via-FastAPI")
|
||||
app.include_router(router)
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
96
routes.py
Normal file
96
routes.py
Normal file
@ -0,0 +1,96 @@
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
import numpy as np
|
||||
import faiss
|
||||
import httpx
|
||||
from fastapi import APIRouter, UploadFile, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from app import (
|
||||
docx_to_chunks,
|
||||
get_embeddings,
|
||||
build_faiss_index,
|
||||
save_index,
|
||||
save_chunks,
|
||||
load_index,
|
||||
load_chunks,
|
||||
openrouter_chat
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# --- request / response models --------------------------------------
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
message: str
|
||||
doc_id: str
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
doc_id: str
|
||||
query: str
|
||||
top_k: Optional[int] = 5
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
answer: str
|
||||
sources: List[int]
|
||||
|
||||
# --- endpoints ------------------------------------------------------
|
||||
|
||||
@router.post("/upload", response_model=UploadResponse)
|
||||
async def upload_endpoint(file: UploadFile):
|
||||
if not file.filename.endswith('.docx'):
|
||||
raise HTTPException(status_code=400, detail="Only .docx files are supported")
|
||||
|
||||
# Generate doc_id from filename
|
||||
doc_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, file.filename))
|
||||
|
||||
# Create temp directory if it doesn't exist
|
||||
os.makedirs("temp", exist_ok=True)
|
||||
tmp_path = f"temp/{doc_id}.docx"
|
||||
|
||||
try:
|
||||
# Save uploaded file
|
||||
with open(tmp_path, "wb") as f:
|
||||
f.write(await file.read())
|
||||
|
||||
# Process document
|
||||
chunks = docx_to_chunks(tmp_path)
|
||||
embeddings = get_embeddings(chunks)
|
||||
|
||||
# Save index and chunks
|
||||
index = build_faiss_index(embeddings)
|
||||
save_index(index, doc_id)
|
||||
save_chunks(chunks, doc_id)
|
||||
|
||||
return {"message": "Document indexed", "doc_id": doc_id}
|
||||
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
@router.post("/query", response_model=QueryResponse)
|
||||
async def query_endpoint(request: QueryRequest):
|
||||
try:
|
||||
# Load stored data
|
||||
index = load_index(request.doc_id)
|
||||
chunks = load_chunks(request.doc_id)
|
||||
|
||||
# Process query
|
||||
q_vec = get_embeddings([request.query])
|
||||
distances, idxs = index.search(q_vec, k=request.top_k)
|
||||
|
||||
# Get relevant chunks
|
||||
context = "\n\n".join(chunks[i] for i in idxs[0])
|
||||
|
||||
# Generate answer
|
||||
prompt = (f"Use ONLY the context to answer the question.\n\n"
|
||||
f"Context:\n{context}\n\nQuestion: {request.query}\n\nAnswer:")
|
||||
answer = await openrouter_chat(prompt)
|
||||
|
||||
return {"answer": answer, "sources": idxs[0].tolist()}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
Loading…
x
Reference in New Issue
Block a user