Spaces:
Sleeping
Sleeping
File size: 4,142 Bytes
71d6e29 9495c6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# pylint: disable=import-error
import gradio as gr
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from transformers import ViTForImageClassification, ViTImageProcessor
# Load pre-trained Vision Transformer model
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)
# Function to predict image class
def classify_image(image):
if image is None:
return None, None
# Process image
inputs = processor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get predicted class and probabilities
predicted_class_idx = logits.argmax(-1).item()
predicted_class = model.config.id2label[predicted_class_idx]
# Get top 5 predictions
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
top5_prob, top5_indices = torch.topk(probs, 5)
# Create plot for visualization
fig, ax = plt.subplots(figsize=(10, 5))
# Get class names and probabilities
classes = [model.config.id2label[idx.item()] for idx in top5_indices]
probabilities = [prob.item() * 100 for prob in top5_prob]
# Create horizontal bar chart
bars = ax.barh(classes, probabilities, color='#4C72B0')
ax.set_xlabel('Probability (%)')
ax.set_title('Top 5 Predictions')
# Add percentage labels
for i, bar in enumerate(bars):
width = bar.get_width()
ax.text(width + 1, bar.get_y() + bar.get_height()/2,
f'{probabilities[i]:.1f}%',
va='center', fontsize=10)
# Improve layout
plt.tight_layout()
return predicted_class, fig
# Create Gradio interface
with gr.Blocks(title="Image Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🖼️ Image Classification Tool
This application uses a Vision Transformer (ViT) model to classify images into 1,000 different categories.
Upload an image or take a photo to see what the AI recognizes in it!
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Upload or capture an image",
type="pil",
height=400
)
classify_btn = gr.Button("Classify Image", variant="primary")
with gr.Column():
prediction = gr.Textbox(label="Prediction")
confidence_plot = gr.Plot(label="Confidence Levels")
# Add examples
example_images = [
"examples/dog.jpg",
"examples/cat.jpg",
"examples/coffee.jpg",
"examples/laptop.jpg",
"examples/beach.jpg"
]
gr.Examples(
examples=example_images,
inputs=image_input,
outputs=[prediction, confidence_plot],
fn=classify_image,
cache_examples=True
)
# Set up the click event
classify_btn.click(
fn=classify_image,
inputs=image_input,
outputs=[prediction, confidence_plot]
)
# Set up the input change event
image_input.change(
fn=classify_image,
inputs=image_input,
outputs=[prediction, confidence_plot]
)
gr.Markdown("""
### How it works
This tool uses a Vision Transformer (ViT) model pre-trained on ImageNet, enabling it to recognize 1,000
different object categories ranging from animals and plants to vehicles, household items, and more.
### Applications
- **Content Categorization**: Automatically organize image libraries
- **Accessibility**: Help describe images for visually impaired users
- **Education**: Learn about objects in the world around you
- **Data Analysis**: Process and categorize large image datasets
Created by [Vinicius Guerra e Ribas](https://viniciusgribas.netlify.app/)
""")
# Launch the app
if __name__ == "__main__":
demo.launch() |