import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch.nn.functional as F import os # Load ICD and CPT codes from files def load_codes_from_files(directory_path, code_type): codes = {} if os.path.exists(directory_path): for file_name in os.listdir(directory_path): if file_name.endswith(".txt"): file_path = os.path.join(directory_path, file_name) with open(file_path, "r", encoding="utf-8") as file: for line in file: parts = line.strip().split(maxsplit=1) if len(parts) == 2: code = parts[0].strip() description = parts[1].strip() codes[code] = description else: print(f"Directory {directory_path} does not exist!") return codes # Load ICD and CPT codes ICD_CODES = load_codes_from_files("./codes/icd_txt_files/", "ICD") CPT_CODES = load_codes_from_files("./codes/cpt_txt_files/", "CPT") # Check if codes were loaded if not ICD_CODES or not CPT_CODES: raise ValueError("No ICD or CPT codes were loaded. Please check your files and directory structure.") # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") model = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(ICD_CODES)) # Prediction function def predict_codes(text): if not text.strip(): return "Please enter a medical summary." # Tokenize input inputs = tokenizer( text, return_tensors="pt", max_length=512, truncation=True, padding=True ) # Get predictions model.eval() with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get probabilities probs = F.softmax(logits, dim=1) # Get top 3 predictions for ICD and CPT top_k = min(3, len(ICD_CODES)) top_icd = torch.topk(probs, k=top_k) # Format results result = "Recommended ICD-10 Codes:\n" for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])): code = list(ICD_CODES.keys())[idx.item()] description = ICD_CODES[code] result += f"{i+1}. {code}: {description} (Confidence: {prob.item():.2f})\n" result += "\nRecommended CPT Codes:\n" for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])): code = list(CPT_CODES.keys())[idx.item()] description = CPT_CODES[code] result += f"{i+1}. {code}: {description} (Confidence: {prob.item():.2f})\n" return result # Create Gradio interface iface = gr.Interface( fn=predict_codes, inputs=gr.Textbox( lines=5, placeholder="Enter medical summary here...", label="Medical Summary" ), outputs=gr.Textbox( label="Predicted Codes", lines=10 ), title="AutoRCM - Medical Code Predictor", description="Enter a medical summary to get recommended ICD-10 and CPT codes.", examples=[ ["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."], ["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."], ["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."] ] ) # Launch the interface iface.launch(share=True)