File size: 3,545 Bytes
cbebebf
 
56c077f
cbebebf
 
 
56c077f
 
 
cbebebf
 
 
 
 
 
42052fb
 
cbebebf
 
56c077f
cbebebf
 
56c077f
42052fb
56c077f
 
 
cbebebf
56c077f
 
 
cbebebf
56c077f
 
 
cbebebf
 
 
 
 
 
 
56c077f
 
 
 
 
 
 
cbebebf
 
 
56c077f
 
 
cbebebf
 
56c077f
cbebebf
56c077f
ac169b3
56c077f
cbebebf
 
 
 
56c077f
 
 
cbebebf
 
56c077f
 
 
 
cbebebf
 
 
 
 
 
 
 
 
 
 
 
 
56c077f
cbebebf
 
 
 
 
 
 
 
 
 
 
56c077f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
import os

# Load ICD and CPT codes from files
def load_codes_from_files(directory_path, code_type):
    codes = {}
    if os.path.exists(directory_path):
        for file_name in os.listdir(directory_path):
            if file_name.endswith(".txt"):
                file_path = os.path.join(directory_path, file_name)
                with open(file_path, "r", encoding="utf-8") as file:
                    for line in file:
                        parts = line.strip().split(maxsplit=1)
                        if len(parts) == 2:
                            code = parts[0].strip()
                            description = parts[1].strip()
                            codes[code] = description
    else:
        print(f"Directory {directory_path} does not exist!")
    return codes

# Load ICD and CPT codes
ICD_CODES = load_codes_from_files("./codes/icd_txt_files/", "ICD")
CPT_CODES = load_codes_from_files("./codes/cpt_txt_files/", "CPT")

# Check if codes were loaded
if not ICD_CODES or not CPT_CODES:
    raise ValueError("No ICD or CPT codes were loaded. Please check your files and directory structure.")

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(ICD_CODES))

# Prediction function
def predict_codes(text):
    if not text.strip():
        return "Please enter a medical summary."
    
    # Tokenize input
    inputs = tokenizer(
        text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True
    )
    
    # Get predictions
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    
    # Get probabilities
    probs = F.softmax(logits, dim=1)
    
    # Get top 3 predictions for ICD and CPT
    top_k = min(3, len(ICD_CODES))
    top_icd = torch.topk(probs, k=top_k)
    
    # Format results
    result = "Recommended ICD-10 Codes:\n"
    for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
        code = list(ICD_CODES.keys())[idx.item()]
        description = ICD_CODES[code]
        result += f"{i+1}. {code}: {description} (Confidence: {prob.item():.2f})\n"
    
    result += "\nRecommended CPT Codes:\n"
    for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
        code = list(CPT_CODES.keys())[idx.item()]
        description = CPT_CODES[code]
        result += f"{i+1}. {code}: {description} (Confidence: {prob.item():.2f})\n"
    
    return result

# Create Gradio interface
iface = gr.Interface(
    fn=predict_codes,
    inputs=gr.Textbox(
        lines=5,
        placeholder="Enter medical summary here...",
        label="Medical Summary"
    ),
    outputs=gr.Textbox(
        label="Predicted Codes",
        lines=10
    ),
    title="AutoRCM - Medical Code Predictor",
    description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
    examples=[
        ["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
        ["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
        ["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
    ]
)

# Launch the interface
iface.launch(share=True)