Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import BertTokenizer, BertForTokenClassification | |
from transformers import pipeline | |
from collections import defaultdict | |
model_name = "b3x0m/bert-xomlac-ner" | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
model = BertForTokenClassification.from_pretrained(model_name) | |
nlp_ner = pipeline("ner", model=model, tokenizer=tokenizer) | |
def ner(file, selected_entities, min_count): | |
with open(file.name) as f: | |
text = f.read() | |
lines = text.splitlines() | |
batch_size = 32 | |
batches = [lines[i:i + batch_size] for i in range(0, len(lines), batch_size)] | |
entity_count = defaultdict(int) | |
for batch in batches: | |
batch_text = " ".join(batch) | |
tokens = tokenizer(batch_text)['input_ids'] | |
if len(tokens) > 128: | |
for i in range(0, len(tokens), 128): | |
sub_tokens = tokens[i:i + 128] | |
sub_batch_text = tokenizer.decode(sub_tokens, skip_special_tokens=True) | |
ner_results = nlp_ner(sub_batch_text) | |
current_entity = None | |
for entity in ner_results: | |
if entity['entity'].startswith("B-") or entity['entity'].startswith("M-") or entity['entity'].startswith("I-"): | |
if current_entity is None: | |
current_entity = {'text': entity['word'], 'label': entity['entity'][2:]} | |
else: | |
current_entity['text'] += entity['word'] | |
elif entity['entity'].startswith("E-"): | |
if current_entity: | |
current_entity['text'] += entity['word'] | |
current_entity['label'] = entity['entity'][2:] | |
entity_count[(current_entity['text'], current_entity['label'])] += 1 | |
current_entity = None | |
else: | |
ner_results = nlp_ner(batch_text) | |
current_entity = None | |
for entity in ner_results: | |
if entity['entity'].startswith("B-") or entity['entity'].startswith("M-") or entity['entity'].startswith("I-"): | |
if current_entity is None: | |
current_entity = {'text': entity['word'], 'label': entity['entity'][2:]} | |
else: | |
current_entity['text'] += entity['word'] | |
elif entity['entity'].startswith("E-"): | |
if current_entity: | |
current_entity['text'] += entity['word'] | |
current_entity['label'] = entity['entity'][2:] | |
entity_count[(current_entity['text'], current_entity['label'])] += 1 | |
current_entity = None | |
output = [] | |
for (name, label), count in entity_count.items(): | |
if count >= min_count and (not selected_entities or label in selected_entities): | |
output.append(f"{name}={label}={count}") | |
return "\n".join(output) | |
css = ''' | |
h1#title { | |
text-align: center; | |
} | |
''' | |
theme = gr.themes.Soft() | |
demo = gr.Blocks(css=css, theme=theme) | |
with demo: | |
input_file = gr.File(label="Upload File (.txt)", file_types=[".txt"]) | |
entity_filter = gr.CheckboxGroup( | |
label="Entities", | |
choices=["PER", "ORG", "LOC", "GPE"], | |
type="value" | |
) | |
count_entities = gr.Number( | |
label="Frequency", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=3 | |
) | |
output_text = gr.Textbox(label="Output", show_copy_button=True, interactive=False, lines=10, max_lines=20) | |
interface = gr.Interface( | |
fn=ner, | |
inputs=[input_file, entity_filter, count_entities], | |
outputs=[output_text], | |
allow_flagging="never", | |
) | |
demo.launch() | |