Spaces:
Sleeping
Sleeping
tarikko
commited on
Commit
·
606851f
1
Parent(s):
e2ad586
Refactor app.py to integrate InferenceClient for response generation and update requirements.txt to include datasets
Browse files- app.py +21 -29
- requirements.txt +1 -1
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
-
import numpy as np
|
4 |
import faiss
|
5 |
-
from
|
|
|
|
|
6 |
from fastapi import FastAPI, HTTPException
|
7 |
from fastapi.responses import JSONResponse
|
8 |
|
@@ -15,7 +16,7 @@ embedding_model = AutoModel.from_pretrained(embedding_model_name)
|
|
15 |
|
16 |
def embed_texts(texts):
|
17 |
"""Generate embeddings for a list of texts."""
|
18 |
-
inputs = embedding_tokenizer(
|
19 |
with torch.no_grad():
|
20 |
outputs = embedding_model(**inputs)
|
21 |
# Use mean pooling to get embeddings
|
@@ -53,43 +54,35 @@ def load_documents(document_mapping, folder_path="Data"):
|
|
53 |
documents = load_documents(document_mapping)
|
54 |
print(f"Loaded {len(documents)} documents.")
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
generation_tokenizer = AutoTokenizer.from_pretrained(generation_model_name)
|
60 |
-
generation_model = AutoModelForCausalLM.from_pretrained(generation_model_name, trust_remote_code=True, device_map="cpu")
|
61 |
|
62 |
def generate_response(query, retrieved_docs):
|
63 |
-
"""Generate a response
|
64 |
context = " ".join(retrieved_docs)
|
65 |
-
# More natural prompt
|
66 |
prompt = (
|
67 |
f"<s>Répondez à la question suivante de manière concise en utilisant uniquement les informations pertinentes du contexte fourni.\n\n"
|
68 |
f"Contexte : {context}\n\n"
|
69 |
f"Question : {query}\n\n"
|
70 |
f"Réponse :"
|
71 |
)
|
|
|
72 |
messages = [
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
75 |
]
|
76 |
-
text = generation_tokenizer.apply_chat_template(
|
77 |
-
messages,
|
78 |
-
tokenize=False,
|
79 |
-
add_generation_prompt=True
|
80 |
-
)
|
81 |
-
model_inputs = generation_tokenizer([text], return_tensors="pt").to(generation_model.device)
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
86 |
)
|
87 |
-
generated_ids = [
|
88 |
-
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
89 |
-
]
|
90 |
|
91 |
-
|
92 |
-
|
93 |
# 6. Query and Retrieve Relevant Documents
|
94 |
def retrieve_documents(query, k=3):
|
95 |
"""Retrieve the top-k most relevant documents."""
|
@@ -101,11 +94,10 @@ def rag_pipeline(query):
|
|
101 |
"""Complete RAG pipeline."""
|
102 |
# Step 1: Retrieve relevant documents
|
103 |
relevant_docs = retrieve_documents(query, 1)
|
104 |
-
print(f"Retrieved {len(relevant_docs)} relevant documents.")
|
105 |
-
print(relevant_docs)
|
106 |
# Step 2: Generate a response using the retrieved documents
|
107 |
response = generate_response(query, relevant_docs)
|
108 |
-
|
|
|
109 |
return response
|
110 |
|
111 |
app = FastAPI()
|
|
|
1 |
import os
|
2 |
import torch
|
|
|
3 |
import faiss
|
4 |
+
from huggingface_hub import InferenceClient
|
5 |
+
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
6 |
+
|
7 |
from fastapi import FastAPI, HTTPException
|
8 |
from fastapi.responses import JSONResponse
|
9 |
|
|
|
16 |
|
17 |
def embed_texts(texts):
|
18 |
"""Generate embeddings for a list of texts."""
|
19 |
+
inputs = embedding_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
20 |
with torch.no_grad():
|
21 |
outputs = embedding_model(**inputs)
|
22 |
# Use mean pooling to get embeddings
|
|
|
54 |
documents = load_documents(document_mapping)
|
55 |
print(f"Loaded {len(documents)} documents.")
|
56 |
|
57 |
+
secret = os.environ.getattribute("API_TOKEN")
|
58 |
+
client = InferenceClient(api_key=secret)
|
|
|
|
|
|
|
59 |
|
60 |
def generate_response(query, retrieved_docs):
|
61 |
+
"""Generate a response with streaming tokens using OpenVINO and TextIteratorStreamer."""
|
62 |
context = " ".join(retrieved_docs)
|
|
|
63 |
prompt = (
|
64 |
f"<s>Répondez à la question suivante de manière concise en utilisant uniquement les informations pertinentes du contexte fourni.\n\n"
|
65 |
f"Contexte : {context}\n\n"
|
66 |
f"Question : {query}\n\n"
|
67 |
f"Réponse :"
|
68 |
)
|
69 |
+
|
70 |
messages = [
|
71 |
+
{"role": "system", "content": "Vous êtes un modèle de langage avancé en français, conçu pour fournir des réponses claires, complètes, grammaticalement correctes, et utiles, tout en restant courtois."},
|
72 |
+
{
|
73 |
+
"role": "user",
|
74 |
+
"content": prompt,
|
75 |
+
}
|
76 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
+
completion = client.chat.completions.create(
|
79 |
+
model="meta-llama/Llama-3.2-3B-Instruct",
|
80 |
+
messages=messages,
|
81 |
+
max_tokens=500,
|
82 |
)
|
|
|
|
|
|
|
83 |
|
84 |
+
return completion.choices[0].message.content
|
85 |
+
|
86 |
# 6. Query and Retrieve Relevant Documents
|
87 |
def retrieve_documents(query, k=3):
|
88 |
"""Retrieve the top-k most relevant documents."""
|
|
|
94 |
"""Complete RAG pipeline."""
|
95 |
# Step 1: Retrieve relevant documents
|
96 |
relevant_docs = retrieve_documents(query, 1)
|
|
|
|
|
97 |
# Step 2: Generate a response using the retrieved documents
|
98 |
response = generate_response(query, relevant_docs)
|
99 |
+
print("Query:", query)
|
100 |
+
print("Response:", response)
|
101 |
return response
|
102 |
|
103 |
app = FastAPI()
|
requirements.txt
CHANGED
@@ -6,5 +6,5 @@ uvicorn
|
|
6 |
numpy
|
7 |
faiss-cpu
|
8 |
faiss-gpu
|
9 |
-
|
10 |
accelerate
|
|
|
6 |
numpy
|
7 |
faiss-cpu
|
8 |
faiss-gpu
|
9 |
+
datasets
|
10 |
accelerate
|