Pclanglais commited on
Commit
0dfb412
1 Parent(s): 8252e5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -175
app.py CHANGED
@@ -1,5 +1,6 @@
 
1
  import re
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
3
  from vllm import LLM, SamplingParams
4
  import torch
5
  import gradio as gr
@@ -7,169 +8,82 @@ import json
7
  import os
8
  import shutil
9
  import requests
10
- import numpy as np
11
  import pandas as pd
12
- from threading import Thread
13
- from FlagEmbedding import BGEM3FlagModel
14
- from sklearn.metrics.pairwise import cosine_similarity
15
-
16
- from transformers import AutoModelForSequenceClassification
17
 
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- #Importing the embedding model
21
- embedding_model = BGEM3FlagModel('BAAI/bge-m3',
22
- use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
23
-
24
- embeddings = np.load("embeddings_albert_tchap.npy")
25
- embeddings_data = pd.read_json("embeddings_albert_tchap.json")
26
- embeddings_text = embeddings_data["text_with_context"].tolist()
27
-
28
- #Importing the classifier/router (deberta)
29
- classifier_model = AutoModelForSequenceClassification.from_pretrained("AgentPublic/chatrag-deberta")
30
- classifier_tokenizer = AutoTokenizer.from_pretrained("AgentPublic/chatrag-deberta")
31
-
32
- #Importing the actual generative LLM (llama-based)
33
- model_name = "Pclanglais/Tchap"
34
- tokenizer = AutoTokenizer.from_pretrained(model_name)
35
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
36
- model = model.to('cuda:0')
37
-
38
- system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
39
- source_text = "Les sources utilisées par Albert-Tchap vont apparaître ici'"
40
-
41
-
42
- #Function to guess whether we use the RAG or not.
43
- def classification_chatrag(query):
44
- print(query)
45
- encoding = classifier_tokenizer(query, return_tensors="pt")
46
- encoding = {k: v.to(classifier_model.device) for k,v in encoding.items()}
47
-
48
- outputs = classifier_model(**encoding)
49
-
50
- logits = outputs.logits
51
- logits.shape
52
-
53
- # apply sigmoid + threshold
54
- sigmoid = torch.nn.Sigmoid()
55
- probs = sigmoid(logits.squeeze().cpu())
56
- predictions = np.zeros(probs.shape)
57
-
58
- # Extract the float value from the tensor
59
- float_value = round(probs.item()*100)
60
-
61
- print(float_value)
62
-
63
- if float_value > 50:
64
- status = True
65
- print("We activate RAG")
66
- else:
67
- status = False
68
- print("We remove RAG")
69
- return status
70
-
71
- #Vector search over the database
72
- def vector_search(sentence_query):
73
-
74
- query_embedding = embedding_model.encode(sentence_query,
75
- batch_size=12,
76
- max_length=256, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
77
- )['dense_vecs']
78
-
79
- # Reshape the query embedding to fit the cosine_similarity function requirements
80
- query_embedding_reshaped = query_embedding.reshape(1, -1)
81
-
82
- # Compute cosine similarities
83
- similarities = cosine_similarity(query_embedding_reshaped, embeddings)
84
-
85
- # Find the index of the closest document (highest similarity)
86
- closest_doc_index = np.argmax(similarities)
87
-
88
- # Closest document's embedding
89
- closest_doc_embedding = embeddings_text[closest_doc_index]
90
-
91
- return closest_doc_embedding
92
-
93
-
94
- class StopOnTokens(StoppingCriteria):
95
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
96
- stop_ids = [29, 0]
97
- for stop_id in stop_ids:
98
- if input_ids[0][-1] == stop_id:
99
- return True
100
- return False
101
-
102
-
103
- def predict(history_transformer_format):
104
-
105
- print(history_transformer_format)
106
- stop = StopOnTokens()
107
-
108
- messages = []
109
- id_message = 1
110
- total_message = len(history_transformer_format)
111
- for item in history_transformer_format:
112
 
113
- #Once we target the ongoing post we add the source.
114
- if id_message == total_message:
115
- if assess_rag:
116
- question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0] + "\n\n### Source ###\n" + source_text
117
- else:
118
- question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
119
- else:
120
- question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
121
- answer = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+item[1]
122
- result = "".join([question, answer])
123
- messages.append(result)
124
- id_message = id_message + 1
125
 
126
- messages = "".join(messages)
127
-
128
- print(messages)
129
-
130
- messages = system_prompt + messages
131
-
132
- print(messages)
133
-
134
- model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
135
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
136
- generate_kwargs = dict(
137
- model_inputs,
138
- streamer=streamer,
139
- max_new_tokens=1024,
140
- do_sample=False,
141
- top_p=0.95,
142
- temperature=0.4,
143
- stopping_criteria=StoppingCriteriaList([stop])
144
- )
145
- t = Thread(target=model.generate, kwargs=generate_kwargs)
146
- t.start()
147
-
148
- history_transformer_format[-1][1] = ""
149
- for new_token in streamer:
150
- if new_token != '<':
151
- history_transformer_format[-1][1] += new_token
152
- yield history_transformer_format
153
-
154
- def user(message, history):
155
- global source_text
156
- global assess_rag
157
- #For now, we only query the vector database once, at the start.
158
- if len(history) == 0:
159
- assess_rag = classification_chatrag(message)
160
- if assess_rag:
161
- source_text = vector_search(message)
162
- else:
163
- source_text = "Albert-Tchap n'utilise pas de sources comme votre requête n'a pas l'air d'en recueillir."
164
 
