osiria's picture
Update app.py
b34b04d
import os
import subprocess
import sys
import gradio as gr
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
install("numpy")
install("torch")
install("transformers")
install("unidecode")
import numpy as np
import torch
from transformers import DebertaV2TokenizerFast, DebertaV2ForQuestionAnswering
from transformers.pipelines import QuestionAnsweringPipeline
from transformers import pipeline
from collections import Counter
from unidecode import unidecode
import re
import string
tokenizer = DebertaV2TokenizerFast.from_pretrained("osiria/deberta-italian-question-answering")
model = DebertaV2ForQuestionAnswering.from_pretrained("osiria/deberta-italian-question-answering")
class OsiriaQA(QuestionAnsweringPipeline):
def __init__(self, punctuation = ',;.:!?()[\]{}', **kwargs):
QuestionAnsweringPipeline.__init__(self, **kwargs)
self.post_regex_left = "^[\s" + punctuation + "]+"
self.post_regex_right = "[\s" + punctuation + "]+$"
def postprocess(self, output):
output = QuestionAnsweringPipeline.postprocess(self, model_outputs=output)
output_length = len(output["answer"])
output["answer"] = re.sub(self.post_regex_left, "", output["answer"])
output["start"] = output["start"] + (output_length - len(output["answer"]))
output_length = len(output["answer"])
output["answer"] = re.sub(self.post_regex_right, "", output["answer"])
output["end"] = output["end"] - (output_length - len(output["answer"]))
return output
device = torch.device("cpu")
model = model.to(device)
model.eval()
pipeline_qa = OsiriaQA(model = model, tokenizer = tokenizer)
header = '''--------------------------------------------------------------------------------------------------
<style>
.vertical-text {
writing-mode: vertical-lr;
text-orientation: upright;
background-color:red;
}
</style>
<center>
<body>
<span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span>
<span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> D</span>
<span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;">    E</span>
<span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;">    M</span>
<span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> O</span>
<span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span>
</body>
</center>
<br>
'''
def extract(question, context):
res = pipeline_qa(context = context,
question = question)
out_text = context[0:res["start"]] + '<span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴀɴs </b> ' + context[res["start"]:res["end"]] + '</span>' + context[res["end"]:]
return out_text
init_question= "Cos'è l'Agenzia Spaziale Italiana?"
init_context = '''L'Agenzia Spaziale Italiana (ASI) è un ente governativo italiano, istituito nel 1988, che ha il compito di predisporre e attuare la politica aerospaziale italiana. Dipende e utilizza i fondi ricevuti dal Governo italiano per finanziare il progetto, lo sviluppo e la gestione operativa di missioni spaziali, con obiettivi scientifici e applicativi.
Gestisce missioni spaziali in proprio e in collaborazione con i maggiori organismi spaziali internazionali, prima tra tutte l'Agenzia Spaziale Europea (dove l'Italia è il terzo maggior contribuente dopo Francia e Germania, e a cui l'ASI corrisponde una parte del proprio budget), quindi la NASA e le altre agenzie spaziali nazionali. Per la realizzazione di satelliti e strumenti scientifici, l'ASI stipula contratti con le imprese, italiane e non, operanti nel settore aerospaziale.
Ha la sede principale a Roma e centri operativi a Matera (sede del Centro di geodesia spaziale Giuseppe Colombo) e Malindi, Kenya (sede del Centro spaziale Luigi Broglio). Il centro di Trapani-Milo, usato per i lanci di palloni stratosferici dal 1975, non è più operativo dal 2010.'''
init_output = extract(question = init_question, context = init_context)
with gr.Blocks(css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface:
with gr.Row():
gr.Markdown(header)
with gr.Row():
context = gr.Text(label="Context", lines = 10, value = init_context)
with gr.Row():
question = gr.Text(label="Question", lines = 1, value = init_question)
with gr.Row():
gr.Examples([["Cosa fa l'Agenzia Spaziale Italiana?"],
["Qual è la sigla dell'Agenzia Spaziale Italiana?"],
["Quando è stata fondata l'ASI?"],
["Chi finanzia l'ASI?"],
["Chi altro contribuisce all'Agenzia Spaziale Europea oltre all'Italia?"],
["Dove ha sede l'Agenzia Spaziale Italiana?"],
["Dove si trova il centro spaziale Giuseppe Colombo?"],
["Dove si trova il centro spaziale Luigi Broglio?"],
["Il centro di Trapani-Milo è ancora in funzione?"]],
inputs=[question])
with gr.Row():
with gr.Column():
button = gr.Button("Ask").style(full_width=False)
with gr.Row():
with gr.Column():
output = gr.Markdown(init_output)
with gr.Row():
with gr.Column():
gr.Markdown("<center>The input examples in this demo are extracted from https://it.wikipedia.org</center>")
button.click(extract, inputs=[question, context], outputs = [output])
interface.launch()