CIFAR-10 / app.py
abdulrahman245's picture
Update app.py
09f8bbf verified
import gradio as gr
import numpy as np
from tensorflow.keras.models import load_model as lm
from PIL import Image
import plotly.graph_objects as go
# Load your trained CIFAR model
model = lm('build_model_1_v2.h5')
# Define the CIFAR-10 class names
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# Function to preprocess the image and predict the class
def classify_image(image):
# Ensure the image is in the right format
image = Image.fromarray(image).convert('RGB')
# Resize the image to (32, 32) as CIFAR-10 uses 32x32 images
image = image.resize((32, 32))
# Convert the image to an array and preprocess it
image_array = np.array(image).astype(np.float32) / 255.0 # Normalize to [0, 1]
# Expand dimensions to match model input shape (1, 32, 32, 3)
image_array = np.expand_dims(image_array, axis=0)
# Get predictions
predictions = model.predict(image_array)[0] # Get the prediction array for the first image
# Get the predicted class and its confidence
predicted_class_idx = np.argmax(predictions)
predicted_class = class_names[predicted_class_idx]
predicted_confidence = predictions[predicted_class_idx] * 100 # Convert to percentage
# Print predicted class and confidence
predicted_info = f"Predicted Class: {predicted_class} with {predicted_confidence:.2f}% confidence."
# Create a Plotly bar chart for class confidence levels
fig = go.Figure(go.Bar(
x=predictions * 100, # Convert probabilities to percentages
y=class_names,
orientation='h',
marker=dict(color='skyblue'),
text=[f"{conf:.1f}%" for conf in predictions * 100], # Show percentage labels
hoverinfo="text"
))
# Update layout for better presentation
fig.update_layout(
title="Class Confidence Levels",
xaxis_title="Confidence (%)",
yaxis_title="Classes",
xaxis=dict(range=[0, 100]), # Set x-axis to 0-100%
yaxis=dict(categoryorder='total ascending'), # Sort bars by confidence
bargap=0.2
)
return predicted_info, fig
# Define the Gradio interface
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="numpy"),
outputs=["text", gr.Plot()],
title="CIFAR Image Classification",
description="Upload an image, and the model will classify it as one of the CIFAR-10 classes."
)
# Launch the interface
interface.launch()