mohanjebaraj commited on
Commit
cbebebf
·
verified ·
1 Parent(s): 28f37f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import torch.nn.functional as F
5
+ import os
6
+
7
+ # Define the model class
8
+ class MedicalCodePredictor(torch.nn.Module):
9
+ def __init__(self, bert_model):
10
+ super().__init__()
11
+ self.bert = bert_model
12
+ self.dropout = torch.nn.Dropout(0.1)
13
+ self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES))
14
+ self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES))
15
+
16
+ def forward(self, input_ids, attention_mask):
17
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
18
+ pooled_output = outputs.last_hidden_state[:, 0, :]
19
+ pooled_output = self.dropout(pooled_output)
20
+
21
+ icd_logits = self.icd_classifier(pooled_output)
22
+ cpt_logits = self.cpt_classifier(pooled_output)
23
+
24
+ return icd_logits, cpt_logits
25
+
26
+ # Load ICD codes from files
27
+ def load_icd_codes_from_files():
28
+ icd_codes = {}
29
+ directory_path = "./codes/icd_txt_files/" # Path to ICD codes directory
30
+
31
+ if os.path.exists(directory_path):
32
+ for file_name in os.listdir(directory_path):
33
+ if file_name.endswith(".txt"):
34
+ file_path = os.path.join(directory_path, file_name)
35
+ with open(file_path, "r", encoding="utf-8") as file:
36
+ for line in file:
37
+ parts = line.strip().split("\t") # Adjust delimiter as needed
38
+ if len(parts) >= 2:
39
+ code = parts[0].strip()
40
+ description = parts[1].strip()
41
+ icd_codes[code] = description
42
+ else:
43
+ print(f"Directory {directory_path} does not exist!")
44
+ return icd_codes
45
+
46
+ # Load CPT codes from files
47
+ def load_cpt_codes_from_files():
48
+ cpt_codes = {}
49
+ directory_path = "./codes/cpt_txt_files/" # Path to CPT codes directory
50
+
51
+ if os.path.exists(directory_path):
52
+ for file_name in os.listdir(directory_path):
53
+ if file_name.endswith(".txt"):
54
+ file_path = os.path.join(directory_path, file_name)
55
+ with open(file_path, "r", encoding="utf-8") as file:
56
+ for line in file:
57
+ parts = line.strip().split("\t") # Adjust delimiter as needed
58
+ if len(parts) >= 2:
59
+ code = parts[0].strip()
60
+ description = parts[1].strip()
61
+ cpt_codes[code] = description
62
+ else:
63
+ print(f"Directory {directory_path} does not exist!")
64
+ return cpt_codes
65
+
66
+ # Load ICD and CPT codes dynamically
67
+ ICD_CODES = load_icd_codes_from_files()
68
+ CPT_CODES = load_cpt_codes_from_files()
69
+
70
+ # Load models
71
+ @torch.no_grad()
72
+ def load_models():
73
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
74
+ base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
75
+ model = MedicalCodePredictor(base_model)
76
+ return tokenizer, model
77
+
78
+ # Prediction function
79
+ def predict_codes(text):
80
+ if not text.strip():
81
+ return "Please enter a medical summary."
82
+
83
+ # Tokenize input
84
+ inputs = tokenizer(text,
85
+ return_tensors="pt",
86
+ max_length=512,
87
+ truncation=True,
88
+ padding=True)
89
+
90
+ # Get predictions
91
+ model.eval()
92
+ icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask'])
93
+
94
+ # Get probabilities
95
+ icd_probs = F.softmax(icd_logits, dim=1)
96
+ cpt_probs = F.softmax(cpt_logits, dim=1)
97
+
98
+ # Get top 3 predictions
99
+ top_icd = torch.topk(icd_probs, k=3)
100
+ top_cpt = torch.topk(cpt_probs, k=3)
101
+
102
+ # Format results
103
+ result = "Recommended ICD-10 Codes:\n"
104
+ for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
105
+ result += f"{i+1}. {ICD_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
106
+
107
+ result += "\nRecommended CPT Codes:\n"
108
+ for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])):
109
+ result += f"{i+1}. {CPT_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
110
+
111
+ return result
112
+
113
+ # Load models globally
114
+ tokenizer, model = load_models()
115
+
116
+ # Create Gradio interface
117
+ iface = gr.Interface(
118
+ fn=predict_codes,
119
+ inputs=gr.Textbox(
120
+ lines=5,
121
+ placeholder="Enter medical summary here...",
122
+ label="Medical Summary"
123
+ ),
124
+ outputs=gr.Textbox(
125
+ label="Predicted Codes",
126
+ lines=8
127
+ ),
128
+ title="AutoRCM - Medical Code Predictor",
129
+ description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
130
+ examples=[
131
+ ["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
132
+ ["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
133
+ ["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
134
+ ]
135
+ )
136
+
137
+ # Launch the interface
138
+ iface.launch(share=True)