File size: 4,142 Bytes
71d6e29
9495c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# pylint: disable=import-error
import gradio as gr
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from transformers import ViTForImageClassification, ViTImageProcessor

# Load pre-trained Vision Transformer model
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name)
processor = ViTImageProcessor.from_pretrained(model_name)

# Function to predict image class
def classify_image(image):
    if image is None:
        return None, None
    
    # Process image
    inputs = processor(images=image, return_tensors="pt")
    
    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    
    # Get predicted class and probabilities
    predicted_class_idx = logits.argmax(-1).item()
    predicted_class = model.config.id2label[predicted_class_idx]
    
    # Get top 5 predictions
    probs = torch.nn.functional.softmax(logits, dim=-1)[0]
    top5_prob, top5_indices = torch.topk(probs, 5)
    
    # Create plot for visualization
    fig, ax = plt.subplots(figsize=(10, 5))
    
    # Get class names and probabilities
    classes = [model.config.id2label[idx.item()] for idx in top5_indices]
    probabilities = [prob.item() * 100 for prob in top5_prob]
    
    # Create horizontal bar chart
    bars = ax.barh(classes, probabilities, color='#4C72B0')
    ax.set_xlabel('Probability (%)')
    ax.set_title('Top 5 Predictions')
    
    # Add percentage labels
    for i, bar in enumerate(bars):
        width = bar.get_width()
        ax.text(width + 1, bar.get_y() + bar.get_height()/2, 
                f'{probabilities[i]:.1f}%', 
                va='center', fontsize=10)
    
    # Improve layout
    plt.tight_layout()
    
    return predicted_class, fig

# Create Gradio interface
with gr.Blocks(title="Image Classifier", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🖼️ Image Classification Tool
        
        This application uses a Vision Transformer (ViT) model to classify images into 1,000 different categories.
        
        Upload an image or take a photo to see what the AI recognizes in it!
        """
    )
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(
                label="Upload or capture an image",
                type="pil",
                height=400
            )
            classify_btn = gr.Button("Classify Image", variant="primary")
        
        with gr.Column():
            prediction = gr.Textbox(label="Prediction")
            confidence_plot = gr.Plot(label="Confidence Levels")
    
    # Add examples
    example_images = [
        "examples/dog.jpg",
        "examples/cat.jpg",
        "examples/coffee.jpg",
        "examples/laptop.jpg",
        "examples/beach.jpg"
    ]
    
    gr.Examples(
        examples=example_images,
        inputs=image_input,
        outputs=[prediction, confidence_plot],
        fn=classify_image,
        cache_examples=True
    )
    
    # Set up the click event
    classify_btn.click(
        fn=classify_image,
        inputs=image_input,
        outputs=[prediction, confidence_plot]
    )
    
    # Set up the input change event
    image_input.change(
        fn=classify_image,
        inputs=image_input,
        outputs=[prediction, confidence_plot]
    )
    
    gr.Markdown("""
    ### How it works
    
    This tool uses a Vision Transformer (ViT) model pre-trained on ImageNet, enabling it to recognize 1,000 
    different object categories ranging from animals and plants to vehicles, household items, and more.
    
    ### Applications
    
    - **Content Categorization**: Automatically organize image libraries
    - **Accessibility**: Help describe images for visually impaired users
    - **Education**: Learn about objects in the world around you
    - **Data Analysis**: Process and categorize large image datasets
    
    Created by [Vinicius Guerra e Ribas](https://viniciusgribas.netlify.app/)
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch()