tarikko commited on
Commit
606851f
·
1 Parent(s): e2ad586

Refactor app.py to integrate InferenceClient for response generation and update requirements.txt to include datasets

Browse files
Files changed (2) hide show
  1. app.py +21 -29
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
  import torch
3
- import numpy as np
4
  import faiss
5
- from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
 
 
6
  from fastapi import FastAPI, HTTPException
7
  from fastapi.responses import JSONResponse
8
 
@@ -15,7 +16,7 @@ embedding_model = AutoModel.from_pretrained(embedding_model_name)
15
 
16
  def embed_texts(texts):
17
  """Generate embeddings for a list of texts."""
18
- inputs = embedding_tokenizer('query: ' + texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
19
  with torch.no_grad():
20
  outputs = embedding_model(**inputs)
21
  # Use mean pooling to get embeddings
@@ -53,43 +54,35 @@ def load_documents(document_mapping, folder_path="Data"):
53
  documents = load_documents(document_mapping)
54
  print(f"Loaded {len(documents)} documents.")
55
 
56
- # Load your model and tokenizer
57
- generation_model_name = "Qwen/Qwen2.5-0.5B-Instruct"
58
-
59
- generation_tokenizer = AutoTokenizer.from_pretrained(generation_model_name)
60
- generation_model = AutoModelForCausalLM.from_pretrained(generation_model_name, trust_remote_code=True, device_map="cpu")
61
 
62
  def generate_response(query, retrieved_docs):
63
- """Generate a response using Flan-T5-Large based on retrieved documents."""
64
  context = " ".join(retrieved_docs)
65
- # More natural prompt
66
  prompt = (
67
  f"<s>Répondez à la question suivante de manière concise en utilisant uniquement les informations pertinentes du contexte fourni.\n\n"
68
  f"Contexte : {context}\n\n"
69
  f"Question : {query}\n\n"
70
  f"Réponse :"
71
  )
 
72
  messages = [
73
- {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
74
- {"role": "user", "content": prompt}
 
 
 
75
  ]
76
- text = generation_tokenizer.apply_chat_template(
77
- messages,
78
- tokenize=False,
79
- add_generation_prompt=True
80
- )
81
- model_inputs = generation_tokenizer([text], return_tensors="pt").to(generation_model.device)
82
 
83
- generated_ids = generation_model.generate(
84
- **model_inputs,
85
- max_new_tokens=512
 
86
  )
87
- generated_ids = [
88
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
89
- ]
90
 
91
- response = generation_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
92
- return response
93
  # 6. Query and Retrieve Relevant Documents
94
  def retrieve_documents(query, k=3):
95
  """Retrieve the top-k most relevant documents."""
@@ -101,11 +94,10 @@ def rag_pipeline(query):
101
  """Complete RAG pipeline."""
102
  # Step 1: Retrieve relevant documents
103
  relevant_docs = retrieve_documents(query, 1)
104
- print(f"Retrieved {len(relevant_docs)} relevant documents.")
105
- print(relevant_docs)
106
  # Step 2: Generate a response using the retrieved documents
107
  response = generate_response(query, relevant_docs)
108
-
 
109
  return response
110
 
111
  app = FastAPI()
 
1
  import os
2
  import torch
 
3
  import faiss
4
+ from huggingface_hub import InferenceClient
5
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
6
+
7
  from fastapi import FastAPI, HTTPException
8
  from fastapi.responses import JSONResponse
9
 
 
16
 
17
  def embed_texts(texts):
18
  """Generate embeddings for a list of texts."""
19
+ inputs = embedding_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
20
  with torch.no_grad():
21
  outputs = embedding_model(**inputs)
22
  # Use mean pooling to get embeddings
 
54
  documents = load_documents(document_mapping)
55
  print(f"Loaded {len(documents)} documents.")
56
 
57
+ secret = os.environ.getattribute("API_TOKEN")
58
+ client = InferenceClient(api_key=secret)
 
 
 
59
 
60
  def generate_response(query, retrieved_docs):
61
+ """Generate a response with streaming tokens using OpenVINO and TextIteratorStreamer."""
62
  context = " ".join(retrieved_docs)
 
63
  prompt = (
64
  f"<s>Répondez à la question suivante de manière concise en utilisant uniquement les informations pertinentes du contexte fourni.\n\n"
65
  f"Contexte : {context}\n\n"
66
  f"Question : {query}\n\n"
67
  f"Réponse :"
68
  )
69
+
70
  messages = [
71
+ {"role": "system", "content": "Vous êtes un modèle de langage avancé en français, conçu pour fournir des réponses claires, complètes, grammaticalement correctes, et utiles, tout en restant courtois."},
72
+ {
73
+ "role": "user",
74
+ "content": prompt,
75
+ }
76
  ]
 
 
 
 
 
 
77
 
78
+ completion = client.chat.completions.create(
79
+ model="meta-llama/Llama-3.2-3B-Instruct",
80
+ messages=messages,
81
+ max_tokens=500,
82
  )
 
 
 
83
 
84
+ return completion.choices[0].message.content
85
+
86
  # 6. Query and Retrieve Relevant Documents
87
  def retrieve_documents(query, k=3):
88
  """Retrieve the top-k most relevant documents."""
 
94
  """Complete RAG pipeline."""
95
  # Step 1: Retrieve relevant documents
96
  relevant_docs = retrieve_documents(query, 1)
 
 
97
  # Step 2: Generate a response using the retrieved documents
98
  response = generate_response(query, relevant_docs)
99
+ print("Query:", query)
100
+ print("Response:", response)
101
  return response
102
 
103
  app = FastAPI()
requirements.txt CHANGED
@@ -6,5 +6,5 @@ uvicorn
6
  numpy
7
  faiss-cpu
8
  faiss-gpu
9
- numpy
10
  accelerate
 
6
  numpy
7
  faiss-cpu
8
  faiss-gpu
9
+ datasets
10
  accelerate