MobileNet_Fire / app.py
Suhani-2407's picture
Update app.py
b291d86 verified
import os
import numpy as np
import tensorflow as tf
from PIL import Image
from io import BytesIO
import base64
# Load the model when the script is loaded
model = tf.keras.models.load_model("MobileNet_model.h5")
# Your specific class labels
class_labels = {
0: "Fake",
1: "Low",
2: "Medium",
3: "High"
}
def preprocess_image(image):
"""Preprocess the image for model prediction"""
# Resize image to model's expected input dimensions
image = image.resize((128, 128))
# Convert to numpy array and normalize
img_array = np.array(image) / 255.0
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
return img_array
def predict_image(image):
"""Make prediction on a single image"""
img_array = preprocess_image(image)
predictions = model.predict(img_array)
predicted_class_idx = np.argmax(predictions)
predicted_class = class_labels[predicted_class_idx]
confidence = float(np.max(predictions))
return {
"predicted_class": predicted_class,
"confidence": confidence,
"class_probabilities": {class_labels[i]: float(prob) for i, prob in enumerate(predictions[0])}
}
def inference(data):
"""
Inference function for Hugging Face API
data can be:
- File path (string)
- URL string
- Base64 encoded image
- Raw image bytes
- Dict with image key containing any of the above
"""
# Handle different input formats
if isinstance(data, dict) and "image" in data:
data = data["image"]
# Handle local file path
if isinstance(data, str) and os.path.isfile(data):
image = Image.open(data)
# Handle URL (Hugging Face will download the image)
elif isinstance(data, str) and (data.startswith("http://") or data.startswith("https://")):
from urllib.request import urlopen
with urlopen(data) as response:
image_bytes = response.read()
image = Image.open(BytesIO(image_bytes))
# Handle base64 encoded image
elif isinstance(data, str) and data.startswith("data:image"):
base64_data = data.split(",")[1]
image_bytes = base64.b64decode(base64_data)
image = Image.open(BytesIO(image_bytes))
# Handle raw image bytes
elif isinstance(data, bytes):
image = Image.open(BytesIO(data))
# Convert RGBA to RGB if needed
if image.mode == "RGBA":
image = image.convert("RGB")
# Make prediction
return predict_image(image)
# For local testing
if __name__ == "__main__":
# Example of using a file path
test_image_path = "path/to/test/image.jpg"
if os.path.exists(test_image_path):
result = inference(test_image_path)
print(f"Predicted class: {result['predicted_class']}")
print(f"Confidence: {result['confidence']:.4f}")