Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import faiss | |
from huggingface_hub import InferenceClient | |
from transformers import AutoConfig, AutoModel, AutoTokenizer | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
embedding_model_name = "intfloat/multilingual-e5-large" | |
embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name) | |
embedding_model = AutoModel.from_pretrained(embedding_model_name) | |
def embed_texts(texts): | |
"""Generate embeddings for a list of texts.""" | |
inputs = embedding_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = embedding_model(**inputs) | |
# Use mean pooling to get embeddings | |
embeddings = torch.mean(outputs.last_hidden_state, dim=1) | |
return embeddings.numpy() | |
# Function to load the FAISS index and document mapping | |
def load_faiss_index_and_mapping(index_path="document_index.faiss", mapping_path="document_mapping.txt"): | |
"""Loads the FAISS index and document mapping from files.""" | |
faiss_index = faiss.read_index(index_path) # Load the FAISS index | |
document_mapping = {} # Dictionary to store document mapping | |
with open(mapping_path, "r") as f: | |
for line in f: | |
index, filename = line.strip().split("\t") | |
document_mapping[int(index)] = filename | |
return faiss_index, document_mapping | |
# Load the index and mapping | |
faiss_index, document_mapping = load_faiss_index_and_mapping() | |
# Function to load documents (keep or modify to load based on document_mapping) | |
def load_documents(document_mapping, folder_path="Data"): | |
"""Loads documents based on the document mapping.""" | |
documents = [] | |
for index in sorted(document_mapping.keys()): | |
filename = document_mapping[index] | |
file_path = os.path.join(folder_path, filename) | |
with open(file_path, "r", encoding="utf-8") as file: | |
documents.append(file.read()) | |
return documents | |
documents = load_documents(document_mapping) | |
print(f"Loaded {len(documents)} documents.") | |
secret = os.environ["API_TOKEN"] | |
client = InferenceClient(api_key=secret) | |
def generate_response(query, retrieved_docs): | |
"""Generate a response with streaming tokens using OpenVINO and TextIteratorStreamer.""" | |
context = " ".join(retrieved_docs) | |
prompt = ( | |
f"<s>Répondez à la question suivante de manière concise en utilisant uniquement les informations pertinentes du contexte fourni.\n\n" | |
f"Contexte : {context}\n\n" | |
f"Question : {query}\n\n" | |
f"Réponse :" | |
) | |
messages = [ | |
{"role": "system", "content": "Vous êtes un modèle de langage avancé en français, conçu pour fournir des réponses claires, complètes, grammaticalement correctes, et utiles, tout en restant courtois."}, | |
{ | |
"role": "user", | |
"content": prompt, | |
} | |
] | |
completion = client.chat.completions.create( | |
model="meta-llama/Llama-3.2-3B-Instruct", | |
messages=messages, | |
max_tokens=500, | |
) | |
return completion.choices[0].message.content | |
# 6. Query and Retrieve Relevant Documents | |
def retrieve_documents(query, k=3): | |
"""Retrieve the top-k most relevant documents.""" | |
query_embedding = embed_texts([query]) | |
distances, indices = faiss_index.search(query_embedding.astype('float32'), k) | |
return [documents[i] for i in indices[0]] | |
def rag_pipeline(query): | |
"""Complete RAG pipeline.""" | |
# Step 1: Retrieve relevant documents | |
relevant_docs = retrieve_documents(query, 1) | |
# Step 2: Generate a response using the retrieved documents | |
response = generate_response(query, relevant_docs) | |
print("Query:", query) | |
print("Response:", response) | |
return response | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Replace '*' with specific domains in production for security | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def generate(query: str = None): | |
if not query: | |
raise HTTPException(status_code=400, detail="Query parameter is required") | |
response = rag_pipeline(query) | |
return JSONResponse(content={"response": response}) | |