salsarra commited on
Commit
c903136
verified
1 Parent(s): caf1832

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tensorflow as tf
3
+ from tf_keras import models, layers
4
+ from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering, AutoModelForCausalLM
5
+ import gradio as gr
6
+ import re
7
+ import os
8
+
9
+ # Check if GPU is available and use it if possible
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ # Version Information:
13
+ confli_version_spanish = 'ConfliBERT-Spanish-Beto-Cased-NewsQA'
14
+ beto_version_spanish = 'Beto-Spanish-Cased-NewsQA'
15
+ gpt2_spanish_version = 'GPT-2-Small-Spanish'
16
+ bloom_spanish_version = 'BLOOM-1.7B'
17
+ beto_sqac_version_spanish = 'Beto-Spanish-Cased-SQAC'
18
+
19
+ # Load Spanish models and tokenizers
20
+ confli_model_spanish = 'salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA'
21
+ confli_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(confli_model_spanish)
22
+ confli_tokenizer_spanish = AutoTokenizer.from_pretrained(confli_model_spanish)
23
+
24
+ beto_model_spanish = 'salsarra/Beto-Spanish-Cased-NewsQA'
25
+ beto_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_model_spanish)
26
+ beto_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_model_spanish)
27
+
28
+ beto_sqac_model_spanish = 'salsarra/Beto-Spanish-Cased-SQAC'
29
+ beto_sqac_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_sqac_model_spanish)
30
+ beto_sqac_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_sqac_model_spanish)
31
+
32
+ # Load Spanish GPT-2 model and tokenizer
33
+ gpt2_spanish_model_name = 'datificate/gpt2-small-spanish'
34
+ gpt2_spanish_tokenizer = AutoTokenizer.from_pretrained(gpt2_spanish_model_name)
35
+ gpt2_spanish_model = AutoModelForCausalLM.from_pretrained(gpt2_spanish_model_name).to(device)
36
+
37
+ # Load BLOOM-1.7B model and tokenizer for Spanish
38
+ bloom_model_name = 'bigscience/bloom-1b7'
39
+ bloom_tokenizer = AutoTokenizer.from_pretrained(bloom_model_name)
40
+ bloom_model = AutoModelForCausalLM.from_pretrained(bloom_model_name).to(device)
41
+
42
+ def handle_error_message(e, default_limit=512):
43
+ error_message = str(e)
44
+ pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)")
45
+ match = pattern.search(error_message)
46
+ if match:
47
+ number_1, number_2 = match.groups()
48
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
49
+ pattern_qa = re.compile(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)")
50
+ match_qa = pattern_qa.search(error_message)
51
+ if match_qa:
52
+ number_1, number_2 = match_qa.groups()
53
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
54
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>"
55
+
56
+ # Spanish QA functions
57
+ def question_answering_spanish(context, question):
58
+ try:
59
+ inputs = confli_tokenizer_spanish(question, context, return_tensors='tf', truncation=True)
60
+ outputs = confli_model_spanish_qa(inputs)
61
+ answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
62
+ answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
63
+ answer = confli_tokenizer_spanish.convert_tokens_to_string(confli_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
64
+ return f"<span style='color: green; font-weight: bold;'>{answer}</span>"
65
+ except Exception as e:
66
+ return handle_error_message(e)
67
+
68
+ def beto_question_answering_spanish(context, question):
69
+ try:
70
+ inputs = beto_tokenizer_spanish(question, context, return_tensors='tf', truncation=True)
71
+ outputs = beto_model_spanish_qa(inputs)
72
+ answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
73
+ answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
74
+ answer = beto_tokenizer_spanish.convert_tokens_to_string(beto_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
75
+ return f"<span style='color: blue; font-weight: bold;'>{answer}</span>"
76
+ except Exception as e:
77
+ return handle_error_message(e)
78
+
79
+ def beto_sqac_question_answering_spanish(context, question):
80
+ try:
81
+ inputs = beto_sqac_tokenizer_spanish(question, context, return_tensors='tf', truncation=True)
82
+ outputs = beto_sqac_model_spanish_qa(inputs)
83
+ answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
84
+ answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
85
+ answer = beto_sqac_tokenizer_spanish.convert_tokens_to_string(beto_sqac_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
86
+ return f"<span style='color: brown; font-weight: bold;'>{answer}</span>"
87
+ except Exception as e:
88
+ return handle_error_message(e)
89
+
90
+ # Functions for Spanish GPT-2 and BLOOM-1.7B models
91
+ def gpt2_spanish_question_answering(context, question):
92
+ try:
93
+ prompt = f"Contexto:\n{context}\n\nPregunta:\n{question}\n\nRespuesta:"
94
+ inputs = gpt2_spanish_tokenizer(prompt, return_tensors='pt').to(device)
95
+ outputs = gpt2_spanish_model.generate(
96
+ inputs['input_ids'],
97
+ max_length=inputs['input_ids'].shape[1] + 50,
98
+ num_return_sequences=1,
99
+ pad_token_id=gpt2_spanish_tokenizer.eos_token_id,
100
+ do_sample=True,
101
+ top_k=40,
102
+ temperature=0.8
103
+ )
104
+ answer = gpt2_spanish_tokenizer.decode(outputs[0], skip_special_tokens=True)
105
+ answer = answer.split("Respuesta:")[-1].strip()
106
+ return f"<span style='color: orange; font-weight: bold;'>{answer}</span>"
107
+ except Exception as e:
108
+ return handle_error_message(e)
109
+
110
+ def bloom_question_answering(context, question):
111
+ try:
112
+ prompt = f"Contexto:\n{context}\n\nPregunta:\n{question}\n\nRespuesta:"
113
+ inputs = bloom_tokenizer(prompt, return_tensors='pt').to(device)
114
+ outputs = bloom_model.generate(
115
+ inputs['input_ids'],
116
+ max_length=inputs['input_ids'].shape[1] + 50,
117
+ num_return_sequences=1,
118
+ pad_token_id=bloom_tokenizer.eos_token_id,
119
+ do_sample=True,
120
+ top_k=40,
121
+ temperature=0.8
122
+ )
123
+ answer = bloom_tokenizer.decode(outputs[0], skip_special_tokens=True)
124
+ answer = answer.split("Respuesta:")[-1].strip()
125
+ return f"<span style='color: purple; font-weight: bold;'>{answer}</span>"
126
+ except Exception as e:
127
+ return handle_error_message(e)
128
+
129
+ # Main function for Spanish QA
130
+ def compare_question_answering_spanish(context, question):
131
+ confli_answer_spanish = question_answering_spanish(context, question)
132
+ beto_answer_spanish = beto_question_answering_spanish(context, question)
133
+ beto_sqac_answer_spanish = beto_sqac_question_answering_spanish(context, question)
134
+ gpt2_answer_spanish = gpt2_spanish_question_answering(context, question)
135
+ bloom_answer = bloom_question_answering(context, question)
136
+ return f"""
137
+ <div>
138
+ <h2 style='color: #2e8b57; font-weight: bold;'>Respuestas:</h2>
139
+ </div><br>
140
+ <div>
141
+ <strong>ConfliBERT-Spanish-Beto-Cased-NewsQA:</strong><br>{confli_answer_spanish}</div><br>
142
+ <div>
143
+ <strong>Beto-Spanish-Cased-NewsQA:</strong><br>{beto_answer_spanish}
144
+ </div><br>
145
+ <div>
146
+ <strong>Beto-Spanish-Cased-SQAC:</strong><br>{beto_sqac_answer_spanish}
147
+ </div><br>
148
+ <div>
149
+ <strong>GPT-2-Small-Spanish:</strong><br>{gpt2_answer_spanish}
150
+ </div><br>
151
+ <div>
152
+ <strong>BLOOM-1.7B:</strong><br>{bloom_answer}
153
+ </div><br>
154
+ <div>
155
+ <strong>Informaci贸n del modelo:</strong><br>
156
+ ConfliBERT-Spanish-Beto-Cased-NewsQA: <a href='https://huggingface.co/salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA' target='_blank'>salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA</a><br>
157
+ Beto-Spanish-Cased-NewsQA: <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-NewsQA' target='_blank'>salsarra/Beto-Spanish-Cased-NewsQA</a><br>
158
+ Beto-Spanish-Cased-SQAC: <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-SQAC' target='_blank'>salsarra/Beto-Spanish-Cased-SQAC</a><br>
159
+ GPT-2-Small-Spanish: <a href='https://huggingface.co/datificate/gpt2-small-spanish' target='_blank'>datificate GPT-2 Small Spanish</a><br>
160
+ BLOOM-1.7B: <a href='https://huggingface.co/bigscience/bloom-1b7' target='_blank'>bigscience BLOOM-1.7B</a><br>
161
+ </div>
162
+ """
163
+
164
+ # Define the CSS for Gradio interface
165
+ css_styles = """
166
+ body {
167
+ background-color: #f0f8ff;
168
+ font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
169
+ }
170
+ h1 a {
171
+ color: #2e8b57;
172
+ text-align: center;
173
+ font-size: 2em;
174
+ text-decoration: none;
175
+ }
176
+ h1 a:hover {
177
+ color: #ff8c00;
178
+ }
179
+ h2 {
180
+ color: #ff8c00;
181
+ text-align: center;
182
+ font-size: 1.5em;
183
+ }
184
+ .gradio-container {
185
+ max-width: 100%;
186
+ margin: 10px auto;
187
+ padding: 10px;
188
+ background-color: #ffffff;
189
+ border-radius: 10px;
190
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
191
+ }
192
+ .gr-input, .gr-output {
193
+ background-color: #ffffff;
194
+ border: 1px solid #ddd;
195
+ border-radius: 5px;
196
+ padding: 10px;
197
+ font-size: 1em;
198
+ }
199
+ .gr-title {
200
+ font-size: 1.5em;
201
+ font-weight: bold;
202
+ color: #2e8b57;
203
+ margin-bottom: 10px;
204
+ text-align: center;
205
+ }
206
+ .gr-description {
207
+ font-size: 1.2em;
208
+ color: #ff8c00;
209
+ margin-bottom: 10px;
210
+ text-align: center.
211
+ }
212
+ .header-title-center a {
213
+ font-size: 4em;
214
+ font-weight: bold;
215
+ color: darkorange;
216
+ text-align: center;
217
+ display: block.
218
+ }
219
+ .gr-button {
220
+ background-color: #ff8c00;
221
+ color: white;
222
+ border: none;
223
+ padding: 10px 20px;
224
+ font-size: 1em.
225
+ border-radius: 5px;
226
+ cursor: pointer.
227
+ }
228
+ .gr-button:hover {
229
+ background-color: #ff4500.
230
+ }
231
+ .footer {
232
+ text-align: center.
233
+ margin-top: 10px.
234
+ font-size: 0.9em.
235
+ color: #666.
236
+ width: 100%.
237
+ }
238
+ .footer a {
239
+ color: #2e8b57.
240
+ font-weight: bold.
241
+ text-decoration: none.
242
+ }
243
+ .footer a:hover {
244
+ text-decoration: underline.
245
+ }
246
+ """
247
+
248
+ # Define the Gradio interface
249
+ demo = gr.Interface(
250
+ fn=compare_question_answering_spanish,
251
+ inputs=[
252
+ gr.Textbox(lines=5, placeholder="Ingrese el contexto aqu铆...", label="Contexto"),
253
+ gr.Textbox(lines=2, placeholder="Ingrese su pregunta aqu铆...", label="Pregunta")
254
+ ],
255
+ outputs=gr.HTML(label="Salida"),
256
+ title="<a href='https://eventdata.utdallas.edu/conflibert/' target='_blank'>ConfliBERT-Spanish-QA</a>",
257
+ description="Compare respuestas entre los modelos ConfliBERT, BETO, Beto SQAC, GPT-2 Small Spanish y BLOOM-1.7B para preguntas en espa帽ol.",
258
+ css=css_styles,
259
+ allow_flagging="never"
260
+ )
261
+
262
+ # Launch the Gradio demo
263
+ demo.launch(share=True)