Spaces:
Sleeping
Sleeping
File size: 12,576 Bytes
1cf999b 43c0e48 1cf999b dd422a6 1cf999b f069948 1cf999b 8fa880c 1cf999b 0eaa8e7 1cf999b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import os
import gradio as gr
import subprocess
import sys
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
install("numpy")
install("torch")
install("transformers")
install("unidecode")
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers import XLMRobertaForTokenClassification
from collections import Counter
from unidecode import unidecode
import string
import re
tokenizer = AutoTokenizer.from_pretrained("osiria/flare-it-ner")
model = XLMRobertaForTokenClassification.from_pretrained("osiria/flare-it-ner", num_labels = 5)
device = torch.device("cpu")
model = model.to(device)
model.eval()
from transformers import pipeline
ner = pipeline('ner', model=model, tokenizer=tokenizer, device=-1)
header = '''--------------------------------------------------------------------------------------------------
<style>
.vertical-text {
writing-mode: vertical-lr;
text-orientation: upright;
background-color:red;
}
</style>
<center>
<body>
<span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span>
<span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> D</span>
<span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;"> E</span>
<span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;"> M</span>
<span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> O</span>
<span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span>
</body>
</center>
<br>
--------------------------------------------------------------------------------------------------'''
paragraph = '''<b>What's FLARE-IT?</b>
This app is a demo of [FLARE-IT](https://huggingface.co/osiria/flare-it), a <b>lightweight</b> and <b>uncased</b> italian language model (<b>17M parameters</b> and <b>67MB</b> size). The model is here fine-tuned for named entity recognition on WikiNER (cross-validated F1 score of 81.29%) plus a custom, hand-crafted dataset of 3.500 manually annotated Wikipedia paragraphs.
It can recognize entities of the following types (in order to make the most of the color-coding, it is recommended to use the light theme for the interface):
- <span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴘᴇʀ</b> person</span>: names of persons
- <span style="background-color:orange;border-radius: 3px;padding: 3px;"><b>ʟᴏᴄ</b> location</span>: names of places
- <span style="background-color:lightblue;border-radius: 3px;padding: 3px;"><b>ᴏʀɢ</b> organization</span>: names of organizations
- <span style="background-color:tomato;border-radius: 3px;padding: 3px;"><b>ᴍɪsᴄ</b> miscellanea</span>: mixed type entities
- <span style="background-color:lightgrey;border-radius: 3px;padding: 3px;"><b>ᴅᴀᴛᴇ</b> date</span>: regex-based dates
- <span style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"><b>ᴛᴀɢ</b> tag</span>: most relevant entities, of any type
The <b>ᴍɪsᴄ</b> class has mixed nature, and it mainly covers names of events or products. Occasionally, entities of other classes might be labeled as <b>ᴍɪsᴄ</b> if the model is not confident enough about their identification.
The execution time in this app depends on the availability of the underlying cloud instance, and is not a reflection of the model inference time.
If unknown tokens are present in the text, they will interfere with the prediction, and the model may behave erratically. In that case, a warning sign will be displayed.
'''
maps = {"O": "NONE", "PER": "PER", "LOC": "LOC", "ORG": "ORG", "MISC": "MISC", "DATE": "DATE"}
reg_month = "(?:gennaio|febbraio|marzo|aprile|maggio|giugno|luglio|agosto|settembre|ottobre|novembre|dicembre|january|february|march|april|may|june|july|august|september|october|november|december)"
reg_date = "(?:\d{1,2}\°{0,1}|primo|\d{1,2}\º{0,1})" + " " + reg_month + " " + "\d{4}|"
reg_date = reg_date + reg_month + " " + "\d{4}|"
reg_date = reg_date + "\d{1,2}" + " " + reg_month
reg_date = reg_date + "\d{1,2}" + "(?:\/|\.)\d{1,2}(?:\/|\.)" + "\d{4}|"
reg_date = reg_date + "(?<=dal )\d{4}|(?<=al )\d{4}|(?<=nel )\d{4}|(?<=anno )\d{4}|(?<=del )\d{4}|"
reg_date = reg_date + "\d{1,5} a\.c\.|\d{1,5} d\.c\."
map_punct = {"’": "'", "«": '"', "»": '"', "”": '"', "“": '"', "–": "-", "$": ""}
unk_tok = 9005
merge_th_1 = 0.8
merge_th_2 = 0.4
min_th = 0.5
def extract(text):
text = text.strip()
for mp in map_punct:
text = text.replace(mp, map_punct[mp])
text = re.sub("\[\d+\]", "", text)
warn_flag = False
res_total = []
out_text = ""
for p_text in text.split("\n"):
if p_text:
toks = tokenizer.encode(p_text)
if unk_tok in toks:
warn_flag = True
res_orig = ner(p_text, aggregation_strategy = "first")
res_orig = [el for r, el in enumerate(res_orig) if len(el["word"].strip()) > 1]
res = []
for r, ent in enumerate(res_orig):
if r > 0 and ent["score"] < merge_th_1 and ent["start"] <= res[-1]["end"] + 1 and ent["score"] <= res[-1]["score"]:
res[-1]["word"] = res[-1]["word"] + " " + ent["word"]
res[-1]["score"] = merge_th_1*(res[-1]["score"] > merge_th_2)
res[-1]["end"] = ent["end"]
elif r < len(res_orig) - 1 and ent["score"] < merge_th_1 and res_orig[r+1]["start"] <= ent["end"] + 1 and res_orig[r+1]["score"] > ent["score"]:
res_orig[r+1]["word"] = ent["word"] + " " + res_orig[r+1]["word"]
res_orig[r+1]["score"] = merge_th_1*(res_orig[r+1]["score"] > merge_th_2)
res_orig[r+1]["start"] = ent["start"]
else:
res.append(ent)
res = [el for r, el in enumerate(res) if el["score"] >= min_th]
dates = [{"entity_group": "DATE", "score": 1.0, "word": p_text[el.span()[0]:el.span()[1]], "start": el.span()[0], "end": el.span()[1]} for el in re.finditer(reg_date, p_text, flags = re.IGNORECASE)]
res.extend(dates)
res = sorted(res, key = lambda t: t["start"])
res_total.extend(res)
chunks = [("", "", 0, "NONE")]
for el in res:
if maps[el["entity_group"]] != "NONE":
tag = maps[el["entity_group"]]
chunks.append((p_text[el["start"]: el["end"]], p_text[chunks[-1][2]:el["end"]], el["end"], tag))
if chunks[-1][2] < len(p_text):
chunks.append(("END", p_text[chunks[-1][2]:], -1, "NONE"))
chunks = chunks[1:]
n_text = []
for i, chunk in enumerate(chunks):
rep = chunk[0]
if chunk[3] == "PER":
rep = '<span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴘᴇʀ</b> ' + chunk[0] + '</span>'
elif chunk[3] == "LOC":
rep = '<span style="background-color:orange;border-radius: 3px;padding: 3px;"><b>ʟᴏᴄ</b> ' + chunk[0] + '</span>'
elif chunk[3] == "ORG":
rep = '<span style="background-color:lightblue;border-radius: 3px;padding: 3px;"><b>ᴏʀɢ</b> ' + chunk[0] + '</span>'
elif chunk[3] == "MISC":
rep = '<span style="background-color:tomato;border-radius: 3px;padding: 3px;"><b>ᴍɪsᴄ</b> ' + chunk[0] + '</span>'
elif chunk[3] == "DATE":
rep = '<span style="background-color:lightgrey;border-radius: 3px;padding: 3px;"><b>ᴅᴀᴛᴇ</b> ' + chunk[0] + '</span>'
n_text.append(chunk[1].replace(chunk[0], rep))
n_text = "".join(n_text)
if out_text:
out_text = out_text + "<br>" + n_text
else:
out_text = n_text
tags = [el["word"] for el in res_total if el["entity_group"] not in ['DATE', None]]
cnt = Counter(tags)
tags = sorted(list(set([el for el in tags if cnt[el] > 1])), key = lambda t: cnt[t]*np.exp(-tags.index(t)))[::-1]
tags = [" ".join(re.sub("[^A-Za-z0-9\s]", "", unidecode(tag)).split()) for tag in tags]
tags = ['<span style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"><b>ᴛᴀɢ </b> ' + el + '</span>' for el in tags]
tags = " ".join(tags)
if tags:
out_text = out_text + "<br><br><b>Tags:</b> " + tags
if warn_flag:
out_text = out_text + "<br><br><b>Warning ⚠️:</b> Unknown tokens detected in text. The model might behave erratically"
return out_text
init_text = '''l'agenzia spaziale europea, nota internazionalmente con l'acronimo esa dalla denominazione inglese european space agency, è un'agenzia internazionale fondata nel 1975 incaricata di coordinare i progetti spaziali di 22 paesi europei. il suo quartier generale si trova a parigi in francia, con uffici a mosca, bruxelles, washington e houston.
attualmente il direttore generale dell'agenzia è l'austriaco josef aschbacher, il quale ha sostituito il tedesco johann-dietrich wörner il primo marzo 2021.
lo spazioporto dell'esa è il centre spatial guyanais a kourou, nella guyana francese, un sito scelto, come tutte le basi di lancio, per via della sua vicinanza con l'equatore. durante gli ultimi anni il lanciatore ariane 5 ha consentito all'esa di raggiungere una posizione di primo piano nei lanci commerciali e l'esa è il principale concorrente della nasa nell'esplorazione spaziale.
le missioni scientifiche dell'esa hanno le loro basi al centro europeo per la ricerca e la tecnologia spaziale (estec) di noordwijk, nei paesi bassi. il centro europeo per le operazioni spaziali (esoc), di darmstadt in germania, è responsabile del controllo dei satelliti esa in orbita. [...]
l'agenzia spaziale italiana (asi) venne fondata nel 1988 per promuovere, coordinare e condurre le attività spaziali in italia. opera in collaborazione con il ministero dell'università e della ricerca scientifica e coopera in numerosi progetti con entità attive nella ricerca scientifica e nelle attività commerciali legate allo spazio. internazionalmente l'asi fornisce la delegazione italiana per l'agenzia spaziale europea e le sue sussidiarie.'''
init_output = extract(init_text)
with gr.Blocks(theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface:
with gr.Row():
gr.Markdown(header)
with gr.Row():
with gr.Column():
gr.Markdown(paragraph)
with gr.Column():
incipit = gr.Markdown("<b>Highlighted entities<b>")
entities = gr.Markdown(init_output)
with gr.Row():
with gr.Column():
text = gr.Text(label="Extract entities", lines = 10, value = init_text)
with gr.Column():
gr.Examples([["aristotele nacque nel 384 a.c. o nel 383 a.c. a stagira, l'attuale stavro, colonia greca situata nella parte nord-orientale della penisola calcidica della tracia. si dice che il padre, nicomaco, sia vissuto presso aminta iii, re dei macedoni, prestandogli i servigi di medico e di amico. aristotele, come figlio del medico reale, doveva pertanto risiedere nella capitale del regno di macedonia"],
["enzo ferrari fondò la scuderia ferrari, che è tuttora la divisione principale del reparto corse della ferrari, il 16 novembre 1929 a modena"],
["wikipedia è un'enciclopedia online a contenuto libero, collaborativa, multilingue e gratuita, nata nel 2001, sostenuta e ospitata dalla wikimedia foundation, un'organizzazione non a scopo di lucro statunitense. lanciata da jimmy wales e larry sanger il 15 gennaio 2001, inizialmente nell'edizione in lingua inglese, nei mesi successivi ha aggiunto edizioni in numerose altre lingue"]],
inputs=[text])
with gr.Row():
button = gr.Button("Extract").style(full_width=False)
with gr.Row():
with gr.Column():
gr.Markdown("<center>The input examples in this demo are extracted from https://it.wikipedia.org</center>")
button.click(extract, inputs=[text], outputs = [entities])
interface.launch() |