emr-distillation commited on
Commit
ba4c6e2
·
1 Parent(s): 211c1b5

gradio_file

Browse files
Files changed (1) hide show
  1. gradio_inference.py +94 -0
gradio_inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
4
+ from torch import nn
5
+ import re
6
+
7
+
8
+ def paragraph_leveling(text):
9
+ model_name = "./trained_model/fine_tunning_encoder_v2"
10
+ model = AutoModel.from_pretrained(model_name)
11
+ model.to('cuda')
12
+ tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')
13
+
14
+ class MLP(nn.Module):
15
+ def __init__(self, target_size=3, input_size=768):
16
+ super(MLP, self).__init__()
17
+ self.num_classes = target_size
18
+ self.input_size = input_size
19
+ self.fc1 = nn.Linear(input_size, target_size)
20
+
21
+ def forward(self, x):
22
+ out = self.fc1(x)
23
+ return out
24
+
25
+ classifier = MLP(target_size=3, input_size=768)
26
+ classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier'))
27
+ classifier.cuda()
28
+ classifier.eval()
29
+
30
+ output_list = []
31
+ text_list = text.split(".")
32
+ result = []
33
+
34
+ # output_list.append(('Label: ', None))
35
+ # output_list.append(('abnormal', 'abnormal'))
36
+ # output_list.append(('normal', 'normal'))
37
+ # # output_list.append((' ', 'normal'))
38
+ # output_list.append(('not much information', 'not much information'))
39
+ # # output_list.append((' ', 'not much information'))
40
+ output_list.append(("\n", None))
41
+
42
+ for idx_sentence in text_list:
43
+ train_encoding = tokenizer(
44
+ idx_sentence,
45
+ return_tensors='pt',
46
+ padding='max_length',
47
+ truncation=True,
48
+ max_length=120)
49
+ output = model(**train_encoding.to('cuda'))
50
+ output = classifier(output[1])
51
+ output = output[0]
52
+
53
+ if output.argmax(-1) == 0:
54
+ output_list.append((idx_sentence, 'abnormal'))
55
+ result.append(0)
56
+ elif output.argmax(-1) == 1:
57
+ output_list.append((idx_sentence, 'normal'))
58
+ result.append(1)
59
+ else:
60
+ output_list.append((idx_sentence, 'not much information'))
61
+ result.append(2)
62
+
63
+ output_list.append(('\n', None))
64
+ if 0 in result:
65
+ output_list.append(('FINAL LABEL: ', None))
66
+ output_list.append(('ABNORMAL', 'abnormal'))
67
+
68
+ else:
69
+ output_list.append(('FINAL LABEL: ', None))
70
+ output_list.append(('NORMAL', 'normal'))
71
+
72
+ return output_list
73
+
74
+
75
+ demo = gr.Interface(
76
+ paragraph_leveling,
77
+ [
78
+ gr.Textbox(
79
+ label="Medical Report",
80
+ info="You can put any types of medical report",
81
+ lines=20,
82
+ value=" ",
83
+ ),
84
+ ],
85
+ gr.HighlightedText(
86
+ label="labeling",
87
+ show_legend = True,
88
+ show_label = True,
89
+ color_map={"abnormal": "violet", "normal": "lightgreen", "not much information": "lightgray"}),
90
+ theme=gr.themes.Base()
91
+ )
92
+ if __name__ == "__main__":
93
+ demo.launch(share=True)
94
+