File size: 2,049 Bytes
1bb21bd
860fd6c
1bb21bd
 
 
 
 
 
 
6c2ab8a
1bb21bd
 
 
 
117020a
1bb21bd
860fd6c
117020a
 
860fd6c
183a63d
73b8367
117020a
860fd6c
cbbf9d1
117020a
860fd6c
183a63d
73b8367
1bb21bd
860fd6c
 
 
 
 
 
 
 
117020a
1bb21bd
860fd6c
 
 
 
 
 
 
 
 
 
 
 
 
3c7ab9e
 
 
 
860fd6c
 
 
3c7ab9e
860fd6c
1bb21bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
print('INFO: import modules')
import json
import gradio as gr
import pickle
from required_classes import *


print('INFO: loading model')
try:
    with open('model_finetuned_clear.pkl', 'rb') as f:
        model = pickle.load(f)
    model.batch_size = 1
    print('INFO: model loaded')
except Exception as e:
    print(f"ERROR: loading models failed with: {str(e)}")

def classify_code(text, top_n):
    embed = model._texts2vecs([text])
    probs = model.classifier_code.predict_proba(embed)
    best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
    preds = {model.classifier_code.classes_[i]: probs[0][i] for i in best_n}
    return preds

def classify_group(text, top_n):
    embed = model._texts2vecs([text])
    probs = model.classifier_group.predict_proba(embed)
    best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
    preds = {model.classifier_group.classes_[i]: probs[0][i] for i in best_n}
    return preds

def classify(text, top_n):
    try:
        top_n = int(top_n)
        res = classify_code(text, top_n), classify_group(text, top_n)
        return res
    except Exception as e:
        error_msg = f"Error: {str(e)}"
        return error_msg, error_msg

print('INFO: starting gradio interface')
box_class = gr.Label(label="Result class")
box_group = gr.Label(label="Result group")
def predict(text, top_n):
    try:
        top_n = int(top_n)
        predicted_codes = classify_code(text, top_n)
        predicted_groups = classify_group(text, top_n)
        return {box_class: predicted_codes, box_group: predicted_groups}
    except Exception as e:
        error_msg = f"Error: {str(e)}"
        return  {box_class: error_msg, box_group: error_msg}

default_input_text = json.load(open('default_input.json'))['input_text']
iface = gr.Interface(
    enable_queue=True,
    title="ICD10-codes classification",
    description="",
    fn=predict,
    inputs=[gr.Textbox(label="Input text", value=default_input_text), gr.Number(label="TOP-N candidates", value=3)],
    outputs=[box_class, box_group],
)

iface.launch()