Spaces:
Sleeping
Sleeping
File size: 4,743 Bytes
f40aca6 c6a3045 f40aca6 c6a3045 f40aca6 97b126b f40aca6 c6a3045 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
# Define the model class
class MedicalCodePredictor(torch.nn.Module):
def __init__(self, bert_model):
super().__init__()
self.bert = bert_model
self.dropout = torch.nn.Dropout(0.1)
self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES))
self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES))
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0, :]
pooled_output = self.dropout(pooled_output)
icd_logits = self.icd_classifier(pooled_output)
cpt_logits = self.cpt_classifier(pooled_output)
return icd_logits, cpt_logits
# Define code dictionaries
ICD_CODES = {
0: "I10 - Essential hypertension",
1: "E11.9 - Type 2 diabetes without complications",
2: "J44.9 - COPD",
3: "I25.10 - Atherosclerotic heart disease",
4: "M54.5 - Low back pain",
5: "F41.9 - Anxiety disorder",
6: "J45.909 - Asthma, unspecified",
7: "K21.9 - GERD",
8: "E78.5 - Dyslipidemia",
9: "M17.9 - Osteoarthritis of knee",
10: "E10.9 - Type 1 diabetes without complications",
11: "R51 - Headache",
12: "R50.9 - Fever, unspecified",
13: "R05 - Cough",
14: "S52.5 - Fracture of forearm",
15: "A49.9 - Bacterial infection, unspecified",
16: "R52 - Pain, unspecified",
17: "R11 - Nausea",
18: "S33.5 - Sprain and strain of lumbar spine"
}
CPT_CODES = {
0: "99213 - Office visit, established patient",
1: "99214 - Office visit, established patient, moderate complexity",
2: "99203 - Office visit, new patient",
3: "80053 - Comprehensive metabolic panel",
4: "85025 - Complete blood count",
5: "93000 - ECG with interpretation",
6: "71045 - Chest X-ray",
7: "99395 - Preventive visit, established patient",
8: "96127 - Brief emotional/behavioral assessment",
9: "99396 - Preventive visit, age 40-64",
10: "96372 - Therapeutic injection",
11: "97110 - Therapeutic exercises",
12: "10060 - Incision and drainage of abscess",
13: "76700 - Abdominal ultrasound",
14: "87500 - Infectious agent detection",
15: "72100 - X-ray of lower spine",
16: "72148 - MRI of lumbar spine"
}
# Load models
@torch.no_grad()
def load_models():
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = MedicalCodePredictor(base_model)
return tokenizer, model
# Prediction function
def predict_codes(text):
if not text.strip():
return "Please enter a medical summary."
# Tokenize input
inputs = tokenizer(text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True)
# Get predictions
model.eval()
icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask'])
# Get probabilities
icd_probs = F.softmax(icd_logits, dim=1)
cpt_probs = F.softmax(cpt_logits, dim=1)
# Get top 3 predictions
top_icd = torch.topk(icd_probs, k=3)
top_cpt = torch.topk(cpt_probs, k=3)
# Format results
result = "Recommended ICD-10 Codes:\n"
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
result += f"{i+1}. {ICD_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n"
result += "\nRecommended CPT Codes:\n"
for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])):
result += f"{i+1}. {CPT_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n"
return result
# Load models globally
tokenizer, model = load_models()
# Create Gradio interface
iface = gr.Interface(
fn=predict_codes,
inputs=gr.Textbox(
lines=5,
placeholder="Enter medical summary here...",
label="Medical Summary"
),
outputs=gr.Textbox(
label="Predicted Codes",
lines=8
),
title="AutoRCM - Medical Code Predictor",
description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
examples=[
["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
]
)
# Launch the interface
iface.launch(share=True)
|