AutoRCM / app.py
mohanjebaraj's picture
Update app.py
56c077f verified
raw
history blame
3.55 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
import os
# Load ICD and CPT codes from files
def load_codes_from_files(directory_path, code_type):
codes = {}
if os.path.exists(directory_path):
for file_name in os.listdir(directory_path):
if file_name.endswith(".txt"):
file_path = os.path.join(directory_path, file_name)
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
parts = line.strip().split(maxsplit=1)
if len(parts) == 2:
code = parts[0].strip()
description = parts[1].strip()
codes[code] = description
else:
print(f"Directory {directory_path} does not exist!")
return codes
# Load ICD and CPT codes
ICD_CODES = load_codes_from_files("./codes/icd_txt_files/", "ICD")
CPT_CODES = load_codes_from_files("./codes/cpt_txt_files/", "CPT")
# Check if codes were loaded
if not ICD_CODES or not CPT_CODES:
raise ValueError("No ICD or CPT codes were loaded. Please check your files and directory structure.")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(ICD_CODES))
# Prediction function
def predict_codes(text):
if not text.strip():
return "Please enter a medical summary."
# Tokenize input
inputs = tokenizer(
text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
# Get predictions
model.eval()
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get probabilities
probs = F.softmax(logits, dim=1)
# Get top 3 predictions for ICD and CPT
top_k = min(3, len(ICD_CODES))
top_icd = torch.topk(probs, k=top_k)
# Format results
result = "Recommended ICD-10 Codes:\n"
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
code = list(ICD_CODES.keys())[idx.item()]
description = ICD_CODES[code]
result += f"{i+1}. {code}: {description} (Confidence: {prob.item():.2f})\n"
result += "\nRecommended CPT Codes:\n"
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
code = list(CPT_CODES.keys())[idx.item()]
description = CPT_CODES[code]
result += f"{i+1}. {code}: {description} (Confidence: {prob.item():.2f})\n"
return result
# Create Gradio interface
iface = gr.Interface(
fn=predict_codes,
inputs=gr.Textbox(
lines=5,
placeholder="Enter medical summary here...",
label="Medical Summary"
),
outputs=gr.Textbox(
label="Predicted Codes",
lines=10
),
title="AutoRCM - Medical Code Predictor",
description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
examples=[
["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
]
)
# Launch the interface
iface.launch(share=True)