import os
import gradio as gr
from gradio.components import Label
import subprocess
import sys
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
install("numpy")
install("transformers")
install("torch")
install("tensorflow")
install("tensorflow-text")
install("tensorflow-hub")
import tensorflow_hub as hub
import tensorflow_text
import tensorflow as tf
import torch
from transformers import AutoTokenizer
from transformers import BertForTokenClassification
import numpy as np
import re
auth_token = os.environ.get("AUTH-TOKEN")
header = '''--------------------------------------------------------------------------------------------------
D
E
M
O
(BETA)
--------------------------------------------------------------------------------------------------'''
model1 = hub.load("https://tfhub.dev/google/universal-sentence-encoder-multilingual-large/3")
tokenizer = AutoTokenizer.from_pretrained("osiria/mbert-base-cased-pos-it", use_auth_token=auth_token)
model = BertForTokenClassification.from_pretrained("osiria/mbert-base-cased-pos-it", num_labels = 17, use_auth_token=auth_token)
model.eval()
from transformers import pipeline
pos = pipeline('ner', model=model, tokenizer=tokenizer, device=-1)
def classify(text, classes):
text = text[:10000]
res = pos(text, aggregation_strategy = "first")
text = " ".join([r["word"] for r in res if r["entity_group"] in ["AGGETTIVO", "NOME", "NOME PROPRIO"]])
classes = {el.split(":")[0].strip(): el.split(":")[1].strip() for el in classes.split("\n")}
t_vec = model1(text).numpy()
t_vec = t_vec/np.linalg.norm(t_vec)
t_vec = t_vec.reshape(-1, 1)
cl_vecs = model1(["L'argomento di cui parliamo è quindi: " + re.sub("\s+", " ", classes[cl].lower().replace(",", " ")).strip() for cl in classes]).numpy()
cl_vecs = cl_vecs/np.sqrt(np.sum(cl_vecs**2, axis = 1).reshape(-1,1))
scores = [el[0] for el in np.dot(cl_vecs, t_vec).tolist()]
scores = np.array([s if s > 0 else 0 for s in scores])
scores = (scores/np.sum(scores)).tolist()
classes = list(classes.keys())
output = sorted(classes, key = lambda cl: scores[classes.index(cl)], reverse = True)
scores = sorted(scores, reverse=True)
out = {tpl[0].capitalize(): tpl[1] for tpl in list(zip(output, scores))}
return out
init_text = '''L'Agenzia spaziale europea, nota internazionalmente con l'acronimo ESA dalla denominazione inglese European Space Agency, è un'agenzia internazionale fondata nel 1975 incaricata di coordinare i progetti spaziali di 22 Paesi europei. Il suo quartier generale si trova a Parigi in Francia, con uffici a Mosca, Bruxelles, Washington e Houston. Il personale dell'ESA del 2016 ammontava a 2 200 persone (esclusi sub-appaltatori e le agenzie nazionali) e il budget del 2022 è di 7,15 miliardi di euro. Attualmente il direttore generale dell'agenzia è l'austriaco Josef Aschbacher, il quale ha sostituito il tedesco Johann-Dietrich Wörner il primo marzo 2021.
Lo spazioporto dell'ESA è il Centre Spatial Guyanais a Kourou, nella Guyana francese, un sito scelto, come tutte le basi di lancio, per via della sua vicinanza con l'equatore. Durante gli ultimi anni il lanciatore Ariane 5 ha consentito all'ESA di raggiungere una posizione di primo piano nei lanci commerciali e l'ESA è il principale concorrente della NASA nell'esplorazione spaziale.
Le missioni scientifiche dell'ESA hanno le loro basi al Centro europeo per la ricerca e la tecnologia spaziale (ESTEC) di Noordwijk, nei Paesi Bassi. Il Centro europeo per le operazioni spaziali (ESOC), di Darmstadt in Germania, è responsabile del controllo dei satelliti ESA in orbita. Le responsabilità del Centro europeo per l'osservazione della Terra (ESRIN) di Frascati, in Italia, includono la raccolta, l'archiviazione e la distribuzione di dati satellitari ai partner dell'ESA; oltre a ciò, la struttura agisce come centro di informazione tecnologica per l'intera agenzia.
'''
init_classes = '''alimentazione: alimentazione, cibo, agricoltura, allevamento, nutrizione
arte: arte, pittura, scultura, moda
animali: animali, zoologia, botanica, piante
ambiente: ambiente, clima, sostenibilità, ecologia, inquinamento
economia: aziende, banche, economia, finanza, borsa
filosofia: etica, filosofia, religione, teologia
geografia: città, regioni, nazioni, geografia, geologia
giustizia: giustizia, magistratura, reati, criminalità
musica: musica, cantanti, gruppi musicali, generi musicali
cinema: cinema, film, televisione, spettacolo
intrattenimento: intrattenimento, tempo libero, svago, videogiochi
letteratura: letteratura, romanzi, narrativa, poesia
medicina: medicina, salute, farmaci, malattie, patologie
governo: governo, legge, politica, partiti, settore pubblico
scienza: scienza, ingegneria, tecnologia
sport: competizioni, sport
guerra: guerra, conflitti, battaglie, tematiche militari
storia: eventi, storia
società: tematiche sociali, tematiche internazionali
trasporti: automobili, treni, aerei, trasporti, veicoli
informatica: computer, smartphone, applicazioni, internet, social networks'''
init_output = classify(init_text, init_classes)
with gr.Blocks(css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface:
with gr.Row():
gr.Markdown(header)
with gr.Row():
text = gr.Text(label="Write or paste a text", lines = 5, value = init_text)
with gr.Row():
gr.Examples([["Alessandro Manzoni, nome completo Alessandro Francesco Tommaso Antonio Manzoni (Milano, 7 marzo 1785 – Milano, 22 maggio 1873), è stato uno scrittore, poeta e drammaturgo italiano. Considerato uno dei maggiori romanzieri italiani di tutti i tempi per il suo celebre romanzo I promessi sposi, caposaldo della letteratura italiana, Manzoni ebbe il merito principale di aver gettato le basi per il romanzo moderno e di aver così patrocinato l'unità linguistica italiana, sulla scia di quella letteratura moralmente e civilmente impegnata propria dell'Illuminismo italiano."],
["Oggi sto male perchè ho la febbre"],
["Mi sono registrato su Facebook"],
["Stasera guardo qualcosa su Netflix"],
["La battaglia delle Termòpili, o delle Termòpile, fu una battaglia combattuta da un'alleanza di poleis greche, guidata dal re di Sparta Leonida I, contro l'Impero persiano governato da Serse I. Si svolse in tre giorni, durante la seconda invasione persiana della Grecia, nell'agosto o nel settembre del 480 a.C. presso lo stretto passaggio delle Termopili (o, più correttamente, Termopile, 'Le porte calde'), contemporaneamente alla battaglia navale di Capo Artemisio."],
["Ieri ho comprato l'Xbox One"],
["Domani per pranzo preparo la pasta alle vongole"],
["Ho appena ascoltato l'ultimo album dei Green Day"],
["Sono chiamati gas serra quei gas presenti nell'atmosfera che riescono a trattenere, in maniera consistente, una parte considerevole della componente nell'infrarosso della radiazione solare che colpisce la Terra ed è emessa dalla superficie terrestre, dall'atmosfera e dalle nuvole. Tale proprietà causa il fenomeno noto come 'effetto serra' ed è verificabile da un'analisi spettroscopica in laboratorio."]],
inputs=[text])
with gr.Row():
classes = gr.Text(label="Classes (write a few classes in the form 'class_name: word1, word2, word3...' using 1 to 5 descriptive words for each class)", lines = 1, value = '''alimentazione: alimentazione, cibo, agricoltura, allevamento, nutrizione
arte: arte, pittura, scultura, moda
animali: animali, zoologia, botanica, piante
ambiente: ambiente, clima, sostenibilità, ecologia, inquinamento
economia: aziende, banche, economia, finanza, borsa
filosofia: etica, filosofia, religione, teologia
geografia: città, regioni, nazioni, geografia, geologia
giustizia: giustizia, magistratura, reati, criminalità
musica: musica, cantanti, gruppi musicali, generi musicali
cinema: cinema, film, televisione, spettacolo
intrattenimento: intrattenimento, tempo libero, svago, videogiochi
letteratura: letteratura, romanzi, narrativa, poesia
medicina: medicina, salute, farmaci, malattie, patologie
governo: governo, legge, politica, partiti, settore pubblico
scienza: scienza, ingegneria, tecnologia
sport: competizioni, sport
guerra: guerra, conflitti, battaglie, tematiche militari
storia: eventi, storia
società: tematiche sociali, tematiche internazionali
trasporti: automobili, treni, aerei, trasporti, veicoli
informatica: computer, smartphone, applicazioni, internet, social networks''')
with gr.Row():
button = gr.Button("Classify").style(full_width=False)
with gr.Row():
with gr.Column():
output = Label(label="Result")
with gr.Row():
with gr.Column():
footer = gr.Markdown("A few examples in this demo are extracted from Wikipedia")
button.click(classify, inputs=[text, classes], outputs = [output])
interface.launch()