import os import torch import faiss from huggingface_hub import InferenceClient from transformers import AutoConfig, AutoModel, AutoTokenizer from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware embedding_model_name = "intfloat/multilingual-e5-large" embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_name) embedding_model = AutoModel.from_pretrained(embedding_model_name) def embed_texts(texts): """Generate embeddings for a list of texts.""" inputs = embedding_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): outputs = embedding_model(**inputs) # Use mean pooling to get embeddings embeddings = torch.mean(outputs.last_hidden_state, dim=1) return embeddings.numpy() # Function to load the FAISS index and document mapping def load_faiss_index_and_mapping(index_path="document_index.faiss", mapping_path="document_mapping.txt"): """Loads the FAISS index and document mapping from files.""" faiss_index = faiss.read_index(index_path) # Load the FAISS index document_mapping = {} # Dictionary to store document mapping with open(mapping_path, "r") as f: for line in f: index, filename = line.strip().split("\t") document_mapping[int(index)] = filename return faiss_index, document_mapping # Load the index and mapping faiss_index, document_mapping = load_faiss_index_and_mapping() # Function to load documents (keep or modify to load based on document_mapping) def load_documents(document_mapping, folder_path="Data"): """Loads documents based on the document mapping.""" documents = [] for index in sorted(document_mapping.keys()): filename = document_mapping[index] file_path = os.path.join(folder_path, filename) with open(file_path, "r", encoding="utf-8") as file: documents.append(file.read()) return documents documents = load_documents(document_mapping) print(f"Loaded {len(documents)} documents.") secret = os.environ["API_TOKEN"] client = InferenceClient(api_key=secret) def generate_response(query, retrieved_docs): """Generate a response with streaming tokens using OpenVINO and TextIteratorStreamer.""" context = " ".join(retrieved_docs) prompt = ( f"Répondez à la question suivante de manière concise en utilisant uniquement les informations pertinentes du contexte fourni.\n\n" f"Contexte : {context}\n\n" f"Question : {query}\n\n" f"Réponse :" ) messages = [ {"role": "system", "content": "Vous êtes un modèle de langage avancé en français, conçu pour fournir des réponses claires, complètes, grammaticalement correctes, et utiles, tout en restant courtois."}, { "role": "user", "content": prompt, } ] completion = client.chat.completions.create( model="meta-llama/Llama-3.2-3B-Instruct", messages=messages, max_tokens=500, ) return completion.choices[0].message.content # 6. Query and Retrieve Relevant Documents def retrieve_documents(query, k=3): """Retrieve the top-k most relevant documents.""" query_embedding = embed_texts([query]) distances, indices = faiss_index.search(query_embedding.astype('float32'), k) return [documents[i] for i in indices[0]] def rag_pipeline(query): """Complete RAG pipeline.""" # Step 1: Retrieve relevant documents relevant_docs = retrieve_documents(query, 1) # Step 2: Generate a response using the retrieved documents response = generate_response(query, relevant_docs) print("Query:", query) print("Response:", response) return response app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # Replace '*' with specific domains in production for security allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/generate") async def generate(query: str = None): if not query: raise HTTPException(status_code=400, detail="Query parameter is required") response = rag_pipeline(query) return JSONResponse(content={"response": response})