startupsearchdz / app.py
Tarikko's picture
Update app.py
ee1daeb verified
import os
import torch
import faiss
from huggingface_hub import InferenceClient
from transformers import AutoConfig, AutoModel, AutoTokenizer
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
embedding_model_name = "intfloat/multilingual-e5-large"
embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
embedding_model = AutoModel.from_pretrained(embedding_model_name)
def embed_texts(texts):
"""Generate embeddings for a list of texts."""
inputs = embedding_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = embedding_model(**inputs)
# Use mean pooling to get embeddings
embeddings = torch.mean(outputs.last_hidden_state, dim=1)
return embeddings.numpy()
# Function to load the FAISS index and document mapping
def load_faiss_index_and_mapping(index_path="document_index.faiss", mapping_path="document_mapping.txt"):
"""Loads the FAISS index and document mapping from files."""
faiss_index = faiss.read_index(index_path) # Load the FAISS index
document_mapping = {} # Dictionary to store document mapping
with open(mapping_path, "r") as f:
for line in f:
index, filename = line.strip().split("\t")
document_mapping[int(index)] = filename
return faiss_index, document_mapping
# Load the index and mapping
faiss_index, document_mapping = load_faiss_index_and_mapping()
# Function to load documents (keep or modify to load based on document_mapping)
def load_documents(document_mapping, folder_path="Data"):
"""Loads documents based on the document mapping."""
documents = []
for index in sorted(document_mapping.keys()):
filename = document_mapping[index]
file_path = os.path.join(folder_path, filename)
with open(file_path, "r", encoding="utf-8") as file:
documents.append(file.read())
return documents
documents = load_documents(document_mapping)
print(f"Loaded {len(documents)} documents.")
secret = os.environ["API_TOKEN"]
client = InferenceClient(api_key=secret)
def generate_response(query, retrieved_docs):
"""Generate a response with streaming tokens using OpenVINO and TextIteratorStreamer."""
context = " ".join(retrieved_docs)
prompt = (
f"<s>Répondez à la question suivante de manière concise en utilisant uniquement les informations pertinentes du contexte fourni.\n\n"
f"Contexte : {context}\n\n"
f"Question : {query}\n\n"
f"Réponse :"
)
messages = [
{"role": "system", "content": "Vous êtes un modèle de langage avancé en français, conçu pour fournir des réponses claires, complètes, grammaticalement correctes, et utiles, tout en restant courtois."},
{
"role": "user",
"content": prompt,
}
]
completion = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=messages,
max_tokens=500,
)
return completion.choices[0].message.content
# 6. Query and Retrieve Relevant Documents
def retrieve_documents(query, k=3):
"""Retrieve the top-k most relevant documents."""
query_embedding = embed_texts([query])
distances, indices = faiss_index.search(query_embedding.astype('float32'), k)
return [documents[i] for i in indices[0]]
def rag_pipeline(query):
"""Complete RAG pipeline."""
# Step 1: Retrieve relevant documents
relevant_docs = retrieve_documents(query, 1)
# Step 2: Generate a response using the retrieved documents
response = generate_response(query, relevant_docs)
print("Query:", query)
print("Response:", response)
return response
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Replace '*' with specific domains in production for security
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/generate")
async def generate(query: str = None):
if not query:
raise HTTPException(status_code=400, detail="Query parameter is required")
response = rag_pipeline(query)
return JSONResponse(content={"response": response})