junhyun01 commited on
Commit
f965c55
·
1 Parent(s): ba4c6e2

Create gradio_inference.py

Browse files
Files changed (1) hide show
  1. gradio_inference.py +93 -94
gradio_inference.py CHANGED
@@ -1,94 +1,93 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
4
- from torch import nn
5
- import re
6
-
7
-
8
- def paragraph_leveling(text):
9
- model_name = "./trained_model/fine_tunning_encoder_v2"
10
- model = AutoModel.from_pretrained(model_name)
11
- model.to('cuda')
12
- tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')
13
-
14
- class MLP(nn.Module):
15
- def __init__(self, target_size=3, input_size=768):
16
- super(MLP, self).__init__()
17
- self.num_classes = target_size
18
- self.input_size = input_size
19
- self.fc1 = nn.Linear(input_size, target_size)
20
-
21
- def forward(self, x):
22
- out = self.fc1(x)
23
- return out
24
-
25
- classifier = MLP(target_size=3, input_size=768)
26
- classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier'))
27
- classifier.cuda()
28
- classifier.eval()
29
-
30
- output_list = []
31
- text_list = text.split(".")
32
- result = []
33
-
34
- # output_list.append(('Label: ', None))
35
- # output_list.append(('abnormal', 'abnormal'))
36
- # output_list.append(('normal', 'normal'))
37
- # # output_list.append((' ', 'normal'))
38
- # output_list.append(('not much information', 'not much information'))
39
- # # output_list.append((' ', 'not much information'))
40
- output_list.append(("\n", None))
41
-
42
- for idx_sentence in text_list:
43
- train_encoding = tokenizer(
44
- idx_sentence,
45
- return_tensors='pt',
46
- padding='max_length',
47
- truncation=True,
48
- max_length=120)
49
- output = model(**train_encoding.to('cuda'))
50
- output = classifier(output[1])
51
- output = output[0]
52
-
53
- if output.argmax(-1) == 0:
54
- output_list.append((idx_sentence, 'abnormal'))
55
- result.append(0)
56
- elif output.argmax(-1) == 1:
57
- output_list.append((idx_sentence, 'normal'))
58
- result.append(1)
59
- else:
60
- output_list.append((idx_sentence, 'not much information'))
61
- result.append(2)
62
-
63
- output_list.append(('\n', None))
64
- if 0 in result:
65
- output_list.append(('FINAL LABEL: ', None))
66
- output_list.append(('ABNORMAL', 'abnormal'))
67
-
68
- else:
69
- output_list.append(('FINAL LABEL: ', None))
70
- output_list.append(('NORMAL', 'normal'))
71
-
72
- return output_list
73
-
74
-
75
- demo = gr.Interface(
76
- paragraph_leveling,
77
- [
78
- gr.Textbox(
79
- label="Medical Report",
80
- info="You can put any types of medical report",
81
- lines=20,
82
- value=" ",
83
- ),
84
- ],
85
- gr.HighlightedText(
86
- label="labeling",
87
- show_legend = True,
88
- show_label = True,
89
- color_map={"abnormal": "violet", "normal": "lightgreen", "not much information": "lightgray"}),
90
- theme=gr.themes.Base()
91
- )
92
- if __name__ == "__main__":
93
- demo.launch(share=True)
94
-
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
4
+ from torch import nn
5
+ import re
6
+
7
+
8
+ def paragraph_leveling(text):
9
+ model_name = "./trained_model/fine_tunning_encoder_v2"
10
+ model = AutoModel.from_pretrained(model_name)
11
+ model.to('cuda')
12
+ tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')
13
+
14
+ class MLP(nn.Module):
15
+ def __init__(self, target_size=3, input_size=768):
16
+ super(MLP, self).__init__()
17
+ self.num_classes = target_size
18
+ self.input_size = input_size
19
+ self.fc1 = nn.Linear(input_size, target_size)
20
+
21
+ def forward(self, x):
22
+ out = self.fc1(x)
23
+ return out
24
+
25
+ classifier = MLP(target_size=3, input_size=768)
26
+ classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier'))
27
+ classifier.cuda()
28
+ classifier.eval()
29
+
30
+ output_list = []
31
+ text_list = text.split(".")
32
+ result = []
33
+
34
+ # output_list.append(('Label: ', None))
35
+ # output_list.append(('abnormal', 'abnormal'))
36
+ # output_list.append(('normal', 'normal'))
37
+ # # output_list.append((' ', 'normal'))
38
+ # output_list.append(('not much information', 'not much information'))
39
+ # # output_list.append((' ', 'not much information'))
40
+ output_list.append(("\n", None))
41
+
42
+ for idx_sentence in text_list:
43
+ train_encoding = tokenizer(
44
+ idx_sentence,
45
+ return_tensors='pt',
46
+ padding='max_length',
47
+ truncation=True,
48
+ max_length=120)
49
+ output = model(**train_encoding.to('cuda'))
50
+ output = classifier(output[1])
51
+ output = output[0]
52
+
53
+ if output.argmax(-1) == 0:
54
+ output_list.append((idx_sentence, 'abnormal'))
55
+ result.append(0)
56
+ elif output.argmax(-1) == 1:
57
+ output_list.append((idx_sentence, 'normal'))
58
+ result.append(1)
59
+ else:
60
+ output_list.append((idx_sentence, 'not much information'))
61
+ result.append(2)
62
+
63
+ output_list.append(('\n', None))
64
+ if 0 in result:
65
+ output_list.append(('FINAL LABEL: ', None))
66
+ output_list.append(('ABNORMAL', 'abnormal'))
67
+
68
+ else:
69
+ output_list.append(('FINAL LABEL: ', None))
70
+ output_list.append(('NORMAL', 'normal'))
71
+
72
+ return output_list
73
+
74
+
75
+ demo = gr.Interface(
76
+ paragraph_leveling,
77
+ [
78
+ gr.Textbox(
79
+ label="Medical Report",
80
+ info="You can put any types of medical report",
81
+ lines=20,
82
+ value=" ",
83
+ ),
84
+ ],
85
+ gr.HighlightedText(
86
+ label="labeling",
87
+ show_legend = True,
88
+ show_label = True,
89
+ color_map={"abnormal": "violet", "normal": "lightgreen", "not much information": "lightgray"}),
90
+ theme=gr.themes.Base()
91
+ )
92
+ if __name__ == "__main__":
93
+ demo.launch(share=True)