File size: 8,965 Bytes
e28a4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68b7cd5
e28a4b7
92fc2d8
 
e28a4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import os
import gradio as gr
import subprocess
import sys

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

install("numpy")
install("torch")
install("transformers")
install("unidecode")

import numpy as np
import torch
from transformers import AutoTokenizer
from transformers import BertForTokenClassification
from collections import Counter
from unidecode import unidecode
import string
import re

auth_token = os.environ.get("AUTH_TOKEN")

tokenizer = AutoTokenizer.from_pretrained("osiria/bert-base-cased-ner-en", token=auth_token)
model = BertForTokenClassification.from_pretrained("osiria/bert-base-cased-ner-en", num_labels = 5, token=auth_token)
device = torch.device("cpu")
model = model.to(device)
model.eval()

from transformers import pipeline
ner = pipeline('ner', model=model, tokenizer=tokenizer, device=-1)


header = '''--------------------------------------------------------------------------------------------------
<style>
.vertical-text {
    writing-mode: vertical-lr;
    text-orientation: upright;
    background-color:red;
}
</style>
<center>
<body>
<span class="vertical-text" style="background-color:lightgreen;border-radius: 3px;padding: 3px;"> </span>
<span class="vertical-text" style="background-color:orange;border-radius: 3px;padding: 3px;"> D</span>
<span class="vertical-text" style="background-color:lightblue;border-radius: 3px;padding: 3px;">    E</span>
<span class="vertical-text" style="background-color:tomato;border-radius: 3px;padding: 3px;">    M</span>
<span class="vertical-text" style="background-color:lightgrey;border-radius: 3px;padding: 3px;"> O</span>
<span class="vertical-text" style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"> </span>
</body>
</center>
<br>
'''

maps = {"O": "NONE", "PER": "PER", "LOC": "LOC", "ORG": "ORG", "MISC": "MISC", "DATE": "DATE"}
reg_month = "(?:gennaio|febbraio|marzo|aprile|maggio|giugno|luglio|agosto|settembre|ottobre|novembre|dicembre|january|february|march|april|may|june|july|august|september|october|november|december)"
reg_date = "(?:\d{1,2}\°{0,1}|primo|\d{1,2}\º{0,1})" + " " + reg_month + " " + "\d{4}|"
reg_date = reg_date + reg_month + " " + "\d{4}|"
reg_date = reg_date + "\d{1,2}" + " " + reg_month
reg_date = reg_date + "\d{1,2}" + "(?:\/|\.)\d{1,2}(?:\/|\.)" + "\d{4}|"
reg_date = reg_date + "(?<=dal )\d{4}|(?<=al )\d{4}|(?<=nel )\d{4}|(?<=anno )\d{4}|(?<=del )\d{4}|"
reg_date = reg_date + "\d{1,5} a\.c\.|\d{1,5} d\.c\."
map_punct = {"’": "'", "«": '"', "»": '"', "”": '"', "“": '"', "–": "-", "$": ""}
unk_tok = 9005
merge_th_1 = 0.8
merge_th_2 = 0.4
min_th = 0.5

