Pclanglais commited on
Commit
9e80596
·
verified ·
1 Parent(s): f2019a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -169
app.py CHANGED
@@ -13,176 +13,99 @@ import pandas as pd
13
  from chromadb.config import Settings
14
  from chromadb.utils import embedding_functions
15
 
16
- device = "cuda:0"
17
- sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="intfloat/multilingual-e5-base", device = "cuda")
18
- client = chromadb.PersistentClient(path="mfs_vector")
19
- collection = client.get_collection(name="sp_expanded", embedding_function = sentence_transformer_ef)
20
 
 
 
 
21
 
22
  # Define the device
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
  #Define variables
25
  temperature=0.2
26
  max_new_tokens=1000
27
  top_p=0.92
28
  repetition_penalty=1.7
29
 
30
- model_name = "AgentPublic/Guillaume-Tell"
31
 
32
- llm = LLM(model_name, max_model_len=4096)
33
 
34
  #Vector search over the database
35
- def vector_search(collection, text):
36
-
37
- results = collection.query(
38
- query_texts=[text],
39
- n_results=5,
40
- )
41
-
42
- document = []
43
- document_html = []
44
- id_list = ""
45
- list_elm = 0
46
- for ids in results["ids"][0]:
47
- first_link = str(results["metadatas"][0][list_elm]["identifier"])
48
- first_title = results["documents"][0][list_elm]
49
- list_elm = list_elm+1
50
-
51
- document.append(first_link + " : " + first_title)
52
- document_html.append('<div class="source" id="' + first_link + '"><p><b>' + first_link + "</b> : " + first_title + "</div>")
53
-
54
- document = "\n\n".join(document)
55
- document_html = '<div id="source_listing">' + "".join(document_html) + "</div>"
56
- # Replace this with the actual implementation of the vector search
57
- return document, document_html
58
 
59
- #CSS for references formatting
60
- css = """
61
- .generation {
62
- margin-left:2em;
63
- margin-right:2em;
64
- size:1.2em;
65
- }
66
-
67
- :target {
68
- background-color: #CCF3DF; /* Change the text color to red */
69
- }
70
-
71
- .source {
72
- float:left;
73
- max-width:17%;
74
- margin-left:2%;
75
- }
76
-
77
- .tooltip {
78
- position: relative;
79
- cursor: pointer;
80
- font-variant-position: super;
81
- color: #97999b;
82
- }
83
-
84
- .tooltip:hover::after {
85
- content: attr(data-text);
86
- position: absolute;
87
- left: 0;
88
- top: 120%; /* Adjust this value as needed to control the vertical spacing between the text and the tooltip */
89
- white-space: pre-wrap; /* Allows the text to wrap */
90
- width: 500px; /* Sets a fixed maximum width for the tooltip */
91
- max-width: 500px; /* Ensures the tooltip does not exceed the maximum width */
92
- z-index: 1;
93
- background-color: #f9f9f9;
94
- color: #000;
95
- border: 1px solid #ddd;
96
- border-radius: 5px;
97
- padding: 5px;
98
- display: block;
99
- box-shadow: 0 4px 8px rgba(0,0,0,0.1); /* Optional: Adds a subtle shadow for better visibility */
100
- }"""
101
-
102
- #Curtesy of chatgpt
103
- def format_references(text):
104
- # Define start and end markers for the reference
105
- ref_start_marker = '<ref text="'
106
- ref_end_marker = '</ref>'
107
 
108
- # Initialize an empty list to hold parts of the text
109
- parts = []
110
- current_pos = 0
111
- ref_number = 1
112
 
113
- # Loop until no more reference start markers are found
114
- while True:
115
- start_pos = text.find(ref_start_marker, current_pos)
116
- if start_pos == -1:
117
- # No more references found, add the rest of the text
118
- parts.append(text[current_pos:])
119
- break
120
-
121
- # Add text up to the start of the reference
122
- parts.append(text[current_pos:start_pos])
123
-
124
- # Find the end of the reference text attribute
125
- end_pos = text.find('">', start_pos)
126
- if end_pos == -1:
127
- # Malformed reference, break to avoid infinite loop
128
- break
129
-
130
- # Extract the reference text
131
- ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip()
132
- ref_text_encoded = ref_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
133
-
134
- # Find the end of the reference tag
135
- ref_end_pos = text.find(ref_end_marker, end_pos)
136
- if ref_end_pos == -1:
137
- # Malformed reference, break to avoid infinite loop
138
- break
139
-
140
- # Extract the reference ID
141
- ref_id = text[end_pos + 2:ref_end_pos].strip()
142
-
143
- # Create the HTML for the tooltip
144
- tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[' + str(ref_number) +']</a></span>'
145
- parts.append(tooltip_html)
146
-
147
- # Update current_pos to the end of the current reference
148
- current_pos = ref_end_pos + len(ref_end_marker)
149
- ref_number = ref_number + 1
150
 
