85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
import os
|
|
import json
|
|
import numpy as np
|
|
import faiss
|
|
import httpx
|
|
from fastapi import HTTPException
|
|
from langchain_community.document_loaders import Docx2txtLoader
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from dotenv import load_dotenv
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Initialize the embedding model
|
|
MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
|
OPENROUTER_API_KEY = "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"
|
|
|
|
# Create necessary directories
|
|
os.makedirs("indices", exist_ok=True)
|
|
os.makedirs("chunks", exist_ok=True)
|
|
|
|
def docx_to_chunks(file_path):
|
|
pages = Docx2txtLoader(file_path).load()
|
|
chunks = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000, chunk_overlap=200
|
|
).split_documents(pages)
|
|
return [c.page_content for c in chunks]
|
|
|
|
def get_embeddings(texts):
|
|
# Generate embeddings using SentenceTransformer
|
|
embeddings = MODEL.encode(texts, convert_to_tensor=False)
|
|
return np.array(embeddings, dtype="float32")
|
|
|
|
def build_faiss_index(vectors):
|
|
dim = vectors.shape[1]
|
|
index = faiss.IndexFlatL2(dim)
|
|
index.add(vectors)
|
|
return index
|
|
|
|
def save_index(index, doc_id):
|
|
index_path = f"indices/{doc_id}.index"
|
|
faiss.write_index(index, index_path)
|
|
|
|
def save_chunks(chunks, doc_id):
|
|
chunks_path = f"chunks/{doc_id}.json"
|
|
with open(chunks_path, "w") as f:
|
|
json.dump(chunks, f)
|
|
|
|
def load_index(doc_id):
|
|
index_path = f"indices/{doc_id}.index"
|
|
if not os.path.exists(index_path):
|
|
raise HTTPException(status_code=404, detail=f"Index not found for doc_id: {doc_id}")
|
|
return faiss.read_index(index_path)
|
|
|
|
def load_chunks(doc_id):
|
|
chunks_path = f"chunks/{doc_id}.json"
|
|
if not os.path.exists(chunks_path):
|
|
raise HTTPException(status_code=404, detail=f"Chunks not found for doc_id: {doc_id}")
|
|
with open(chunks_path, "r") as f:
|
|
return json.load(f)
|
|
|
|
async def openrouter_chat(prompt):
|
|
headers = {
|
|
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
|
"HTTP-Referer": "https://yourapp.example",
|
|
"Content-Type": "application/json"
|
|
}
|
|
data = {
|
|
"model": "google/gemma-3-27b-it:free",
|
|
"messages": [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
"temperature": 0.3,
|
|
"max_tokens": 512
|
|
}
|
|
async with httpx.AsyncClient(timeout=120) as client:
|
|
r = await client.post(
|
|
"https://openrouter.ai/api/v1/chat/completions",
|
|
headers=headers,
|
|
json=data
|
|
)
|
|
r.raise_for_status()
|
|
return r.json()["choices"][0]["message"]["content"] |