Spaces:
Sleeping
Sleeping
File size: 5,756 Bytes
cbebebf ac169b3 cbebebf 42052fb ac169b3 cbebebf ac169b3 cbebebf 42052fb cbebebf 42052fb cbebebf ac169b3 cbebebf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
import os
# Define the model class
class MedicalCodePredictor(torch.nn.Module):
def __init__(self, bert_model):
super().__init__()
self.bert = bert_model
self.dropout = torch.nn.Dropout(0.1)
self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES))
self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES))
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0, :]
pooled_output = self.dropout(pooled_output)
icd_logits = self.icd_classifier(pooled_output)
cpt_logits = self.cpt_classifier(pooled_output)
return icd_logits, cpt_logits
# Load ICD codes from files
def load_icd_codes_from_files():
icd_codes = {}
directory_path = "./codes/icd_txt_files/" # Path to ICD codes directory
if os.path.exists(directory_path):
for file_name in os.listdir(directory_path):
if file_name.endswith(".txt"):
file_path = os.path.join(directory_path, file_name)
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
# Skip empty lines
if line.strip():
# Split the line into code and description
parts = line.strip().split(maxsplit=1)
if len(parts) == 2:
code = parts[0].strip()
description = parts[1].strip()
icd_codes[code] = description
else:
print(f"Invalid line format in file {file_name}: {line}")
else:
print(f"Directory {directory_path} does not exist!")
if not icd_codes:
raise ValueError("No ICD codes were loaded. Please check your files and directory structure.")
return icd_codes
ICD_CODES = load_icd_codes_from_files()
print(f"Loaded {len(ICD_CODES)} ICD codes.")
# Load CPT codes from files
def load_cpt_codes_from_files():
cpt_codes = {}
directory_path = "./codes/cpt_txt_files/" # Path to CPT codes directory
if os.path.exists(directory_path):
for file_name in os.listdir(directory_path):
if file_name.endswith(".txt"):
file_path = os.path.join(directory_path, file_name)
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
# Split the line into code and description
parts = line.strip().split(maxsplit=1)
if len(parts) == 2:
code = parts[0].strip()
description = parts[1].strip()
cpt_codes[code] = description
else:
print(f"Directory {directory_path} does not exist!")
return cpt_codes
# Load ICD and CPT codes dynamically
ICD_CODES = load_icd_codes_from_files()
CPT_CODES = load_cpt_codes_from_files()
# Load models
@torch.no_grad()
def load_models():
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = MedicalCodePredictor(base_model)
return tokenizer, model
# Prediction function
def predict_codes(text):
if not text.strip():
return "Please enter a medical summary."
# Tokenize input
inputs = tokenizer(text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True)
# Get predictions
model.eval()
icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask'])
# Get probabilities
icd_probs = F.softmax(icd_logits, dim=1)
cpt_probs = F.softmax(cpt_logits, dim=1)
# Get top 3 predictions
top_icd = torch.topk(icd_probs, k=3)
top_cpt = torch.topk(cpt_probs, k=3)
# Get top k predictions (limit k to the number of available codes)
top_k = min(3, len(ICD_CODES))
top_icd = torch.topk(icd_probs, k=top_k)
# Format results
result = "Recommended ICD-10 Codes:\n"
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
result += f"{i+1}. {ICD_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
result += "\nRecommended CPT Codes:\n"
for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])):
result += f"{i+1}. {CPT_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
return result
# Load models globally
tokenizer, model = load_models()
# Create Gradio interface
iface = gr.Interface(
fn=predict_codes,
inputs=gr.Textbox(
lines=5,
placeholder="Enter medical summary here...",
label="Medical Summary"
),
outputs=gr.Textbox(
label="Predicted Codes",
lines=8
),
title="AutoRCM - Medical Code Predictor",
description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
examples=[
["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
]
)
# Launch the interface
iface.launch(share=True)
|