viniciusgribas commited on
Commit
d30bbd6
·
1 Parent(s): af09f13

modelchanging

Browse files
Files changed (1) hide show
  1. app.py +63 -116
app.py CHANGED
@@ -1,133 +1,80 @@
1
- # pylint: disable=import-error
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
- from torchvision import transforms
6
  from PIL import Image
7
  import matplotlib.pyplot as plt
8
- from transformers import ViTForImageClassification, ViTImageProcessor
9
 
10
- # Load pre-trained Vision Transformer model
11
- model_name = "google/vit-base-patch16-224"
12
- model = ViTForImageClassification.from_pretrained(model_name)
13
- processor = ViTImageProcessor.from_pretrained(model_name)
14
 
15
- # Function to predict image class
 
 
 
 
16
  def classify_image(image):
17
  if image is None:
18
- return None, None
19
-
20
- # Process image
21
- inputs = processor(images=image, return_tensors="pt")
22
-
23
- # Make prediction
24
- with torch.no_grad():
25
- outputs = model(**inputs)
26
- logits = outputs.logits
27
-
28
- # Get predicted class and probabilities
29
- predicted_class_idx = logits.argmax(-1).item()
30
- predicted_class = model.config.id2label[predicted_class_idx]
31
-
32
- # Get top 5 predictions
33
- probs = torch.nn.functional.softmax(logits, dim=-1)[0]
34
- top5_prob, top5_indices = torch.topk(probs, 5)
35
-
36
- # Create plot for visualization
37
- fig, ax = plt.subplots(figsize=(10, 5))
38
-
39
- # Get class names and probabilities
40
- classes = [model.config.id2label[idx.item()] for idx in top5_indices]
41
- probabilities = [prob.item() * 100 for prob in top5_prob]
42
 
43
- # Create horizontal bar chart
44
- bars = ax.barh(classes, probabilities, color='#4C72B0')
45
- ax.set_xlabel('Probability (%)')
46
- ax.set_title('Top 5 Predictions')
47
-
48
- # Add percentage labels
49
- for i, bar in enumerate(bars):
50
- width = bar.get_width()
51
- ax.text(width + 1, bar.get_y() + bar.get_height()/2,
52
- f'{probabilities[i]:.1f}%',
53
- va='center', fontsize=10)
54
-
55
- # Improve layout
56
- plt.tight_layout()
57
-
58
- return predicted_class, fig
59
-
60
- # Create Gradio interface
61
- with gr.Blocks(title="Image Classifier", theme=gr.themes.Soft()) as demo:
62
- gr.Markdown(
63
- """
64
- # 🖼️ Image Classification Tool
65
 
66
- This application uses a Vision Transformer (ViT) model to classify images into 1,000 different categories.
 
 
 
67
 
68
- Upload an image or take a photo to see what the AI recognizes in it!
69
- """
70
- )
71
-
72
- with gr.Row():
73
- with gr.Column():
74
- image_input = gr.Image(
75
- label="Upload or capture an image",
76
- type="pil",
77
- height=400
78
- )
79
- classify_btn = gr.Button("Classify Image", variant="primary")
80
 
81
- with gr.Column():
82
- prediction = gr.Textbox(label="Prediction")
83
- confidence_plot = gr.Plot(label="Confidence Levels")
84
-
85
- # Add examples
86
- example_images = [
87
- "examples/dog.jpg",
88
- "examples/cat.jpg",
89
- "examples/coffee.jpg",
90
- "examples/laptop.jpg",
91
- "examples/beach.jpg"
92
- ]
93
-
94
- gr.Examples(
95
- examples=example_images,
96
- inputs=image_input,
97
- outputs=[prediction, confidence_plot],
98
- fn=classify_image,
99
- cache_examples=True
100
- )
101
-
102
- # Set up the click event
103
- classify_btn.click(
104
- fn=classify_image,
105
- inputs=image_input,
106
- outputs=[prediction, confidence_plot]
107
- )
108
-
109
- # Set up the input change event
110
- image_input.change(
111
- fn=classify_image,
112
- inputs=image_input,
113
- outputs=[prediction, confidence_plot]
114
- )
115
-
116
- gr.Markdown("""
117
- ### How it works
118
-
119
- This tool uses a Vision Transformer (ViT) model pre-trained on ImageNet, enabling it to recognize 1,000
120
- different object categories ranging from animals and plants to vehicles, household items, and more.
121
-
122
- ### Applications
123
-
124
- - **Content Categorization**: Automatically organize image libraries
125
- - **Accessibility**: Help describe images for visually impaired users
126
- - **Education**: Learn about objects in the world around you
127
- - **Data Analysis**: Process and categorize large image datasets
128
 
129
- Created by [Vinicius Guerra e Ribas](https://viniciusgribas.netlify.app/)
130
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # Launch the app
133
  if __name__ == "__main__":
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
 
4
  from PIL import Image
5
  import matplotlib.pyplot as plt
6
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
7
 
8
+ # Use a smaller, more efficient model
9
+ model_name = "microsoft/resnet-18" # Smaller model that should work with Hugging Face constraints
 
 
10
 
11
+ # Load model and feature extractor
12
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
13
+ model = AutoModelForImageClassification.from_pretrained(model_name)
14
+
15
+ # Function to classify image
16
  def classify_image(image):
17
  if image is None:
18
+ return "No image provided", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ try:
21
+ # Process image
22
+ inputs = feature_extractor(images=image, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Make prediction
25
+ with torch.no_grad():
26
+ outputs = model(**inputs)
27
+ logits = outputs.logits
28
 
29
+ # Get predicted class
30
+ predicted_class_idx = logits.argmax(-1).item()
31
+ predicted_class = model.config.id2label[predicted_class_idx]
 
 
 
 
 
 
 
 
 
32
 
33
+ # Get top 5 predictions
34
+ probs = torch.nn.functional.softmax(logits, dim=-1)[0]
35
+ top5_prob, top5_indices = torch.topk(probs, 5)
36
+
37
+ # Create plot for visualization
38
+ fig, ax = plt.subplots(figsize=(10, 5))
39
+
40
+ # Get class names and probabilities
41
+ classes = [model.config.id2label[idx.item()] for idx in top5_indices]
42
+ probabilities = [prob.item() * 100 for prob in top5_prob]
43
+
44
+ # Create horizontal bar chart
45
+ bars = ax.barh(classes, probabilities, color='#4C72B0')
46
+ ax.set_xlabel('Probability (%)')
47
+ ax.set_title('Top 5 Predictions')
48
+
49
+ # Add percentage labels
50
+ for i, bar in enumerate(bars):
51
+ width = bar.get_width()
52
+ ax.text(width + 1, bar.get_y() + bar.get_height()/2,
53
+ f'{probabilities[i]:.1f}%',
54
+ va='center', fontsize=10)
55
+
56
+ # Improve layout
57
+ plt.tight_layout()
58
+
59
+ return predicted_class, fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ except Exception as e:
62
+ return f"Error: {str(e)}", None
63
+
64
+ # Create Gradio interface with simpler structure
65
+ demo = gr.Interface(
66
+ fn=classify_image,
67
+ inputs=gr.Image(type="pil"),
68
+ outputs=[
69
+ gr.Textbox(label="Prediction"),
70
+ gr.Plot(label="Confidence Levels")
71
+ ],
72
+ title="🖼️ Image Classification Tool",
73
+ description="Upload an image to see what the AI recognizes in it!",
74
+ allow_flagging="never",
75
+ examples=[], # No examples to avoid dependencies
76
+ theme=gr.themes.Soft()
77
+ )
78
 
79
  # Launch the app
80
  if __name__ == "__main__":