Spaces:
Running
Running
File size: 5,599 Bytes
57d44a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import gradio as gr
import transformers
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
# model large
model_name = "pucpr/clinicalnerpt-chemical"
model_large = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer_large = AutoTokenizer.from_pretrained(model_name)
# model base
model_name = "pucpr/clinicalnerpt-chemical"
model_base = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer_base = AutoTokenizer.from_pretrained(model_name)
# css
background_colors_entity_word = {
'ChemicalDrugs': "#fae8ff",
}
background_colors_entity_tag = {
'ChemicalDrugs': "#d946ef",
}
css = {
'entity_word': 'color:#000000;background: #xxxxxx; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 2.5; border-radius: 0.35em;',
'entity_tag': 'color:#fff;background: #xxxxxx; font-size: 0.8em; font-weight: bold; line-height: 2.5; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5em;'
}
list_EN = "<span style='"
list_EN += f"{css['entity_tag'].replace('#xxxxxx',background_colors_entity_tag['ChemicalDrugs'])};padding:0.5em;"
list_EN += "'>ChemicalDrugs</span>"
# infos
title = "BioBERTpt - Chemical entities"
description = "BioBERTpt - Chemical entities"
allow_screenshot = False
allow_flagging = False
examples = [
["Dispneia venoso central em subclavia D duplolumen recebendo solução salina e glicosada em BI."],
["Paciente com Sepse pulmonar em D8 tazocin (paciente não recebeu por 2 dias Atb)."],
["FOI REALIZADO CURSO DE ATB COM LEVOFLOXACINA POR 7 DIAS."],
]
def ner(input_text):
num = 0
for tokenizer,model in zip([tokenizer_large,tokenizer_base],[model_large,model_base]):
# tokenization
inputs = tokenizer(input_text, max_length=512, truncation=True, return_tensors="pt")
tokens = inputs.tokens()
# get predictions
outputs = model(**inputs).logits
predictions = torch.argmax(outputs, dim=2)
preds = [model_base.config.id2label[prediction] for prediction in predictions[0].numpy()]
# variables
groups_pred = dict()
group_indices = list()
group_label = ''
pred_prec = ''
group_start = ''
count = 0
# group the NEs
for i,en in enumerate(preds):
if en == 'O':
if len(group_indices) > 0:
groups_pred[count] = {'indices':group_indices,'en':group_label}
group_indices = list()
group_label = ''
count += 1
if en.startswith('B'):
if len(group_indices) > 0:
groups_pred[count] = {'indices':group_indices,'en':group_label}
group_indices = list()
group_label = ''
count += 1
group_indices.append(i)
group_label = en.replace('B-','')
pred_prec = en
elif en.startswith('I'):
if len(group_indices) > 0:
if en.replace('I-','') == group_label:
group_indices.append(i)
else:
groups_pred[count] = {'indices':group_indices,'en':group_label}
group_indices = [i]
group_label = en.replace('I-','')
count += 1
else:
group_indices = [i]
group_label = en.replace('I-','')
if i == len(preds) - 1 and len(group_indices) > 0:
groups_pred[count] = {'indices':group_indices,'en':group_label}
group_indices = list()
group_label = ''
count += 1
# there is at least one NE
len_groups_pred = len(groups_pred)
inputs = inputs['input_ids'][0].numpy()#[1:-1]
if len_groups_pred > 0:
for pred_num in range(len_groups_pred):
en = groups_pred[pred_num]['en']
indices = groups_pred[pred_num]['indices']
if pred_num == 0:
if indices[0] > 0:
output = tokenizer.decode(inputs[:indices[0]]) + f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> '
else:
output = f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> '
else:
output += tokenizer.decode(inputs[indices_prev[-1]+1:indices[0]]) + f'<span style="{css["entity_word"].replace("#xxxxxx",background_colors_entity_word[en])}">' + tokenizer.decode(inputs[indices[0]:indices[-1]+1]) + f'<span style="{css["entity_tag"].replace("#xxxxxx",background_colors_entity_tag[en])}">' + en + '</span></span> '
indices_prev = indices
output += tokenizer.decode(inputs[indices_prev[-1]+1:])
else:
output = input_text
# output
output = output.replace('[CLS]','').replace(' [SEP]','').replace('##','')
output = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + output + "</div>"
if num == 0:
output_large = output
num += 1
else: output_base = output
return output_large, output_base
# interface gradio
iface = gr.Interface(
title=title,
description=description,
article=article,
allow_screenshot=allow_screenshot,
allow_flagging=allow_flagging,
fn=ner,
inputs=gr.inputs.Textbox(placeholder="Digite uma frase aqui ou clique em um exemplo:", lines=5),
outputs=[gr.outputs.HTML(label="NER1"),gr.outputs.HTML(label="NER2")],
examples=examples
)
iface.launch() |