viniciusgribas's picture
commit temporário antes do filtro
71d6e29
raw
history blame
4.14 kB
# pylint: disable=import-error
import gradio as gr
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from transformers import ViTForImageClassification, ViTImageProcessor
# Load pre-trained Vision Transformer model
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)
# Function to predict image class
def classify_image(image):
if image is None:
return None, None
# Process image
inputs = processor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get predicted class and probabilities
predicted_class_idx = logits.argmax(-1).item()
predicted_class = model.config.id2label[predicted_class_idx]
# Get top 5 predictions
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
top5_prob, top5_indices = torch.topk(probs, 5)
# Create plot for visualization
fig, ax = plt.subplots(figsize=(10, 5))
# Get class names and probabilities
classes = [model.config.id2label[idx.item()] for idx in top5_indices]
probabilities = [prob.item() * 100 for prob in top5_prob]
# Create horizontal bar chart
bars = ax.barh(classes, probabilities, color='#4C72B0')
ax.set_xlabel('Probability (%)')
ax.set_title('Top 5 Predictions')
# Add percentage labels
for i, bar in enumerate(bars):
width = bar.get_width()
ax.text(width + 1, bar.get_y() + bar.get_height()/2,
f'{probabilities[i]:.1f}%',
va='center', fontsize=10)
# Improve layout
plt.tight_layout()
return predicted_class, fig
# Create Gradio interface
with gr.Blocks(title="Image Classifier", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🖼️ Image Classification Tool
This application uses a Vision Transformer (ViT) model to classify images into 1,000 different categories.
Upload an image or take a photo to see what the AI recognizes in it!
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Upload or capture an image",
type="pil",
height=400
)
classify_btn = gr.Button("Classify Image", variant="primary")
with gr.Column():
prediction = gr.Textbox(label="Prediction")
confidence_plot = gr.Plot(label="Confidence Levels")
# Add examples
example_images = [
"examples/dog.jpg",
"examples/cat.jpg",
"examples/coffee.jpg",
"examples/laptop.jpg",
"examples/beach.jpg"
]
gr.Examples(
examples=example_images,
inputs=image_input,
outputs=[prediction, confidence_plot],
fn=classify_image,
cache_examples=True
)
# Set up the click event
classify_btn.click(
fn=classify_image,
inputs=image_input,
outputs=[prediction, confidence_plot]
)
# Set up the input change event
image_input.change(
fn=classify_image,
inputs=image_input,
outputs=[prediction, confidence_plot]
)
gr.Markdown("""
### How it works
This tool uses a Vision Transformer (ViT) model pre-trained on ImageNet, enabling it to recognize 1,000
different object categories ranging from animals and plants to vehicles, household items, and more.
### Applications
- **Content Categorization**: Automatically organize image libraries
- **Accessibility**: Help describe images for visually impaired users
- **Education**: Learn about objects in the world around you
- **Data Analysis**: Process and categorize large image datasets
Created by [Vinicius Guerra e Ribas](https://viniciusgribas.netlify.app/)
""")
# Launch the app
if __name__ == "__main__":
demo.launch()