File size: 2,547 Bytes
f965c55 6bd5a8b f965c55 d747df2 f965c55 4ebbbf4 f965c55 18d8c3a f965c55 29cd78b f965c55 18d8c3a f965c55 4ebbbf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import gradio as gr
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
from torch import nn
import re
def paragraph_leveling(text):
model_name = "./trained_model/fine_tunning_encoder_v2"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')
class MLP(nn.Module):
def __init__(self, target_size=3, input_size=768):
super(MLP, self).__init__()
self.num_classes = target_size
self.input_size = input_size
self.fc1 = nn.Linear(input_size, target_size)
def forward(self, x):
out = self.fc1(x)
return out
classifier = MLP(target_size=3, input_size=768)
classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier', map_location=torch.device('cpu')))
classifier.eval()
output_list = []
text_list = text.split(".")
result = []
output_list.append(("\n", None))
for idx_sentence in text_list:
train_encoding = tokenizer(
idx_sentence,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=120)
output = model(**train_encoding)
output = classifier(output[1])
output = output[0]
if output.argmax(-1) == 0:
output_list.append((idx_sentence, 'abnormal'))
result.append(0)
elif output.argmax(-1) == 1:
output_list.append((idx_sentence, 'normal'))
result.append(1)
else:
output_list.append((idx_sentence, 'uncertain'))
result.append(2)
output_list.append(('\n', None))
if 0 in result:
output_list.append(('FINAL LABEL: ', None))
output_list.append(('ABNORMAL', 'abnormal'))
else:
output_list.append(('FINAL LABEL: ', None))
output_list.append(('NORMAL', 'normal'))
return output_list
demo = gr.Interface(
paragraph_leveling,
[
gr.Textbox(
label="Medical Report",
info="You may put radiology medical report. Each sentence should be seperate with period mark.",
lines=20,
value=" ",
),
],
gr.HighlightedText(
label="labeling",
show_legend = True,
show_label = True,
color_map={"abnormal": "violet", "normal": "lightgreen", "uncertain": "lightgray"}),
theme=gr.themes.Base()
)
if __name__ == "__main__":
demo.launch()
|