pjohn2409's picture
First commit
7d8db22
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications.imagenet_utils import preprocess_input
import os
# Load your frozen model
model = tf.keras.models.load_model("final_trashnet_transfer_learning_model.keras")
# Mapping of original classes to broader categories
class_mapping = {
0: "Compostable", # compostable
1: "Recyclables", # recyclable
2: "Trash", #trash
}
# Define a function to preprocess the input image
def preprocess_image(image):
# Resize the image to 128*128 (as required by your model)
image = image.resize((128, 128))
# Convert the image to a NumPy array and normalize it
img_array = img_to_array(image)
img_array = preprocess_input(img_array)
# Ensure the image has the correct shape (32, 32, 3)
img_array = np.expand_dims(img_array, axis=0)
return img_array
# Define the prediction function
def classify_trash(image):
processed_image = preprocess_image(image)
predictions = model.predict(processed_image)
print(predictions)
class_index = np.argmax(predictions)
confidence = np.max(predictions)
predicted_class = class_mapping[class_index]
return f"Predicted Category: {predicted_class}", f"Confidence: {confidence*100:.2f}"
# Function to gather example images dynamically
def get_example_images():
example_images = []
base_dir = "examples"
categories = ["Compostable", "Recyclables", "Trash"]
for category in categories:
folder_path = os.path.join(base_dir, category)
if os.path.exists(folder_path):
example_images += [
os.path.join(folder_path, img) for img in os.listdir(folder_path) if img.endswith((".jpg", ".png"))
]
return example_images
# Example images from all categories
example_images = get_example_images()
# Define the Gradio interface
interface = gr.Interface(
fn=classify_trash,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=[gr.Textbox(label="Predicted Category"), gr.Textbox(label="Confidence")],
examples= example_images,
title="Trash Classifier",
description="Upload an image of trash, and the model will classify it into 'Compostable', 'Recyclables' and 'Trash' based on its category."
)
# Run the app
if __name__ == "__main__":
interface.launch()