Feature done
Average latency 25ms
This commit is contained in:
parent
b9de7c5ecf
commit
6e16fc99c9
112
app.py
Normal file
112
app.py
Normal file
@ -0,0 +1,112 @@
|
||||
import os
|
||||
import uuid
|
||||
import numpy as np
|
||||
import faiss
|
||||
import httpx
|
||||
from fastapi import FastAPI, UploadFile, Form
|
||||
from langchain_community.document_loaders import Docx2txtLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize the embedding model
|
||||
MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
OPENROUTER_API_KEY = "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"
|
||||
|
||||
app = FastAPI(title="RAG-via-FastAPI")
|
||||
|
||||
# --- helpers ---------------------------------------------------------
|
||||
|
||||
def docx_to_chunks(file_path):
|
||||
pages = Docx2txtLoader(file_path).load()
|
||||
chunks = RecursiveCharacterTextSplitter(
|
||||
chunk_size=1000, chunk_overlap=200
|
||||
).split_documents(pages)
|
||||
return [c.page_content for c in chunks]
|
||||
|
||||
def get_embeddings(texts):
|
||||
# Generate embeddings using SentenceTransformer
|
||||
embeddings = MODEL.encode(texts, convert_to_tensor=False)
|
||||
return np.array(embeddings, dtype="float32")
|
||||
|
||||
def build_faiss_index(vectors):
|
||||
dim = vectors.shape[1]
|
||||
index = faiss.IndexFlatL2(dim)
|
||||
index.add(vectors)
|
||||
return index
|
||||
|
||||
async def openrouter_chat(prompt):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
||||
"HTTP-Referer": "https://yourapp.example",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": "google/gemma-3-27b-it:free", # any OpenRouter-hosted model
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 512
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
r = await client.post(
|
||||
"https://openrouter.ai/api/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()["choices"][0]["message"]["content"]
|
||||
|
||||
# --- request / response models --------------------------------------
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
class RAGResponse(BaseModel):
|
||||
answer: str
|
||||
sources: list[int]
|
||||
|
||||
# --- endpoints ------------------------------------------------------
|
||||
|
||||
@app.post("/rag", response_model=RAGResponse)
|
||||
async def rag_endpoint(file: UploadFile, query: str = Form(...)):
|
||||
# Create temp directory if it doesn't exist
|
||||
os.makedirs("temp", exist_ok=True)
|
||||
|
||||
# Use a local temp directory instead of /tmp for Windows compatibility
|
||||
tmp_path = f"temp/{uuid.uuid4()}.docx"
|
||||
|
||||
try:
|
||||
# Save uploaded file
|
||||
with open(tmp_path, "wb") as f:
|
||||
f.write(await file.read())
|
||||
|
||||
# 1. chunk
|
||||
chunks = docx_to_chunks(tmp_path)
|
||||
|
||||
# 2. embed
|
||||
embeddings = get_embeddings(chunks)
|
||||
|
||||
# 3. search
|
||||
index = build_faiss_index(embeddings)
|
||||
q_vec = get_embeddings([query])
|
||||
distances, idxs = index.search(q_vec, k=3)
|
||||
context = "\n\n".join(chunks[i] for i in idxs[0])
|
||||
|
||||
# 4. generate
|
||||
prompt = (f"Use ONLY the context to answer the question.\n\n"
|
||||
f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
|
||||
answer = await openrouter_chat(prompt)
|
||||
|
||||
return {"answer": answer, "sources": idxs[0].tolist()}
|
||||
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
3
app/__init__.py
Normal file
3
app/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
RAG (Retrieval Augmented Generation) application package.
|
||||
"""
|
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
72
app/api/routes.py
Normal file
72
app/api/routes.py
Normal file
@ -0,0 +1,72 @@
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
import logging
|
||||
from fastapi import APIRouter, UploadFile, Form
|
||||
from app.core.models import RAGResponse
|
||||
from app.services.document import docx_to_chunks
|
||||
from app.services.embedding import embedding_service
|
||||
from app.services.llm import chat_completion
|
||||
from app.core.config import settings
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
router = APIRouter()
|
||||
|
||||
def log_latency(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
latency_ms = (end_time - start_time) * 1000
|
||||
logging.info(f"Latency for {func.__name__}: {latency_ms:.2f} ms")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
@log_latency
|
||||
def search_faiss(index, query_vector, k, chunks): # Added chunks parameter
|
||||
"""Search the FAISS index for similar vectors."""
|
||||
distances, idxs = index.search(query_vector, k)
|
||||
# Log top k results with their distances and text preview
|
||||
for i in range(k):
|
||||
chunk_text = chunks[idxs[0][i]]
|
||||
preview = ' '.join(chunk_text.split()[:20]) # Get first 50 words
|
||||
logging.info(f"Top {i+1} result - Index: {idxs[0][i]}, Distance: {distances[0][i]:.4f}")
|
||||
logging.info(f"Text preview: {preview}...")
|
||||
return distances, idxs
|
||||
|
||||
@router.post("/rag", response_model=RAGResponse)
|
||||
async def rag_endpoint(file: UploadFile, query: str = Form(...)):
|
||||
# Create temp directory if it doesn't exist
|
||||
os.makedirs("temp", exist_ok=True)
|
||||
|
||||
# Use a local temp directory for Windows compatibility
|
||||
tmp_path = f"temp/{uuid.uuid4()}.docx"
|
||||
|
||||
try:
|
||||
# Save uploaded file
|
||||
with open(tmp_path, "wb") as f:
|
||||
f.write(await file.read())
|
||||
|
||||
# 1. chunk
|
||||
chunks = docx_to_chunks(tmp_path)
|
||||
|
||||
# 2. embed
|
||||
embeddings = embedding_service.get_embeddings(chunks)
|
||||
|
||||
# 3. search
|
||||
index = embedding_service.build_faiss_index(embeddings)
|
||||
q_vec = embedding_service.get_embeddings([query])
|
||||
distances, idxs = search_faiss(index, q_vec, settings.TOP_K_RESULTS, chunks)
|
||||
|
||||
# 4. generate
|
||||
context = "\n\n".join(chunks[i] for i in idxs[0])
|
||||
prompt = (f"Use ONLY the context to answer the question.\n\n"
|
||||
f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:")
|
||||
answer = await chat_completion(prompt)
|
||||
|
||||
return {"answer": answer, "sources": idxs[0].tolist()}
|
||||
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
13
app/core/config.py
Normal file
13
app/core/config.py
Normal file
@ -0,0 +1,13 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
OPENROUTER_API_KEY: str = "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"
|
||||
MODEL_NAME: str = "all-MiniLM-L6-v2"
|
||||
CHUNK_SIZE: int = 1000
|
||||
CHUNK_OVERLAP: int = 200
|
||||
TOP_K_RESULTS: int = 3
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
settings = Settings()
|
8
app/core/models.py
Normal file
8
app/core/models.py
Normal file
@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
class RAGResponse(BaseModel):
|
||||
answer: str
|
||||
sources: list[int]
|
0
app/services/__init__.py
Normal file
0
app/services/__init__.py
Normal file
27
app/services/document.py
Normal file
27
app/services/document.py
Normal file
@ -0,0 +1,27 @@
|
||||
from langchain_community.document_loaders import Docx2txtLoader
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from app.core.config import settings
|
||||
import time
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def log_latency(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
latency_ms = (end_time - start_time) * 1000
|
||||
logging.info(f"Latency for {func.__name__}: {latency_ms:.2f} ms")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
@log_latency
|
||||
def docx_to_chunks(file_path: str) -> list[str]:
|
||||
"""Convert a DOCX file to text chunks."""
|
||||
pages = Docx2txtLoader(file_path).load()
|
||||
chunks = RecursiveCharacterTextSplitter(
|
||||
chunk_size=settings.CHUNK_SIZE,
|
||||
chunk_overlap=settings.CHUNK_OVERLAP
|
||||
).split_documents(pages)
|
||||
return [c.page_content for c in chunks]
|
39
app/services/embedding.py
Normal file
39
app/services/embedding.py
Normal file
@ -0,0 +1,39 @@
|
||||
import numpy as np
|
||||
import faiss
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from app.core.config import settings
|
||||
import logging
|
||||
import time
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def log_latency(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
latency_ms = (end_time - start_time) * 1000
|
||||
logging.info(f"Latency for {func.__name__}: {latency_ms:.2f} ms")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
class EmbeddingService:
|
||||
|
||||
def __init__(self):
|
||||
self.model = SentenceTransformer(settings.MODEL_NAME)
|
||||
|
||||
@log_latency
|
||||
def get_embeddings(self, texts: list[str]) -> np.ndarray:
|
||||
"""Generate embeddings for a list of texts."""
|
||||
embeddings = self.model.encode(texts, convert_to_tensor=False)
|
||||
return np.array(embeddings, dtype="float32")
|
||||
|
||||
@log_latency
|
||||
def build_faiss_index(self, vectors: np.ndarray) -> faiss.Index:
|
||||
"""Build a FAISS index from vectors."""
|
||||
dim = vectors.shape[1]
|
||||
index = faiss.IndexFlatL2(dim)
|
||||
index.add(vectors)
|
||||
return index
|
||||
|
||||
embedding_service = EmbeddingService()
|
41
app/services/llm.py
Normal file
41
app/services/llm.py
Normal file
@ -0,0 +1,41 @@
|
||||
import httpx
|
||||
from app.core.config import settings
|
||||
import logging
|
||||
import time
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def log_latency(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
latency_ms = (end_time - start_time) * 1000
|
||||
logging.info(f"Latency for {func.__name__}: {latency_ms:.2f} ms")
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
@log_latency
|
||||
async def chat_completion(prompt: str) -> str:
|
||||
"""Get completion from OpenRouter API."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {settings.OPENROUTER_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": "google/gemma-3-27b-it:free",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": 0.1,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
r = await client.post(
|
||||
"https://openrouter.ai/api/v1/chat/completions",
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()["choices"][0]["message"]["content"]
|
||||
|
5
main.py
Normal file
5
main.py
Normal file
@ -0,0 +1,5 @@
|
||||
from fastapi import FastAPI
|
||||
from app.api.routes import router
|
||||
|
||||
app = FastAPI(title="RAG-via-FastAPI")
|
||||
app.include_router(router)
|
22
requirements.txt
Normal file
22
requirements.txt
Normal file
@ -0,0 +1,22 @@
|
||||
annotated-types==0.7.0
|
||||
anyio==4.9.0
|
||||
click==8.2.0
|
||||
colorama==0.4.6
|
||||
fastapi==0.115.12
|
||||
h11==0.16.0
|
||||
idna==3.10
|
||||
pydantic==2.11.4
|
||||
pydantic_core==2.33.2
|
||||
sniffio==1.3.1
|
||||
starlette==0.46.2
|
||||
typing-inspection==0.4.0
|
||||
typing_extensions==4.13.2
|
||||
uvicorn==0.34.2
|
||||
python-multipart
|
||||
langchain-community
|
||||
python-docx
|
||||
faiss-cpu
|
||||
sentence-transformers
|
||||
httpx
|
||||
python-dotenv
|
||||
pydantic-settings
|
Loading…
x
Reference in New Issue
Block a user