import gradio as gr import torch from transformers import AutoTokenizer, AutoModel import torch.nn.functional as F # Define the model class class MedicalCodePredictor(torch.nn.Module): def __init__(self, bert_model): super().__init__() self.bert = bert_model self.dropout = torch.nn.Dropout(0.1) self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES)) self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES)) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.last_hidden_state[:, 0, :] pooled_output = self.dropout(pooled_output) icd_logits = self.icd_classifier(pooled_output) cpt_logits = self.cpt_classifier(pooled_output) return icd_logits, cpt_logits # Define code dictionaries ICD_CODES = { 0: "I10 - Essential hypertension", 1: "E11.9 - Type 2 diabetes without complications", 2: "J44.9 - COPD", 3: "I25.10 - Atherosclerotic heart disease", 4: "M54.5 - Low back pain", 5: "F41.9 - Anxiety disorder", 6: "J45.909 - Asthma, unspecified", 7: "K21.9 - GERD", 8: "E78.5 - Dyslipidemia", 9: "M17.9 - Osteoarthritis of knee", 10: "E10.9 - Type 1 diabetes without complications", 11: "R51 - Headache", 12: "R50.9 - Fever, unspecified", 13: "R05 - Cough", 14: "S52.5 - Fracture of forearm", 15: "A49.9 - Bacterial infection, unspecified", 16: "R52 - Pain, unspecified", 17: "R11 - Nausea", 18: "S33.5 - Sprain and strain of lumbar spine" } CPT_CODES = { 0: "99213 - Office visit, established patient", 1: "99214 - Office visit, established patient, moderate complexity", 2: "99203 - Office visit, new patient", 3: "80053 - Comprehensive metabolic panel", 4: "85025 - Complete blood count", 5: "93000 - ECG with interpretation", 6: "71045 - Chest X-ray", 7: "99395 - Preventive visit, established patient", 8: "96127 - Brief emotional/behavioral assessment", 9: "99396 - Preventive visit, age 40-64", 10: "96372 - Therapeutic injection", 11: "97110 - Therapeutic exercises", 12: "10060 - Incision and drainage of abscess", 13: "76700 - Abdominal ultrasound", 14: "87500 - Infectious agent detection", 15: "72100 - X-ray of lower spine", 16: "72148 - MRI of lumbar spine" } # Load models @torch.no_grad() def load_models(): tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") model = MedicalCodePredictor(base_model) return tokenizer, model # Prediction function def predict_codes(text): if not text.strip(): return "Please enter a medical summary." # Tokenize input inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True, padding=True) # Get predictions model.eval() icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask']) # Get probabilities icd_probs = F.softmax(icd_logits, dim=1) cpt_probs = F.softmax(cpt_logits, dim=1) # Get top 3 predictions top_icd = torch.topk(icd_probs, k=3) top_cpt = torch.topk(cpt_probs, k=3) # Format results result = "Recommended ICD-10 Codes:\n" for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])): result += f"{i+1}. {ICD_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n" result += "\nRecommended CPT Codes:\n" for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])): result += f"{i+1}. {CPT_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n" return result # Load models globally tokenizer, model = load_models() # Create Gradio interface iface = gr.Interface( fn=predict_codes, inputs=gr.Textbox( lines=5, placeholder="Enter medical summary here...", label="Medical Summary" ), outputs=gr.Textbox( label="Predicted Codes", lines=8 ), title="AutoRCM - Medical Code Predictor", description="Enter a medical summary to get recommended ICD-10 and CPT codes.", examples=[ ["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."], ["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."], ["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."] ] ) # Launch the interface iface.launch(share=True)