Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import torch.nn.functional as F | |
import os | |
# Define the model class | |
class MedicalCodePredictor(torch.nn.Module): | |
def __init__(self, bert_model): | |
super().__init__() | |
self.bert = bert_model | |
self.dropout = torch.nn.Dropout(0.1) | |
self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES)) | |
self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES)) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
pooled_output = outputs.last_hidden_state[:, 0, :] | |
pooled_output = self.dropout(pooled_output) | |
icd_logits = self.icd_classifier(pooled_output) | |
cpt_logits = self.cpt_classifier(pooled_output) | |
return icd_logits, cpt_logits | |
# Load ICD codes from files | |
def load_icd_codes_from_files(): | |
icd_codes = {} | |
directory_path = "./codes/icd_txt_files/" # Path to ICD codes directory | |
if os.path.exists(directory_path): | |
for file_name in os.listdir(directory_path): | |
if file_name.endswith(".txt"): | |
file_path = os.path.join(directory_path, file_name) | |
with open(file_path, "r", encoding="utf-8") as file: | |
for line in file: | |
# Skip empty lines | |
if line.strip(): | |
# Split the line into code and description | |
parts = line.strip().split(maxsplit=1) | |
if len(parts) == 2: | |
code = parts[0].strip() | |
description = parts[1].strip() | |
icd_codes[code] = description | |
else: | |
print(f"Invalid line format in file {file_name}: {line}") | |
else: | |
print(f"Directory {directory_path} does not exist!") | |
if not icd_codes: | |
raise ValueError("No ICD codes were loaded. Please check your files and directory structure.") | |
return icd_codes | |
ICD_CODES = load_icd_codes_from_files() | |
print(f"Loaded {len(ICD_CODES)} ICD codes.") | |
# Load CPT codes from files | |
def load_cpt_codes_from_files(): | |
cpt_codes = {} | |
directory_path = "./codes/cpt_txt_files/" # Path to CPT codes directory | |
if os.path.exists(directory_path): | |
for file_name in os.listdir(directory_path): | |
if file_name.endswith(".txt"): | |
file_path = os.path.join(directory_path, file_name) | |
with open(file_path, "r", encoding="utf-8") as file: | |
for line in file: | |
# Split the line into code and description | |
parts = line.strip().split(maxsplit=1) | |
if len(parts) == 2: | |
code = parts[0].strip() | |
description = parts[1].strip() | |
cpt_codes[code] = description | |
else: | |
print(f"Directory {directory_path} does not exist!") | |
return cpt_codes | |
# Load ICD and CPT codes dynamically | |
ICD_CODES = load_icd_codes_from_files() | |
CPT_CODES = load_cpt_codes_from_files() | |
# Load models | |
def load_models(): | |
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
model = MedicalCodePredictor(base_model) | |
return tokenizer, model | |
# Prediction function | |
def predict_codes(text): | |
if not text.strip(): | |
return "Please enter a medical summary." | |
# Tokenize input | |
inputs = tokenizer(text, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True) | |
# Get predictions | |
model.eval() | |
icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask']) | |
# Get probabilities | |
icd_probs = F.softmax(icd_logits, dim=1) | |
cpt_probs = F.softmax(cpt_logits, dim=1) | |
# Get top 3 predictions | |
top_icd = torch.topk(icd_probs, k=3) | |
top_cpt = torch.topk(cpt_probs, k=3) | |
# Get top k predictions (limit k to the number of available codes) | |
top_k = min(3, len(ICD_CODES)) | |
top_icd = torch.topk(icd_probs, k=top_k) | |
# Format results | |
result = "Recommended ICD-10 Codes:\n" | |
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])): | |
result += f"{i+1}. {ICD_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n" | |
result += "\nRecommended CPT Codes:\n" | |
for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])): | |
result += f"{i+1}. {CPT_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n" | |
return result | |
# Load models globally | |
tokenizer, model = load_models() | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_codes, | |
inputs=gr.Textbox( | |
lines=5, | |
placeholder="Enter medical summary here...", | |
label="Medical Summary" | |
), | |
outputs=gr.Textbox( | |
label="Predicted Codes", | |
lines=8 | |
), | |
title="AutoRCM - Medical Code Predictor", | |
description="Enter a medical summary to get recommended ICD-10 and CPT codes.", | |
examples=[ | |
["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."], | |
["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."], | |
["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."] | |
] | |
) | |
# Launch the interface | |
iface.launch(share=True) | |