import gradio as gr import torch from transformers import AutoTokenizer, BertForSequenceClassification, AutoModel from torch import nn import re def paragraph_leveling(text): model_name = "./trained_model/fine_tunning_encoder_v2" model = AutoModel.from_pretrained(model_name) model.to('cuda') tokenizer = AutoTokenizer.from_pretrained('zzxslp/RadBERT-RoBERTa-4m') class MLP(nn.Module): def __init__(self, target_size=3, input_size=768): super(MLP, self).__init__() self.num_classes = target_size self.input_size = input_size self.fc1 = nn.Linear(input_size, target_size) def forward(self, x): out = self.fc1(x) return out classifier = MLP(target_size=3, input_size=768) classifier.load_state_dict(torch.load('./trained_model/fine_tunning_classifier')) classifier.cuda() classifier.eval() output_list = [] text_list = text.split(".") result = [] # output_list.append(('Label: ', None)) # output_list.append(('abnormal', 'abnormal')) # output_list.append(('normal', 'normal')) # # output_list.append((' ', 'normal')) # output_list.append(('not much information', 'not much information')) # # output_list.append((' ', 'not much information')) output_list.append(("\n", None)) for idx_sentence in text_list: train_encoding = tokenizer( idx_sentence, return_tensors='pt', padding='max_length', truncation=True, max_length=120) output = model(**train_encoding.to('cuda')) output = classifier(output[1]) output = output[0] if output.argmax(-1) == 0: output_list.append((idx_sentence, 'abnormal')) result.append(0) elif output.argmax(-1) == 1: output_list.append((idx_sentence, 'normal')) result.append(1) else: output_list.append((idx_sentence, 'not much information')) result.append(2) output_list.append(('\n', None)) if 0 in result: output_list.append(('FINAL LABEL: ', None)) output_list.append(('ABNORMAL', 'abnormal')) else: output_list.append(('FINAL LABEL: ', None)) output_list.append(('NORMAL', 'normal')) return output_list demo = gr.Interface( paragraph_leveling, [ gr.Textbox( label="Medical Report", info="You can put any types of medical report", lines=20, value=" ", ), ], gr.HighlightedText( label="labeling", show_legend = True, show_label = True, color_map={"abnormal": "violet", "normal": "lightgreen", "not much information": "lightgray"}), theme=gr.themes.Base() ) if __name__ == "__main__": demo.launch(share=True)