import gradio as gr import torch from transformers import AutoTokenizer, AutoModel import torch.nn.functional as F import os # Define the model class class MedicalCodePredictor(torch.nn.Module): def __init__(self, bert_model): super().__init__() self.bert = bert_model self.dropout = torch.nn.Dropout(0.1) self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES)) self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES)) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.last_hidden_state[:, 0, :] pooled_output = self.dropout(pooled_output) icd_logits = self.icd_classifier(pooled_output) cpt_logits = self.cpt_classifier(pooled_output) return icd_logits, cpt_logits # Load ICD codes from files def load_icd_codes_from_files(): icd_codes = {} directory_path = "./codes/icd_txt_files/" # Path to ICD codes directory if os.path.exists(directory_path): for file_name in os.listdir(directory_path): if file_name.endswith(".txt"): file_path = os.path.join(directory_path, file_name) with open(file_path, "r", encoding="utf-8") as file: for line in file: # Skip empty lines if line.strip(): # Split the line into code and description parts = line.strip().split(maxsplit=1) if len(parts) == 2: code = parts[0].strip() description = parts[1].strip() icd_codes[code] = description else: print(f"Invalid line format in file {file_name}: {line}") else: print(f"Directory {directory_path} does not exist!") if not icd_codes: raise ValueError("No ICD codes were loaded. Please check your files and directory structure.") return icd_codes ICD_CODES = load_icd_codes_from_files() print(f"Loaded {len(ICD_CODES)} ICD codes.") # Load CPT codes from files def load_cpt_codes_from_files(): cpt_codes = {} directory_path = "./codes/cpt_txt_files/" # Path to CPT codes directory if os.path.exists(directory_path): for file_name in os.listdir(directory_path): if file_name.endswith(".txt"): file_path = os.path.join(directory_path, file_name) with open(file_path, "r", encoding="utf-8") as file: for line in file: # Split the line into code and description parts = line.strip().split(maxsplit=1) if len(parts) == 2: code = parts[0].strip() description = parts[1].strip() cpt_codes[code] = description else: print(f"Directory {directory_path} does not exist!") return cpt_codes # Load ICD and CPT codes dynamically ICD_CODES = load_icd_codes_from_files() CPT_CODES = load_cpt_codes_from_files() # Load models @torch.no_grad() def load_models(): tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") model = MedicalCodePredictor(base_model) return tokenizer, model # Prediction function def predict_codes(text): if not text.strip(): return "Please enter a medical summary." # Tokenize input inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True, padding=True) # Get predictions model.eval() icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask']) # Get probabilities icd_probs = F.softmax(icd_logits, dim=1) cpt_probs = F.softmax(cpt_logits, dim=1) # Get top 3 predictions top_icd = torch.topk(icd_probs, k=3) top_cpt = torch.topk(cpt_probs, k=3) # Get top k predictions (limit k to the number of available codes) top_k = min(3, len(ICD_CODES)) top_icd = torch.topk(icd_probs, k=top_k) # Format results result = "Recommended ICD-10 Codes:\n" for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])): result += f"{i+1}. {ICD_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n" result += "\nRecommended CPT Codes:\n" for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])): result += f"{i+1}. {CPT_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n" return result # Load models globally tokenizer, model = load_models() # Create Gradio interface iface = gr.Interface( fn=predict_codes, inputs=gr.Textbox( lines=5, placeholder="Enter medical summary here...", label="Medical Summary" ), outputs=gr.Textbox( label="Predicted Codes", lines=8 ), title="AutoRCM - Medical Code Predictor", description="Enter a medical summary to get recommended ICD-10 and CPT codes.", examples=[ ["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."], ["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."], ["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."] ] ) # Launch the interface iface.launch(share=True)