import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch.nn.functional as F # Mock ICD and CPT data (replace with actual API calls or datasets) def fetch_icd_codes(query): # Mock ICD codes for demonstration return [ {"code": "R50.9", "description": "Fever, unspecified"}, {"code": "A00", "description": "Cholera"}, {"code": "J06.9", "description": "Acute upper respiratory infection, unspecified"} ] def fetch_cpt_codes(query): # Mock CPT codes for demonstration return [ {"code": "99213", "description": "Office or other outpatient visit"}, {"code": "87804", "description": "Infectious agent detection by immunoassay"}, {"code": "85025", "description": "Complete blood count (CBC)"} ] # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") model = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=1000) # Adjust num_labels as needed # Prediction function def predict_codes(text): if not text.strip(): return "Please enter a medical summary." # Tokenize input inputs = tokenizer( text, return_tensors="pt", max_length=512, truncation=True, padding=True ) # Get predictions model.eval() with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get probabilities probs = F.softmax(logits, dim=1) # Get top 3 predictions top_k = torch.topk(probs, k=3) # Fetch ICD and CPT codes using mock functions icd_results = fetch_icd_codes(text) cpt_results = fetch_cpt_codes(text) # Format results result = "Recommended ICD-10 Codes:\n" for i, code in enumerate(icd_results[:3]): # Show top 3 ICD codes result += f"{i+1}. {code.get('code', 'Unknown')}: {code.get('description', 'No description')}\n" result += "\nRecommended CPT Codes:\n" for i, code in enumerate(cpt_results[:3]): # Show top 3 CPT codes result += f"{i+1}. {code.get('code', 'Unknown')}: {code.get('description', 'No description')}\n" return result # Create Gradio interface iface = gr.Interface( fn=predict_codes, inputs=gr.Textbox( lines=5, placeholder="Enter medical summary here...", label="Medical Summary" ), outputs=gr.Textbox( label="Predicted Codes", lines=10 ), title="AutoRCM - Medical Code Predictor", description="Enter a medical summary to get recommended ICD-10 and CPT codes.", examples=[ ["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."], ["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."], ["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."] ], allow_flagging="never" # Disable caching ) # Launch the interface iface.launch(share=True)