165
- history_transformer_format = history + [[message, ""]]
166
-
167
- print(history_transformer_format)
168
- return "", history_transformer_format, source_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  # Define the Gradio interface
171
- title = "Tchap"
172
- description = "Le chatbot du service public"
173
  examples = [
174
  [
175
  "Qui peut bénéficier de l'AIP?", # user_message
@@ -177,26 +91,26 @@ examples = [
177
  ]
178
  ]
179
 
180
- with gr.Blocks() as demo:
181
- with gr.Row():
182
- with gr.Column(scale=2):
183
- gr.HTML("<h2>Chat</2>")
184
- chatbot = gr.Chatbot()
185
- msg = gr.Textbox()
186
- clear = gr.Button("Clear")
187
-
188
- history = gr.State()
189
-
190
- with gr.Column(scale=1):
191
- gr.HTML("<h2>Source utilisée</2>")
192
- user_output = gr.HTML() # To display the user's message
193
 
194
- msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot, user_output], queue=False).then(
195
- predict, chatbot, chatbot
196
- )
197
-
198
- clear.click(lambda: None, None, chatbot, queue=False)
199
 
 
 
 
 
 
 
200
 
201
- demo.queue()
202
- demo.launch()
 
1
+ import transformers
2
  import re
3
+ from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
4
  from vllm import LLM, SamplingParams
5
  import torch
6
  import gradio as gr
 
8
  import os
9
  import shutil
10
  import requests
11
+ import chromadb
12
  import pandas as pd
13
+ from chromadb.config import Settings
14
+ from chromadb.utils import embedding_functions
 
 
 
15
 
16
+ # Define the device
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
+ model_name = "PleIAs/OCRonos"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ llm = LLM(model_name, max_model_len=8128)
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ #CSS for references formatting
25
+ css = """
26
+ .generation {
27
+ margin-left:2em;
28
+ margin-right:2em;
29
+ size:1.2em;
30
+ }
31
+ :target {
32
+ background-color: #CCF3DF; /* Change the text color to red */
33
+ }
34
+ .source {
35
+ float:left;
36
+ max-width:17%;
37
+ margin-left:2%;
38
+ }
39
+ .tooltip {
40
+ position: relative;
41
+ cursor: pointer;
42
+ font-variant-position: super;
43
+ color: #97999b;
44
+ }
45
+
46
+ .tooltip:hover::after {
47
+ content: attr(data-text);
48
+ position: absolute;
49
+ left: 0;
50
+ top: 120%; /* Adjust this value as needed to control the vertical spacing between the text and the tooltip */
51
+ white-space: pre-wrap; /* Allows the text to wrap */
52
+ width: 500px; /* Sets a fixed maximum width for the tooltip */
53
+ max-width: 500px; /* Ensures the tooltip does not exceed the maximum width */
54
+ z-index: 1;
55
+ background-color: #f9f9f9;
56
+ color: #000;
57
+ border: 1px solid #ddd;
58
+ border-radius: 5px;
59
+ padding: 5px;
60
+ display: block;
61
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1); /* Optional: Adds a subtle shadow for better visibility */
62
+ }"""
63
+
64
+ #Curtesy of chatgpt
65
+
66
+ # Class to encapsulate the Falcon chatbot
67
+ class MistralChatBot:
68
+ def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
69
+ self.system_prompt = system_prompt
70
+
71
+ def predict(self, user_message):
72
+ sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"])
73
+ detailed_prompt = correction = f"### TEXT ###\n{user_message}\n\n### CORRECTION ###\n"
74
+ print(detailed_prompt)
75
+ prompts = [detailed_prompt]
76
+ outputs = llm.generate(prompts, sampling_params, use_tqdm = False)
77
+ generated_text = outputs[0].outputs[0].text
78
+ generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + generated_text + "</div>"
79
+ return generated_text
80
+
81
+ # Create the Falcon chatbot instance
82
+ mistral_bot = MistralChatBot()
83
 
84
  # Define the Gradio interface
85
+ title = "Correction d'OCR"
86
+ description = "Un outil expérimental de correction d'OCR basé sur des modèles de langue"
87
  examples = [
88
  [
89
  "Qui peut bénéficier de l'AIP?", # user_message
 
91
  ]
92
  ]
93
 
94
+ additional_inputs=[
95
+ gr.Slider(
96
+ label="Température",
97
+ value=0.2, # Default value
98
+ minimum=0.05,
99
+ maximum=1.0,
100
+ step=0.05,
101
+ interactive=True,
102
+ info="Des valeurs plus élevées donne plus de créativité, mais aussi d'étrangeté",
103
+ ),
104
+ ]
 
 
105
 
106
+ demo = gr.Blocks()
 
 
 
 
107
 
108
+ with gr.Blocks(theme='JohnSmith9982/small_and_pretty', css=css) as demo:
109
+ gr.HTML("""<h1 style="text-align:center">Correction d'OCR</h1>""")
110
+ text_input = gr.Textbox(label="Votre texte.", type="text", lines=1)
111
+ text_button = gr.Button("Corriger l'OCR")
112
+ text_output = gr.HTML(label="Le texte corrigé")
113
+ text_button.click(mistral_bot.predict, inputs=text_input, outputs=[text_output])
114
 
115
+ if __name__ == "__main__":
116
+ demo.queue().launch()