def extract(text):

    text = text.strip()
    for mp in map_punct:
        text = text.replace(mp, map_punct[mp])
    text = re.sub("\[\d+\]", "", text)

    warn_flag = False
    
    res_total = []
    out_text = ""
    
    for p_text in text.split("\n"):
    
        if p_text:

            toks = tokenizer.encode(p_text)
            if unk_tok in toks:
                warn_flag = True
    
            res_orig = ner(p_text, aggregation_strategy = "first")
            res_orig = [el for r, el in enumerate(res_orig) if len(el["word"].strip()) > 1]
            res = []
            
            for r, ent in enumerate(res_orig):
                if r > 0 and ent["score"] < merge_th_1 and ent["start"] <= res[-1]["end"] + 1 and ent["score"] <= res[-1]["score"]:
                    res[-1]["word"] = res[-1]["word"] + " " + ent["word"]
                    res[-1]["score"] = merge_th_1*(res[-1]["score"] > merge_th_2)
                    res[-1]["end"] = ent["end"]
                elif r < len(res_orig) - 1 and ent["score"] < merge_th_1 and res_orig[r+1]["start"] <= ent["end"] + 1 and res_orig[r+1]["score"] > ent["score"]:
                    res_orig[r+1]["word"] = ent["word"] + " " + res_orig[r+1]["word"]
                    res_orig[r+1]["score"] = merge_th_1*(res_orig[r+1]["score"] > merge_th_2)
                    res_orig[r+1]["start"] = ent["start"]
                else:
                    res.append(ent)
                    
            res = [el for r, el in enumerate(res) if el["score"] >= min_th]
        
            dates = [{"entity_group": "DATE", "score": 1.0, "word": p_text[el.span()[0]:el.span()[1]], "start": el.span()[0], "end": el.span()[1]} for el in re.finditer(reg_date, p_text, flags = re.IGNORECASE)]
            res.extend(dates)
            res = sorted(res, key = lambda t: t["start"])
            res_total.extend(res)
    
            chunks = [("", "", 0, "NONE")]
    
            for el in res:
                if maps[el["entity_group"]] != "NONE":
                    tag = maps[el["entity_group"]]
                    chunks.append((p_text[el["start"]: el["end"]], p_text[chunks[-1][2]:el["end"]], el["end"], tag))

            if chunks[-1][2] < len(p_text):
                chunks.append(("END", p_text[chunks[-1][2]:], -1, "NONE"))
            chunks = chunks[1:]
            
            n_text = []
    
            for i, chunk in enumerate(chunks):

                rep = chunk[0]
    
                if chunk[3] == "PER":
                    rep = '<span style="background-color:lightgreen;border-radius: 3px;padding: 3px;"><b>ᴘᴇʀ</b> ' + chunk[0] + '</span>'
                elif chunk[3] == "LOC":
                    rep = '<span style="background-color:orange;border-radius: 3px;padding: 3px;"><b>ʟᴏᴄ</b> ' + chunk[0] + '</span>'
                elif chunk[3] == "ORG":
                    rep = '<span style="background-color:lightblue;border-radius: 3px;padding: 3px;"><b>ᴏʀɢ</b> ' + chunk[0] + '</span>'
                elif chunk[3] == "MISC":
                    rep = '<span style="background-color:tomato;border-radius: 3px;padding: 3px;"><b>ᴍɪsᴄ</b> ' + chunk[0] + '</span>'
                elif chunk[3] == "DATE":
                    rep = '<span style="background-color:lightgrey;border-radius: 3px;padding: 3px;"><b>ᴅᴀᴛᴇ</b> ' + chunk[0] + '</span>'
    
                n_text.append(chunk[1].replace(chunk[0], rep))
    
            n_text = "".join(n_text)
            if out_text:
                out_text = out_text + "<br>" + n_text
            else:
                out_text = n_text
    

    tags = [el["word"] for el in res_total if el["entity_group"] not in ['DATE', None]]
    cnt = Counter(tags)
    tags = sorted(list(set([el for el in tags if cnt[el] > 1])), key = lambda t: cnt[t]*np.exp(-tags.index(t)))[::-1]
    tags = [" ".join(re.sub("[^A-Za-z0-9\s]", "", unidecode(tag)).split()) for tag in tags]
    tags = ['<span style="background-color:#CF9FFF;border-radius: 3px;padding: 3px;"><b>ᴛᴀɢ </b> ' + el + '</span>' for el in tags]
    tags = "    ".join(tags)

    if tags:
        out_text = out_text + "<br><br><b>Tags:</b> " + tags

    if warn_flag:
        out_text = out_text + "<br><br><b>Warning ⚠️:</b> Unknown tokens detected in text.  The model might behave erratically"
    
    return out_text



init_text = '''The American Academy of Arts and Sciences (AAA&S) is one of the oldest learned societies in the United States. It was founded in 1780 during the American Revolution by John Adams, John Hancock, James Bowdoin, Andrew Oliver, and other Founding Fathers of the United States. It is headquartered in Cambridge, Massachusetts.
Membership in the academy is achieved through a thorough petition, review, and election process. The academy's quarterly journal, Dædalus, is published by the MIT Press on behalf of the academy. The academy also conducts multidisciplinary public policy research.
The Academy was established by the Massachusetts legislature on May 4, 1780, charted in order "to cultivate every art and science which may tend to advance the interest, honor, dignity, and happiness of a free, independent, and virtuous people." The sixty-two incorporating fellows represented varying interests and high standing in the political, professional, and commercial sectors of the state. The first class of new members, chosen by the Academy in 1781, included Benjamin Franklin and George Washington as well as several international honorary members.
'''

init_output = extract(init_text)




with gr.Blocks(css="footer {visibility: hidden}", theme=gr.themes.Default(text_size="lg", spacing_size="lg")) as interface:
    
    with gr.Row():
        gr.Markdown(header)
    with gr.Row():
        text = gr.Text(label="Extract entities", lines = 10, value = init_text)
    with gr.Row():
        with gr.Column():
            button = gr.Button("Extract").style(full_width=False)
    with gr.Row():
        with gr.Column():
            entities = gr.Markdown(init_output)

    with gr.Row():
        with gr.Column():
            gr.Markdown("<center>The input examples in this demo are extracted from https://it.wikipedia.org</center>") 

    button.click(extract, inputs=[text], outputs = [entities])


interface.launch()