151
- # Join and return the parts
152
- parts = ''.join(parts)
153
-
154
- return parts
155
-
156
- # Class to encapsulate the Falcon chatbot
157
- class MistralChatBot:
158
- def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
159
- self.system_prompt = system_prompt
160
-
161
- def predict(self, user_message):
162
- fiches, fiches_html = vector_search(collection, user_message)
163
- sampling_params = SamplingParams(temperature=.7, top_p=.95, max_tokens=2000, presence_penalty = 1.5, stop = ["``"])
164
- detailed_prompt = """<|im_start|>system
165
- Tu es Albert, le chatbot des Maisons France Service qui donne des réponses sourcées.<|im_end|>
166
- <|im_start|>user
167
- Ecrit un texte référencé en réponse à cette question : """ + user_message + """
168
-
169
- Les références doivent être citées de cette manière : texte rédigé<ref text=\"[passage pertinent dans la référence]\">[\"identifiant de la référence\"]</ref>Si les références ne permettent pas de répondre, qu'il n'y a pas de réponse.
170
-
171
- Les cinq références disponibles : """ + fiches + "<|im_end|>\n<|im_start|>assistant\n"
172
- print(detailed_prompt)
173
- prompts = [detailed_prompt]
174
- outputs = llm.generate(prompts, sampling_params, use_tqdm = False)
175
- generated_text = outputs[0].outputs[0].text
176
- generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
177
- fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
178
- return generated_text, fiches_html
179
-
180
- # Create the Falcon chatbot instance
181
- mistral_bot = MistralChatBot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  # Define the Gradio interface
184
- title = "Guillaume-Tell"
185
- description = "Le LLM répond à des questions administratives sur l'éducation nationale à partir de sources fiables."
186
  examples = [
187
  [
188
  "Qui peut bénéficier de l'AIP?", # user_message
@@ -190,27 +113,12 @@ examples = [
190
  ]
191
  ]
192
 
193
- additional_inputs=[
194
- gr.Slider(
195
- label="Température",
196
- value=0.2, # Default value
197
- minimum=0.05,
198
- maximum=1.0,
199
- step=0.05,
200
- interactive=True,
201
- info="Des valeurs plus élevées donne plus de créativité, mais aussi d'étrangeté",
202
- ),
203
- ]
204
 
205
  demo = gr.Blocks()
206
 
207
  with gr.Blocks(theme='JohnSmith9982/small_and_pretty', css=css) as demo:
208
- gr.HTML("""<h1 style="text-align:center">Albert (Guillaume-Tell)</h1>""")
209
- text_input = gr.Textbox(label="Votre question ou votre instruction.", type="text", lines=1)
210
- text_button = gr.Button("Interroger Albert")
211
- text_output = gr.HTML(label="La réponse d'Albert")
212
- embedding_output = gr.HTML(label="Les sources utilisées")
213
- text_button.click(mistral_bot.predict, inputs=text_input, outputs=[text_output, embedding_output])
214
 
215
  if __name__ == "__main__":
216
  demo.queue().launch()
 
13
  from chromadb.config import Settings
14
  from chromadb.utils import embedding_functions
15
 
16
+ model = BGEM3FlagModel('BAAI/bge-m3',
17
+ use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
 
 
18
 
19
+ embeddings = np.load("embeddings_with_api.npy")
20
+ embeddings_data = pd.read_json("embeddings_tchap.json")
21
+ embeddings_text = embeddings_data["text_with_context"].tolist()
22
 
23
  # Define the device
24
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
25
  #Define variables
26
  temperature=0.2
27
  max_new_tokens=1000
28
  top_p=0.92
29
  repetition_penalty=1.7
30
 
31
+ #model_name = "Pclanglais/Tchap"
32
 
33
+ #llm = LLM(model_name, max_model_len=4096)
34
 
35
  #Vector search over the database
36
+ def vector_search(sentence_query):
37
+
38
+ query_embedding = model.encode(sentence_query,
39
+ batch_size=12,
40
+ max_length=256, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
41
+ )['dense_vecs']
42
+
43
+ # Reshape the query embedding to fit the cosine_similarity function requirements
44
+ query_embedding_reshaped = query_embedding.reshape(1, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Compute cosine similarities
47
+ similarities = cosine_similarity(query_embedding_reshaped, embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Find the index of the closest document (highest similarity)
50
+ closest_doc_index = np.argmax(similarities)
 
 
51
 
52
+ # Closest document's embedding
53
+ closest_doc_embedding = sentences_1[closest_doc_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ return closest_doc_embedding
56
+
57
+
58
+ class StopOnTokens(StoppingCriteria):
59
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
60
+ stop_ids = [29, 0]
61
+ for stop_id in stop_ids:
62
+ if input_ids[0][-1] == stop_id:
63
+ return True
64
+ return False
65
+
66
+ def predict(message, history):
67
+ text = vector_search(message)
68
+ message = message + "\n\n### Source ###\n"
69
+ history_transformer_format = history + [[message, ""]]
70
+ stop = StopOnTokens()
71
+
72
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
73
+ for item in history_transformer_format])
74
+
75
+ return messages
76
+
77
+ def predict_alt(message, history):
78
+ history_transformer_format = history + [[message, ""]]
79
+ stop = StopOnTokens()
80
+
81
+ messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
82
+ for item in history_transformer_format])
83
+
84
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
85
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
86
+ generate_kwargs = dict(
87
+ model_inputs,
88
+ streamer=streamer,
89
+ max_new_tokens=1024,
90
+ do_sample=True,
91
+ top_p=0.95,
92
+ top_k=1000,
93
+ temperature=1.0,
94
+ num_beams=1,
95
+ stopping_criteria=StoppingCriteriaList([stop])
96
+ )
97
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
98
+ t.start()
99
+
100
+ partial_message = ""
101
+ for new_token in streamer:
102
+ if new_token != '<':
103
+ partial_message += new_token
104
+ yield partial_message
105
 
106
  # Define the Gradio interface
107
+ title = "Tchap"
108
+ description = "Le chatbot du service public"
109
  examples = [
110
  [
111
  "Qui peut bénéficier de l'AIP?", # user_message
 
113
  ]
114
  ]
115
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  demo = gr.Blocks()
118
 
119
  with gr.Blocks(theme='JohnSmith9982/small_and_pretty', css=css) as demo:
120
+ gr.HTML("""<h1 style="text-align:center">Albert-Tchap</h1>""")
121
+ gr.ChatInterface(predict).launch()
 
 
 
 
122
 
123
  if __name__ == "__main__":
124
  demo.queue().launch()