File size: 2,547 Bytes
f965c55
 
 
 
 
 
 
 
 
6bd5a8b
f965c55
 
 
 
 
 
 
 
 
 
 
 
 
 
d747df2
f965c55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ebbbf4
f965c55
 
 
 
 
 
 
 
 
 
18d8c3a
f965c55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29cd78b
f965c55
 
 
 
 
 
 
 
18d8c3a
f965c55
 
 
4ebbbf4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel
from torch import nn
import re


def paragraph_leveling(text):
    model_name = "./trained_model/fine_tunning_encoder_v2"
    model = AutoModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m')

    class MLP(nn.Module):
        def __init__(self, target_size=3, input_size=768):
            super(MLP, self).__init__()
            self.num_classes = target_size
            self.input_size = input_size
            self.fc1 = nn.Linear(input_size, target_size)

        def forward(self, x):
            out = self.fc1(x)
            return out

    classifier = MLP(target_size=3, input_size=768)
    classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier', map_location=torch.device('cpu')))
    classifier.eval()

    output_list = []
    text_list = text.split(".")
    result = []

    output_list.append(("\n", None))

    for idx_sentence in text_list:
        train_encoding = tokenizer(
            idx_sentence,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=120)
        output = model(**train_encoding)
        output = classifier(output[1])
        output = output[0]

        if output.argmax(-1) == 0:
            output_list.append((idx_sentence, 'abnormal'))
            result.append(0)
        elif output.argmax(-1) == 1:
            output_list.append((idx_sentence, 'normal'))
            result.append(1)
        else:
            output_list.append((idx_sentence, 'uncertain'))
            result.append(2)

    output_list.append(('\n', None))
    if 0 in result:
        output_list.append(('FINAL LABEL: ', None))
        output_list.append(('ABNORMAL', 'abnormal'))

    else:
        output_list.append(('FINAL LABEL: ', None))
        output_list.append(('NORMAL', 'normal'))

    return output_list


demo = gr.Interface(
    paragraph_leveling,
    [
        gr.Textbox(
            label="Medical Report",
            info="You may put radiology medical report. Each sentence should be seperate with period mark.",
            lines=20,
            value=" ",
        ),
    ],
    gr.HighlightedText(
        label="labeling",
        show_legend = True,
        show_label = True,
        color_map={"abnormal": "violet", "normal": "lightgreen", "uncertain": "lightgray"}),
    theme=gr.themes.Base()
)
if __name__ == "__main__":
    demo.launch()