import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

# Define the model class
class MedicalCodePredictor(torch.nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert = bert_model
        self.dropout = torch.nn.Dropout(0.1)
        self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES))
        self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES))
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        
        icd_logits = self.icd_classifier(pooled_output)
        cpt_logits = self.cpt_classifier(pooled_output)
        
        return icd_logits, cpt_logits

# Define code dictionaries
ICD_CODES = {
    0: "I10 - Essential hypertension",
    1: "E11.9 - Type 2 diabetes without complications",
    2: "J44.9 - COPD",
    3: "I25.10 - Atherosclerotic heart disease",
    4: "M54.5 - Low back pain",
    5: "F41.9 - Anxiety disorder",
    6: "J45.909 - Asthma, unspecified",
    7: "K21.9 - GERD",
    8: "E78.5 - Dyslipidemia",
    9: "M17.9 - Osteoarthritis of knee",
    10: "E10.9 - Type 1 diabetes without complications",
    11: "R51 - Headache",
    12: "R50.9 - Fever, unspecified",
    13: "R05 - Cough",
    14: "S52.5 - Fracture of forearm",
    15: "A49.9 - Bacterial infection, unspecified",
    16: "R52 - Pain, unspecified",
    17: "R11 - Nausea",
    18: "S33.5 - Sprain and strain of lumbar spine"
}

CPT_CODES = {
    0: "99213 - Office visit, established patient",
    1: "99214 - Office visit, established patient, moderate complexity",
    2: "99203 - Office visit, new patient",
    3: "80053 - Comprehensive metabolic panel",
    4: "85025 - Complete blood count",
    5: "93000 - ECG with interpretation",
    6: "71045 - Chest X-ray",
    7: "99395 - Preventive visit, established patient",
    8: "96127 - Brief emotional/behavioral assessment",
    9: "99396 - Preventive visit, age 40-64",
    10: "96372 - Therapeutic injection",
    11: "97110 - Therapeutic exercises",
    12: "10060 - Incision and drainage of abscess",
    13: "76700 - Abdominal ultrasound",
    14: "87500 - Infectious agent detection",
    15: "72100 - X-ray of lower spine",
    16: "72148 - MRI of lumbar spine"
}

# Load models
@torch.no_grad()
def load_models():
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    model = MedicalCodePredictor(base_model)
    return tokenizer, model

# Prediction function
def predict_codes(text):
    if not text.strip():
        return "Please enter a medical summary."
    
    # Tokenize input
    inputs = tokenizer(text, 
                      return_tensors="pt",
                      max_length=512,
                      truncation=True,
                      padding=True)
    
    # Get predictions
    model.eval()
    icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask'])
    
    # Get probabilities
    icd_probs = F.softmax(icd_logits, dim=1)
    cpt_probs = F.softmax(cpt_logits, dim=1)
    
    # Get top 3 predictions
    top_icd = torch.topk(icd_probs, k=3)
    top_cpt = torch.topk(cpt_probs, k=3)
    
    # Format results
    result = "Recommended ICD-10 Codes:\n"
    for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
        result += f"{i+1}. {ICD_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n"
    
    result += "\nRecommended CPT Codes:\n"
    for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])):
        result += f"{i+1}. {CPT_CODES[idx.item()]} (Confidence: {prob.item():.2f})\n"
    
    return result

# Load models globally
tokenizer, model = load_models()

# Create Gradio interface
iface = gr.Interface(
    fn=predict_codes,
    inputs=gr.Textbox(
        lines=5,
        placeholder="Enter medical summary here...",
        label="Medical Summary"
    ),
    outputs=gr.Textbox(
        label="Predicted Codes",
        lines=8
    ),
    title="AutoRCM - Medical Code Predictor",
    description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
    examples=[
        ["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
        ["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
        ["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
    ]
)

# Launch the interface
iface.launch(share=True)