AutoRCM / app.py
mohanjebaraj's picture
Update app.py
ac169b3 verified
raw
history blame
5.76 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
import os
# Define the model class
class MedicalCodePredictor(torch.nn.Module):
def __init__(self, bert_model):
super().__init__()
self.bert = bert_model
self.dropout = torch.nn.Dropout(0.1)
self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES))
self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES))
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0, :]
pooled_output = self.dropout(pooled_output)
icd_logits = self.icd_classifier(pooled_output)
cpt_logits = self.cpt_classifier(pooled_output)
return icd_logits, cpt_logits
# Load ICD codes from files
def load_icd_codes_from_files():
icd_codes = {}
directory_path = "./codes/icd_txt_files/" # Path to ICD codes directory
if os.path.exists(directory_path):
for file_name in os.listdir(directory_path):
if file_name.endswith(".txt"):
file_path = os.path.join(directory_path, file_name)
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
# Skip empty lines
if line.strip():
# Split the line into code and description
parts = line.strip().split(maxsplit=1)
if len(parts) == 2:
code = parts[0].strip()
description = parts[1].strip()
icd_codes[code] = description
else:
print(f"Invalid line format in file {file_name}: {line}")
else:
print(f"Directory {directory_path} does not exist!")
if not icd_codes:
raise ValueError("No ICD codes were loaded. Please check your files and directory structure.")
return icd_codes
ICD_CODES = load_icd_codes_from_files()
print(f"Loaded {len(ICD_CODES)} ICD codes.")
# Load CPT codes from files
def load_cpt_codes_from_files():
cpt_codes = {}
directory_path = "./codes/cpt_txt_files/" # Path to CPT codes directory
if os.path.exists(directory_path):
for file_name in os.listdir(directory_path):
if file_name.endswith(".txt"):
file_path = os.path.join(directory_path, file_name)
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
# Split the line into code and description
parts = line.strip().split(maxsplit=1)
if len(parts) == 2:
code = parts[0].strip()
description = parts[1].strip()
cpt_codes[code] = description
else:
print(f"Directory {directory_path} does not exist!")
return cpt_codes
# Load ICD and CPT codes dynamically
ICD_CODES = load_icd_codes_from_files()
CPT_CODES = load_cpt_codes_from_files()
# Load models
@torch.no_grad()
def load_models():
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = MedicalCodePredictor(base_model)
return tokenizer, model
# Prediction function
def predict_codes(text):
if not text.strip():
return "Please enter a medical summary."
# Tokenize input
inputs = tokenizer(text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True)
# Get predictions
model.eval()
icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask'])
# Get probabilities
icd_probs = F.softmax(icd_logits, dim=1)
cpt_probs = F.softmax(cpt_logits, dim=1)
# Get top 3 predictions
top_icd = torch.topk(icd_probs, k=3)
top_cpt = torch.topk(cpt_probs, k=3)
# Get top k predictions (limit k to the number of available codes)
top_k = min(3, len(ICD_CODES))
top_icd = torch.topk(icd_probs, k=top_k)
# Format results
result = "Recommended ICD-10 Codes:\n"
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
result += f"{i+1}. {ICD_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
result += "\nRecommended CPT Codes:\n"
for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])):
result += f"{i+1}. {CPT_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
return result
# Load models globally
tokenizer, model = load_models()
# Create Gradio interface
iface = gr.Interface(
fn=predict_codes,
inputs=gr.Textbox(
lines=5,
placeholder="Enter medical summary here...",
label="Medical Summary"
),
outputs=gr.Textbox(
label="Predicted Codes",
lines=8
),
title="AutoRCM - Medical Code Predictor",
description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
examples=[
["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
]
)
# Launch the interface
iface.launch(share=True)