Spaces:
Build error
Build error
File size: 2,049 Bytes
1bb21bd 860fd6c 1bb21bd 6c2ab8a 1bb21bd 117020a 1bb21bd 860fd6c 117020a 860fd6c 183a63d 73b8367 117020a 860fd6c cbbf9d1 117020a 860fd6c 183a63d 73b8367 1bb21bd 860fd6c 117020a 1bb21bd 860fd6c 3c7ab9e 860fd6c 3c7ab9e 860fd6c 1bb21bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
print('INFO: import modules')
import json
import gradio as gr
import pickle
from required_classes import *
print('INFO: loading model')
try:
with open('model_finetuned_clear.pkl', 'rb') as f:
model = pickle.load(f)
model.batch_size = 1
print('INFO: model loaded')
except Exception as e:
print(f"ERROR: loading models failed with: {str(e)}")
def classify_code(text, top_n):
embed = model._texts2vecs([text])
probs = model.classifier_code.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
preds = {model.classifier_code.classes_[i]: probs[0][i] for i in best_n}
return preds
def classify_group(text, top_n):
embed = model._texts2vecs([text])
probs = model.classifier_group.predict_proba(embed)
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
preds = {model.classifier_group.classes_[i]: probs[0][i] for i in best_n}
return preds
def classify(text, top_n):
try:
top_n = int(top_n)
res = classify_code(text, top_n), classify_group(text, top_n)
return res
except Exception as e:
error_msg = f"Error: {str(e)}"
return error_msg, error_msg
print('INFO: starting gradio interface')
box_class = gr.Label(label="Result class")
box_group = gr.Label(label="Result group")
def predict(text, top_n):
try:
top_n = int(top_n)
predicted_codes = classify_code(text, top_n)
predicted_groups = classify_group(text, top_n)
return {box_class: predicted_codes, box_group: predicted_groups}
except Exception as e:
error_msg = f"Error: {str(e)}"
return {box_class: error_msg, box_group: error_msg}
default_input_text = json.load(open('default_input.json'))['input_text']
iface = gr.Interface(
enable_queue=True,
title="ICD10-codes classification",
description="",
fn=predict,
inputs=[gr.Textbox(label="Input text", value=default_input_text), gr.Number(label="TOP-N candidates", value=3)],
outputs=[box_class, box_group],
)
iface.launch()
|