junhyun01's picture
Update app.py
29cd78b verified
raw
history blame
2.55 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
from torch import nn
import re
def paragraph_leveling(text):
model_name = "./trained_model/fine_tunning_encoder_v2"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')
class MLP(nn.Module):
def __init__(self, target_size=3, input_size=768):
super(MLP, self).__init__()
self.num_classes = target_size
self.input_size = input_size
self.fc1 = nn.Linear(input_size, target_size)
def forward(self, x):
out = self.fc1(x)
return out
classifier = MLP(target_size=3, input_size=768)
classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier', map_location=torch.device('cpu')))
classifier.eval()
output_list = []
text_list = text.split(".")
result = []
output_list.append(("\n", None))
for idx_sentence in text_list:
train_encoding = tokenizer(
idx_sentence,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=120)
output = model(**train_encoding)
output = classifier(output[1])
output = output[0]
if output.argmax(-1) == 0:
output_list.append((idx_sentence, 'abnormal'))
result.append(0)
elif output.argmax(-1) == 1:
output_list.append((idx_sentence, 'normal'))
result.append(1)
else:
output_list.append((idx_sentence, 'uncertain'))
result.append(2)
output_list.append(('\n', None))
if 0 in result:
output_list.append(('FINAL LABEL: ', None))
output_list.append(('ABNORMAL', 'abnormal'))
else:
output_list.append(('FINAL LABEL: ', None))
output_list.append(('NORMAL', 'normal'))
return output_list
demo = gr.Interface(
paragraph_leveling,
[
gr.Textbox(
label="Medical Report",
info="You may put radiology medical report. Each sentence should be seperate with period mark.",
lines=20,
value=" ",
),
],
gr.HighlightedText(
label="labeling",
show_legend = True,
show_label = True,
color_map={"abnormal": "violet", "normal": "lightgreen", "uncertain": "lightgray"}),
theme=gr.themes.Base()
)
if __name__ == "__main__":
demo.launch()