File size: 4,743 Bytes
f40aca6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6a3045
 
 
 
 
 
 
 
 
 
f40aca6
 
 
 
 
 
 
 
 
 
 
 
c6a3045
 
 
 
 
 
 
 
f40aca6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97b126b
f40aca6
 
 
 
 
 
 
 
 
c6a3045
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

# Define the model class
class MedicalCodePredictor(torch.nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert = bert_model
        self.dropout = torch.nn.Dropout(0.1)
        self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES))
        self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES))
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        
        icd_logits = self.icd_classifier(pooled_output)
        cpt_logits = self.cpt_classifier(pooled_output)
        
        return icd_logits, cpt_logits

# Define code dictionaries
ICD_CODES = {
    0: "I10 - Essential hypertension",
    1: "E11.9 - Type 2 diabetes without complications",
    2: "J44.9 - COPD",
    3: "I25.10 - Atherosclerotic heart disease",
    4: "M54.5 - Low back pain",
    5: "F41.9 - Anxiety disorder",
    6: "J45.909 - Asthma, unspecified",
    7: "K21.9 - GERD",
    8: "E78.5 - Dyslipidemia",
    9: "M17.9 - Osteoarthritis of knee",
    10: "E10.9 - Type 1 diabetes without complications",
    11: "R51 - Headache",
    12: "R50.9 - Fever, unspecified",
    13: "R05 - Cough",
    14: "S52.5 - Fracture of forearm",
    15: "A49.9 - Bacterial infection, unspecified",
    16: "R52 - Pain, unspecified",
    17: "R11 - Nausea",
    18: "S33.5 - Sprain and strain of lumbar spine"
}

CPT_CODES = {
    0: "99213 - Office visit, established patient",
    1: "99214 - Office visit, established patient, moderate complexity",
    2: "99203 - Office visit, new patient",
    3: "80053 - Comprehensive metabolic panel",
    4: "85025 - Complete blood count",
    5: "93000 - ECG with interpretation",
    6: "71045 - Chest X-ray",
    7: "99395 - Preventive visit, established patient",
    8: "96127 - Brief emotional/behavioral assessment",
    9: "99396 - Preventive visit, age 40-64",
    10: "96372 - Therapeutic injection",
    11: "97110 - Therapeutic exercises",
    12: "10060 - Incision and drainage of abscess",
    13: "76700 - Abdominal ultrasound",
    14: "87500 - Infectious agent detection",
    15: "72100 - X-ray of lower spine",
    16: "72148 - MRI of lumbar spine"
}

# Load models
@torch.no_grad()
def load_models():
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    model = MedicalCodePredictor(base_model)
    return tokenizer, model

# Prediction function
def predict_codes(text):
    if not text.strip():
        return "Please enter a medical summary."
    
    # Tokenize input
    inputs = tokenizer(text, 
                      return_tensors="pt",
                      max_length=512,
                      truncation=True,
                      padding=True)
    
    # Get predictions
    model.eval()
    icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask'])
    
    # Get probabilities
    icd_probs = F.softmax(icd_logits, dim=1)
    cpt_probs = F.softmax(cpt_logits, dim=1)
    
    # Get top 3 predictions
    top_icd = torch.topk(icd_probs, k=3)
    top_cpt = torch.topk(cpt_probs, k=3)
    
    # Format results
    result = "Recommended ICD-10 Codes:\n"
    for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
        result += f"{i+1}. {ICD_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n"
    
    result += "\nRecommended CPT Codes:\n"
    for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])):
        result += f"{i+1}. {CPT_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n"
    
    return result

# Load models globally
tokenizer, model = load_models()

# Create Gradio interface
iface = gr.Interface(
    fn=predict_codes,
    inputs=gr.Textbox(
        lines=5,
        placeholder="Enter medical summary here...",
        label="Medical Summary"
    ),
    outputs=gr.Textbox(
        label="Predicted Codes",
        lines=8
    ),
    title="AutoRCM - Medical Code Predictor",
    description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
    examples=[
        ["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
        ["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
        ["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
    ]
)

# Launch the interface
iface.launch(share=True)