Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
from tensorflow.keras.models import load_model as lm | |
from PIL import Image | |
import plotly.graph_objects as go | |
# Load your trained CIFAR model | |
model = lm('build_model_1_v2.h5') | |
# Define the CIFAR-10 class names | |
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
# Function to preprocess the image and predict the class | |
def classify_image(image): | |
# Ensure the image is in the right format | |
image = Image.fromarray(image).convert('RGB') | |
# Resize the image to (32, 32) as CIFAR-10 uses 32x32 images | |
image = image.resize((32, 32)) | |
# Convert the image to an array and preprocess it | |
image_array = np.array(image).astype(np.float32) / 255.0 # Normalize to [0, 1] | |
# Expand dimensions to match model input shape (1, 32, 32, 3) | |
image_array = np.expand_dims(image_array, axis=0) | |
# Get predictions | |
predictions = model.predict(image_array)[0] # Get the prediction array for the first image | |
# Get the predicted class and its confidence | |
predicted_class_idx = np.argmax(predictions) | |
predicted_class = class_names[predicted_class_idx] | |
predicted_confidence = predictions[predicted_class_idx] * 100 # Convert to percentage | |
# Print predicted class and confidence | |
predicted_info = f"Predicted Class: {predicted_class} with {predicted_confidence:.2f}% confidence." | |
# Create a Plotly bar chart for class confidence levels | |
fig = go.Figure(go.Bar( | |
x=predictions * 100, # Convert probabilities to percentages | |
y=class_names, | |
orientation='h', | |
marker=dict(color='skyblue'), | |
text=[f"{conf:.1f}%" for conf in predictions * 100], # Show percentage labels | |
hoverinfo="text" | |
)) | |
# Update layout for better presentation | |
fig.update_layout( | |
title="Class Confidence Levels", | |
xaxis_title="Confidence (%)", | |
yaxis_title="Classes", | |
xaxis=dict(range=[0, 100]), # Set x-axis to 0-100% | |
yaxis=dict(categoryorder='total ascending'), # Sort bars by confidence | |
bargap=0.2 | |
) | |
return predicted_info, fig | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(type="numpy"), | |
outputs=["text", gr.Plot()], | |
title="CIFAR Image Classification", | |
description="Upload an image, and the model will classify it as one of the CIFAR-10 classes." | |
) | |
# Launch the interface | |
interface.launch() | |