|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel |
|
from torch import nn |
|
import re |
|
|
|
|
|
def paragraph_leveling(text): |
|
model_name = "./trained_model/fine_tunning_encoder_v2" |
|
model = AutoModel.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m') |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, target_size=3, input_size=768): |
|
super(MLP, self).__init__() |
|
self.num_classes = target_size |
|
self.input_size = input_size |
|
self.fc1 = nn.Linear(input_size, target_size) |
|
|
|
def forward(self, x): |
|
out = self.fc1(x) |
|
return out |
|
|
|
classifier = MLP(target_size=3, input_size=768) |
|
classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier', map_location=torch.device('cpu'))) |
|
classifier.eval() |
|
|
|
output_list = [] |
|
text_list = text.split(".") |
|
result = [] |
|
|
|
output_list.append(("\n", None)) |
|
|
|
for idx_sentence in text_list: |
|
train_encoding = tokenizer( |
|
idx_sentence, |
|
return_tensors='pt', |
|
padding='max_length', |
|
truncation=True, |
|
max_length=120) |
|
output = model(**train_encoding) |
|
output = classifier(output[1]) |
|
output = output[0] |
|
|
|
if output.argmax(-1) == 0: |
|
output_list.append((idx_sentence, 'abnormal')) |
|
result.append(0) |
|
elif output.argmax(-1) == 1: |
|
output_list.append((idx_sentence, 'normal')) |
|
result.append(1) |
|
else: |
|
output_list.append((idx_sentence, 'uncertain')) |
|
result.append(2) |
|
|
|
output_list.append(('\n', None)) |
|
if 0 in result: |
|
output_list.append(('FINAL LABEL: ', None)) |
|
output_list.append(('ABNORMAL', 'abnormal')) |
|
|
|
else: |
|
output_list.append(('FINAL LABEL: ', None)) |
|
output_list.append(('NORMAL', 'normal')) |
|
|
|
return output_list |
|
|
|
|
|
demo = gr.Interface( |
|
paragraph_leveling, |
|
[ |
|
gr.Textbox( |
|
label="Medical Report", |
|
info="You may put radiology medical report. Each sentence should be seperate with period mark.", |
|
lines=20, |
|
value=" ", |
|
), |
|
], |
|
gr.HighlightedText( |
|
label="labeling", |
|
show_legend = True, |
|
show_label = True, |
|
color_map={"abnormal": "violet", "normal": "lightgreen", "uncertain": "lightgray"}), |
|
theme=gr.themes.Base() |
|
) |
|
if __name__ == "__main__": |
|
demo.launch() |
|
|