Spaces:
Sleeping
Sleeping
# pylint: disable=import-error | |
import gradio as gr | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from transformers import ViTForImageClassification, ViTImageProcessor | |
# Load pre-trained Vision Transformer model | |
model_name = "google/vit-base-patch16-224" | |
model = ViTForImageClassification.from_pretrained(model_name) | |
processor = ViTImageProcessor.from_pretrained(model_name) | |
# Function to predict image class | |
def classify_image(image): | |
if image is None: | |
return None, None | |
# Process image | |
inputs = processor(images=image, return_tensors="pt") | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Get predicted class and probabilities | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_class = model.config.id2label[predicted_class_idx] | |
# Get top 5 predictions | |
probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
top5_prob, top5_indices = torch.topk(probs, 5) | |
# Create plot for visualization | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
# Get class names and probabilities | |
classes = [model.config.id2label[idx.item()] for idx in top5_indices] | |
probabilities = [prob.item() * 100 for prob in top5_prob] | |
# Create horizontal bar chart | |
bars = ax.barh(classes, probabilities, color='#4C72B0') | |
ax.set_xlabel('Probability (%)') | |
ax.set_title('Top 5 Predictions') | |
# Add percentage labels | |
for i, bar in enumerate(bars): | |
width = bar.get_width() | |
ax.text(width + 1, bar.get_y() + bar.get_height()/2, | |
f'{probabilities[i]:.1f}%', | |
va='center', fontsize=10) | |
# Improve layout | |
plt.tight_layout() | |
return predicted_class, fig | |
# Create Gradio interface | |
with gr.Blocks(title="Image Classifier", theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# 🖼️ Image Classification Tool | |
This application uses a Vision Transformer (ViT) model to classify images into 1,000 different categories. | |
Upload an image or take a photo to see what the AI recognizes in it! | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
label="Upload or capture an image", | |
type="pil", | |
height=400 | |
) | |
classify_btn = gr.Button("Classify Image", variant="primary") | |
with gr.Column(): | |
prediction = gr.Textbox(label="Prediction") | |
confidence_plot = gr.Plot(label="Confidence Levels") | |
# Add examples | |
example_images = [ | |
"examples/dog.jpg", | |
"examples/cat.jpg", | |
"examples/coffee.jpg", | |
"examples/laptop.jpg", | |
"examples/beach.jpg" | |
] | |
gr.Examples( | |
examples=example_images, | |
inputs=image_input, | |
outputs=[prediction, confidence_plot], | |
fn=classify_image, | |
cache_examples=True | |
) | |
# Set up the click event | |
classify_btn.click( | |
fn=classify_image, | |
inputs=image_input, | |
outputs=[prediction, confidence_plot] | |
) | |
# Set up the input change event | |
image_input.change( | |
fn=classify_image, | |
inputs=image_input, | |
outputs=[prediction, confidence_plot] | |
) | |
gr.Markdown(""" | |
### How it works | |
This tool uses a Vision Transformer (ViT) model pre-trained on ImageNet, enabling it to recognize 1,000 | |
different object categories ranging from animals and plants to vehicles, household items, and more. | |
### Applications | |
- **Content Categorization**: Automatically organize image libraries | |
- **Accessibility**: Help describe images for visually impaired users | |
- **Education**: Learn about objects in the world around you | |
- **Data Analysis**: Process and categorize large image datasets | |
Created by [Vinicius Guerra e Ribas](https://viniciusgribas.netlify.app/) | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |