lyangas
commited on
Commit
·
488bb56
1
Parent(s):
1efad19
add method predict_code for prediction code by group
Browse files
app.py
CHANGED
@@ -19,7 +19,7 @@ try:
|
|
19 |
except Exception as e:
|
20 |
print(f"ERROR: loading embedder failed with: {str(e)}")
|
21 |
|
22 |
-
|
23 |
classifiers_codes = {}
|
24 |
try:
|
25 |
for clf_name in os.listdir('classifiers/codes'):
|
@@ -28,10 +28,11 @@ try:
|
|
28 |
with open('classifiers/codes/'+clf_name, 'rb') as f:
|
29 |
model = pickle.load(f)
|
30 |
classifiers_codes[clf_name.split('.')[0]] = model
|
31 |
-
print(f'INFO: classifier {clf_name} loaded')
|
32 |
except Exception as e:
|
33 |
print(f"ERROR: loading classifiers failed with: {str(e)}")
|
34 |
|
|
|
35 |
classifiers_groups = {}
|
36 |
try:
|
37 |
for clf_name in os.listdir('classifiers/groups'):
|
@@ -40,7 +41,21 @@ try:
|
|
40 |
with open('classifiers/groups/'+clf_name, 'rb') as f:
|
41 |
model = pickle.load(f)
|
42 |
classifiers_groups[clf_name.split('.')[0]] = model
|
43 |
-
print(f'INFO: classifier {clf_name} loaded')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
except Exception as e:
|
45 |
print(f"ERROR: loading classifiers failed with: {str(e)}")
|
46 |
|
@@ -68,6 +83,17 @@ def classify_group(text, top_n):
|
|
68 |
preds[clf_name] = clf_preds
|
69 |
return preds
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
def get_top_result(preds):
|
72 |
total_scores = {}
|
73 |
for clf_name, scores in preds.items():
|
@@ -97,7 +123,7 @@ def test():
|
|
97 |
return {'response': data}
|
98 |
|
99 |
@app.route("/predict", methods=['POST'])
|
100 |
-
def
|
101 |
data = request.json
|
102 |
base64_bytes = str(data['textB64']).encode("ascii")
|
103 |
sample_string_bytes = base64.b64decode(base64_bytes)
|
@@ -121,5 +147,28 @@ def read_root():
|
|
121 |
}
|
122 |
return result
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
if __name__ == "__main__":
|
125 |
app.run(host='0.0.0.0', port=7860)
|
|
|
19 |
except Exception as e:
|
20 |
print(f"ERROR: loading embedder failed with: {str(e)}")
|
21 |
|
22 |
+
print('Loading classifiers of codes')
|
23 |
classifiers_codes = {}
|
24 |
try:
|
25 |
for clf_name in os.listdir('classifiers/codes'):
|
|
|
28 |
with open('classifiers/codes/'+clf_name, 'rb') as f:
|
29 |
model = pickle.load(f)
|
30 |
classifiers_codes[clf_name.split('.')[0]] = model
|
31 |
+
print(f'INFO: codes classifier {clf_name} loaded')
|
32 |
except Exception as e:
|
33 |
print(f"ERROR: loading classifiers failed with: {str(e)}")
|
34 |
|
35 |
+
print('Loading classifiers of groups')
|
36 |
classifiers_groups = {}
|
37 |
try:
|
38 |
for clf_name in os.listdir('classifiers/groups'):
|
|
|
41 |
with open('classifiers/groups/'+clf_name, 'rb') as f:
|
42 |
model = pickle.load(f)
|
43 |
classifiers_groups[clf_name.split('.')[0]] = model
|
44 |
+
print(f'INFO: groups classifier {clf_name} loaded')
|
45 |
+
except Exception as e:
|
46 |
+
print(f"ERROR: loading classifiers failed with: {str(e)}")
|
47 |
+
|
48 |
+
print('Loading classifiers in groups')
|
49 |
+
groups_models = {}
|
50 |
+
try:
|
51 |
+
for clf_name in os.listdir('classifiers/codes_in_groups'):
|
52 |
+
if '.' == clf_name[0]:
|
53 |
+
continue
|
54 |
+
with open('classifiers/codes_in_groups/'+clf_name, 'rb') as f:
|
55 |
+
model = pickle.load(f)
|
56 |
+
group_name = clf_name.replace('_code_clf.pkl', '')
|
57 |
+
groups_models[group_name] = model
|
58 |
+
print(f'INFO: codes classifier for group {group_name} loaded')
|
59 |
except Exception as e:
|
60 |
print(f"ERROR: loading classifiers failed with: {str(e)}")
|
61 |
|
|
|
83 |
preds[clf_name] = clf_preds
|
84 |
return preds
|
85 |
|
86 |
+
def classify_code_by_group(text, group_name, top_n):
|
87 |
+
embed = [embedder(text)]
|
88 |
+
model = groups_models[group_name]
|
89 |
+
probs = model.predict_proba(embed)
|
90 |
+
best_n = np.flip(np.argsort(probs, axis=1,)[0,-top_n:])
|
91 |
+
|
92 |
+
top_n_preds = {str(model.classes_[i]): float(probs[0][i]) for i in best_n}
|
93 |
+
top_cls = model.classes_[best_n[0]]
|
94 |
+
all_codes_in_group = model.classes_
|
95 |
+
return top_cls, top_n_preds, all_codes_in_group
|
96 |
+
|
97 |
def get_top_result(preds):
|
98 |
total_scores = {}
|
99 |
for clf_name, scores in preds.items():
|
|
|
123 |
return {'response': data}
|
124 |
|
125 |
@app.route("/predict", methods=['POST'])
|
126 |
+
def predict_api():
|
127 |
data = request.json
|
128 |
base64_bytes = str(data['textB64']).encode("ascii")
|
129 |
sample_string_bytes = base64.b64decode(base64_bytes)
|
|
|
147 |
}
|
148 |
return result
|
149 |
|
150 |
+
@app.route("/predict_code", methods=['POST'])
|
151 |
+
def predict_code_api():
|
152 |
+
data = request.json
|
153 |
+
base64_bytes = str(data['textB64']).encode("ascii")
|
154 |
+
sample_string_bytes = base64.b64decode(base64_bytes)
|
155 |
+
text = sample_string_bytes.decode("ascii")
|
156 |
+
top_n = int(data['top_n'])
|
157 |
+
group_name = data['dx_group']
|
158 |
+
|
159 |
+
if top_n < 1:
|
160 |
+
return {'error': 'top_n should be geather than 0'}
|
161 |
+
if text.strip() == '':
|
162 |
+
return {'error': 'text is empty'}
|
163 |
+
if group_name not in groups_models:
|
164 |
+
return {'error': 'have no classifier for the group'}
|
165 |
+
|
166 |
+
top_pred_code, pred_codes, all_codes_in_group = classify_code_by_group(text, group_name, top_n)
|
167 |
+
result = {
|
168 |
+
"icd10":
|
169 |
+
{'result': top_pred_code, 'details': pred_codes, 'all_codes': all_codes_in_group}
|
170 |
+
}
|
171 |
+
return result
|
172 |
+
|
173 |
if __name__ == "__main__":
|
174 |
app.run(host='0.0.0.0', port=7860)
|