# pylint: disable=import-error import gradio as gr import numpy as np import torch from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt from transformers import ViTForImageClassification, ViTImageProcessor # Load pre-trained Vision Transformer model model_name = "google/vit-base-patch16-224" model = ViTForImageClassification.from_pretrained(model_name) processor = ViTImageProcessor.from_pretrained(model_name) # Function to predict image class def classify_image(image): if image is None: return None, None # Process image inputs = processor(images=image, return_tensors="pt") # Make prediction with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get predicted class and probabilities predicted_class_idx = logits.argmax(-1).item() predicted_class = model.config.id2label[predicted_class_idx] # Get top 5 predictions probs = torch.nn.functional.softmax(logits, dim=-1)[0] top5_prob, top5_indices = torch.topk(probs, 5) # Create plot for visualization fig, ax = plt.subplots(figsize=(10, 5)) # Get class names and probabilities classes = [model.config.id2label[idx.item()] for idx in top5_indices] probabilities = [prob.item() * 100 for prob in top5_prob] # Create horizontal bar chart bars = ax.barh(classes, probabilities, color='#4C72B0') ax.set_xlabel('Probability (%)') ax.set_title('Top 5 Predictions') # Add percentage labels for i, bar in enumerate(bars): width = bar.get_width() ax.text(width + 1, bar.get_y() + bar.get_height()/2, f'{probabilities[i]:.1f}%', va='center', fontsize=10) # Improve layout plt.tight_layout() return predicted_class, fig # Create Gradio interface with gr.Blocks(title="Image Classifier", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🖼️ Image Classification Tool This application uses a Vision Transformer (ViT) model to classify images into 1,000 different categories. Upload an image or take a photo to see what the AI recognizes in it! """ ) with gr.Row(): with gr.Column(): image_input = gr.Image( label="Upload or capture an image", type="pil", height=400 ) classify_btn = gr.Button("Classify Image", variant="primary") with gr.Column(): prediction = gr.Textbox(label="Prediction") confidence_plot = gr.Plot(label="Confidence Levels") # Add examples example_images = [ "examples/dog.jpg", "examples/cat.jpg", "examples/coffee.jpg", "examples/laptop.jpg", "examples/beach.jpg" ] gr.Examples( examples=example_images, inputs=image_input, outputs=[prediction, confidence_plot], fn=classify_image, cache_examples=True ) # Set up the click event classify_btn.click( fn=classify_image, inputs=image_input, outputs=[prediction, confidence_plot] ) # Set up the input change event image_input.change( fn=classify_image, inputs=image_input, outputs=[prediction, confidence_plot] ) gr.Markdown(""" ### How it works This tool uses a Vision Transformer (ViT) model pre-trained on ImageNet, enabling it to recognize 1,000 different object categories ranging from animals and plants to vehicles, household items, and more. ### Applications - **Content Categorization**: Automatically organize image libraries - **Accessibility**: Help describe images for visually impaired users - **Education**: Learn about objects in the world around you - **Data Analysis**: Process and categorize large image datasets Created by [Vinicius Guerra e Ribas](https://viniciusgribas.netlify.app/) """) # Launch the app if __name__ == "__main__": demo.launch()