import os
import numpy as np
import tensorflow as tf
from PIL import Image
from io import BytesIO
import base64

# Load the model when the script is loaded
model = tf.keras.models.load_model("MobileNet_model.h5")

# Your specific class labels
class_labels = {
    0: "Fake",
    1: "Low",
    2: "Medium",
    3: "High"
}

def preprocess_image(image):
    """Preprocess the image for model prediction"""
    # Resize image to model's expected input dimensions
    image = image.resize((128, 128))
    
    # Convert to numpy array and normalize
    img_array = np.array(image) / 255.0
    
    # Add batch dimension
    img_array = np.expand_dims(img_array, axis=0)
    return img_array

def predict_image(image):
    """Make prediction on a single image"""
    img_array = preprocess_image(image)
    predictions = model.predict(img_array)
    predicted_class_idx = np.argmax(predictions)
    predicted_class = class_labels[predicted_class_idx]
    confidence = float(np.max(predictions))
    
    return {
        "predicted_class": predicted_class,
        "confidence": confidence,
        "class_probabilities": {class_labels[i]: float(prob) for i, prob in enumerate(predictions[0])}
    }

def inference(data):
    """
    Inference function for Hugging Face API
    
    data can be:
    - File path (string)
    - URL string
    - Base64 encoded image
    - Raw image bytes
    - Dict with image key containing any of the above
    """
    # Handle different input formats
    if isinstance(data, dict) and "image" in data:
        data = data["image"]
    
    # Handle local file path
    if isinstance(data, str) and os.path.isfile(data):
        image = Image.open(data)
    
    # Handle URL (Hugging Face will download the image)
    elif isinstance(data, str) and (data.startswith("http://") or data.startswith("https://")):
        from urllib.request import urlopen
        with urlopen(data) as response:
            image_bytes = response.read()
        image = Image.open(BytesIO(image_bytes))
    
    # Handle base64 encoded image
    elif isinstance(data, str) and data.startswith("data:image"):
        base64_data = data.split(",")[1]
        image_bytes = base64.b64decode(base64_data)
        image = Image.open(BytesIO(image_bytes))
    
    # Handle raw image bytes
    elif isinstance(data, bytes):
        image = Image.open(BytesIO(data))
        
    # Convert RGBA to RGB if needed
    if image.mode == "RGBA":
        image = image.convert("RGB")
        
    # Make prediction
    return predict_image(image)

# For local testing
if __name__ == "__main__":
    # Example of using a file path
    test_image_path = "path/to/test/image.jpg"
    if os.path.exists(test_image_path):
        result = inference(test_image_path)
        print(f"Predicted class: {result['predicted_class']}")
        print(f"Confidence: {result['confidence']:.4f}")