File size: 5,756 Bytes
cbebebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac169b3
 
 
 
 
 
 
 
 
 
cbebebf
 
42052fb
ac169b3
 
 
cbebebf
 
ac169b3
 
 
cbebebf
 
 
 
 
 
 
 
 
 
 
42052fb
 
 
cbebebf
 
 
 
 
42052fb
cbebebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac169b3
 
 
 
 
cbebebf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
import os

# Define the model class
class MedicalCodePredictor(torch.nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert = bert_model
        self.dropout = torch.nn.Dropout(0.1)
        self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES))
        self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES))
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        
        icd_logits = self.icd_classifier(pooled_output)
        cpt_logits = self.cpt_classifier(pooled_output)
        
        return icd_logits, cpt_logits

# Load ICD codes from files
def load_icd_codes_from_files():
    icd_codes = {}
    directory_path = "./codes/icd_txt_files/"  # Path to ICD codes directory

    if os.path.exists(directory_path):
        for file_name in os.listdir(directory_path):
            if file_name.endswith(".txt"):
                file_path = os.path.join(directory_path, file_name)
                with open(file_path, "r", encoding="utf-8") as file:
                    for line in file:
                        # Skip empty lines
                        if line.strip():
                            # Split the line into code and description
                            parts = line.strip().split(maxsplit=1)
                            if len(parts) == 2:
                                code = parts[0].strip()
                                description = parts[1].strip()
                                icd_codes[code] = description
                            else:
                                print(f"Invalid line format in file {file_name}: {line}")
    else:
        print(f"Directory {directory_path} does not exist!")

    if not icd_codes:
        raise ValueError("No ICD codes were loaded. Please check your files and directory structure.")

    return icd_codes

    ICD_CODES = load_icd_codes_from_files()
    print(f"Loaded {len(ICD_CODES)} ICD codes.")

# Load CPT codes from files
def load_cpt_codes_from_files():
    cpt_codes = {}
    directory_path = "./codes/cpt_txt_files/"  # Path to CPT codes directory

    if os.path.exists(directory_path):
        for file_name in os.listdir(directory_path):
            if file_name.endswith(".txt"):
                file_path = os.path.join(directory_path, file_name)
                with open(file_path, "r", encoding="utf-8") as file:
                    for line in file:
                        # Split the line into code and description
                        parts = line.strip().split(maxsplit=1)
                        if len(parts) == 2:
                            code = parts[0].strip()
                            description = parts[1].strip()
                            cpt_codes[code] = description
    else:
        print(f"Directory {directory_path} does not exist!")

    return cpt_codes

# Load ICD and CPT codes dynamically
ICD_CODES = load_icd_codes_from_files()
CPT_CODES = load_cpt_codes_from_files()

# Load models
@torch.no_grad()
def load_models():
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    model = MedicalCodePredictor(base_model)
    return tokenizer, model

# Prediction function
def predict_codes(text):
    if not text.strip():
        return "Please enter a medical summary."
    
    # Tokenize input
    inputs = tokenizer(text, 
                      return_tensors="pt",
                      max_length=512,
                      truncation=True,
                      padding=True)
    
    # Get predictions
    model.eval()
    icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask'])
    
    # Get probabilities
    icd_probs = F.softmax(icd_logits, dim=1)
    cpt_probs = F.softmax(cpt_logits, dim=1)
    
    # Get top 3 predictions
    top_icd = torch.topk(icd_probs, k=3)
    top_cpt = torch.topk(cpt_probs, k=3)

    # Get top k predictions (limit k to the number of available codes)
    top_k = min(3, len(ICD_CODES))
    top_icd = torch.topk(icd_probs, k=top_k)

    
    # Format results
    result = "Recommended ICD-10 Codes:\n"
    for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
        result += f"{i+1}. {ICD_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
    
    result += "\nRecommended CPT Codes:\n"
    for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])):
        result += f"{i+1}. {CPT_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
    
    return result

# Load models globally
tokenizer, model = load_models()

# Create Gradio interface
iface = gr.Interface(
    fn=predict_codes,
    inputs=gr.Textbox(
        lines=5,
        placeholder="Enter medical summary here...",
        label="Medical Summary"
    ),
    outputs=gr.Textbox(
        label="Predicted Codes",
        lines=8
    ),
    title="AutoRCM - Medical Code Predictor",
    description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
    examples=[
        ["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
        ["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
        ["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
    ]
)

# Launch the interface
iface.launch(share=True)