junhyun01's picture
Rename gradio_inference.py to app.py
684b494
raw
history blame
2.88 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
from torch import nn
import re
def paragraph_leveling(text):
model_name = "./trained_model/fine_tunning_encoder_v2"
model = AutoModel.from_pretrained(model_name)
model.to('cuda')
tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')
class MLP(nn.Module):
def __init__(self, target_size=3, input_size=768):
super(MLP, self).__init__()
self.num_classes = target_size
self.input_size = input_size
self.fc1 = nn.Linear(input_size, target_size)
def forward(self, x):
out = self.fc1(x)
return out
classifier = MLP(target_size=3, input_size=768)
classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier'))
classifier.cuda()
classifier.eval()
output_list = []
text_list = text.split(".")
result = []
# output_list.append(('Label: ', None))
# output_list.append(('abnormal', 'abnormal'))
# output_list.append(('normal', 'normal'))
# # output_list.append((' ', 'normal'))
# output_list.append(('not much information', 'not much information'))
# # output_list.append((' ', 'not much information'))
output_list.append(("\n", None))
for idx_sentence in text_list:
train_encoding = tokenizer(
idx_sentence,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=120)
output = model(**train_encoding.to('cuda'))
output = classifier(output[1])
output = output[0]
if output.argmax(-1) == 0:
output_list.append((idx_sentence, 'abnormal'))
result.append(0)
elif output.argmax(-1) == 1:
output_list.append((idx_sentence, 'normal'))
result.append(1)
else:
output_list.append((idx_sentence, 'not much information'))
result.append(2)
output_list.append(('\n', None))
if 0 in result:
output_list.append(('FINAL LABEL: ', None))
output_list.append(('ABNORMAL', 'abnormal'))
else:
output_list.append(('FINAL LABEL: ', None))
output_list.append(('NORMAL', 'normal'))
return output_list
demo = gr.Interface(
paragraph_leveling,
[
gr.Textbox(
label="Medical Report",
info="You can put any types of medical report",
lines=20,
value=" ",
),
],
gr.HighlightedText(
label="labeling",
show_legend = True,
show_label = True,
color_map={"abnormal": "violet", "normal": "lightgreen", "not much information": "lightgray"}),
theme=gr.themes.Base()
)
if __name__ == "__main__":
demo.launch(share=True)