akhaliq HF staff commited on
Commit
1e2f8be
·
1 Parent(s): 3264419

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -0
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow_hub as hub
3
+
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import gradio as gr
11
+
12
+ #@title Helper functions for loading image (hidden)
13
+
14
+ original_image_cache = {}
15
+
16
+ def preprocess_image(image):
17
+ image = np.array(image)
18
+ # reshape into shape [batch_size, height, width, num_channels]
19
+ img_reshaped = tf.reshape(image, [1, image.shape[0], image.shape[1], image.shape[2]])
20
+ # Use `convert_image_dtype` to convert to floats in the [0,1] range.
21
+ image = tf.image.convert_image_dtype(img_reshaped, tf.float32)
22
+ return image
23
+
24
+ def load_image_from_url(img_url):
25
+ """Returns an image with shape [1, height, width, num_channels]."""
26
+ user_agent = {'User-agent': 'Colab Sample (https://tensorflow.org)'}
27
+ response = requests.get(img_url, headers=user_agent)
28
+ image = Image.open(BytesIO(response.content))
29
+ image = preprocess_image(image)
30
+ return image
31
+
32
+ def load_image(image_url, image_size=256, dynamic_size=False, max_dynamic_size=512):
33
+ """Loads and preprocesses images."""
34
+ # Cache image file locally.
35
+ if image_url in original_image_cache:
36
+ img = original_image_cache[image_url]
37
+ elif image_url.startswith('https://'):
38
+ img = load_image_from_url(image_url)
39
+ else:
40
+ fd = tf.io.gfile.GFile(image_url, 'rb')
41
+ img = preprocess_image(Image.open(fd))
42
+ original_image_cache[image_url] = img
43
+ # Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1].
44
+ img_raw = img
45
+ if tf.reduce_max(img) > 1.0:
46
+ img = img / 255.
47
+ if len(img.shape) == 3:
48
+ img = tf.stack([img, img, img], axis=-1)
49
+ if not dynamic_size:
50
+ img = tf.image.resize_with_pad(img, image_size, image_size)
51
+ elif img.shape[1] > max_dynamic_size or img.shape[2] > max_dynamic_size:
52
+ img = tf.image.resize_with_pad(img, max_dynamic_size, max_dynamic_size)
53
+ return img, img_raw
54
+
55
+
56
+
57
+ image_size = 224
58
+ dynamic_size = False
59
+
60
+ model_name = "efficientnetv2-s"
61
+
62
+ model_handle_map = {
63
+ "efficientnetv2-s": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_s/classification/2",
64
+ "efficientnetv2-m": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_m/classification/2",
65
+ "efficientnetv2-l": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_l/classification/2",
66
+ "efficientnetv2-s-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_s/classification/2",
67
+ "efficientnetv2-m-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_m/classification/2",
68
+ "efficientnetv2-l-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_l/classification/2",
69
+ "efficientnetv2-xl-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_xl/classification/2",
70
+ "efficientnetv2-b0-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b0/classification/2",
71
+ "efficientnetv2-b1-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b1/classification/2",
72
+ "efficientnetv2-b2-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b2/classification/2",
73
+ "efficientnetv2-b3-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b3/classification/2",
74
+ "efficientnetv2-s-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_s/classification/2",
75
+ "efficientnetv2-m-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_m/classification/2",
76
+ "efficientnetv2-l-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_l/classification/2",
77
+ "efficientnetv2-xl-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_xl/classification/2",
78
+ "efficientnetv2-b0-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b0/classification/2",
79
+ "efficientnetv2-b1-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b1/classification/2",
80
+ "efficientnetv2-b2-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b2/classification/2",
81
+ "efficientnetv2-b3-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b3/classification/2",
82
+ "efficientnetv2-b0": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b0/classification/2",
83
+ "efficientnetv2-b1": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b1/classification/2",
84
+ "efficientnetv2-b2": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b2/classification/2",
85
+ "efficientnetv2-b3": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b3/classification/2",
86
+ "efficientnet_b0": "https://tfhub.dev/tensorflow/efficientnet/b0/classification/1",
87
+ "efficientnet_b1": "https://tfhub.dev/tensorflow/efficientnet/b1/classification/1",
88
+ "efficientnet_b2": "https://tfhub.dev/tensorflow/efficientnet/b2/classification/1",
89
+ "efficientnet_b3": "https://tfhub.dev/tensorflow/efficientnet/b3/classification/1",
90
+ "efficientnet_b4": "https://tfhub.dev/tensorflow/efficientnet/b4/classification/1",
91
+ "efficientnet_b5": "https://tfhub.dev/tensorflow/efficientnet/b5/classification/1",
92
+ "efficientnet_b6": "https://tfhub.dev/tensorflow/efficientnet/b6/classification/1",
93
+ "efficientnet_b7": "https://tfhub.dev/tensorflow/efficientnet/b7/classification/1",
94
+ "bit_s-r50x1": "https://tfhub.dev/google/bit/s-r50x1/ilsvrc2012_classification/1",
95
+ "inception_v3": "https://tfhub.dev/google/imagenet/inception_v3/classification/4",
96
+ "inception_resnet_v2": "https://tfhub.dev/google/imagenet/inception_resnet_v2/classification/4",
97
+ "resnet_v1_50": "https://tfhub.dev/google/imagenet/resnet_v1_50/classification/4",
98
+ "resnet_v1_101": "https://tfhub.dev/google/imagenet/resnet_v1_101/classification/4",
99
+ "resnet_v1_152": "https://tfhub.dev/google/imagenet/resnet_v1_152/classification/4",
100
+ "resnet_v2_50": "https://tfhub.dev/google/imagenet/resnet_v2_50/classification/4",
101
+ "resnet_v2_101": "https://tfhub.dev/google/imagenet/resnet_v2_101/classification/4",
102
+ "resnet_v2_152": "https://tfhub.dev/google/imagenet/resnet_v2_152/classification/4",
103
+ "nasnet_large": "https://tfhub.dev/google/imagenet/nasnet_large/classification/4",
104
+ "nasnet_mobile": "https://tfhub.dev/google/imagenet/nasnet_mobile/classification/4",
105
+ "pnasnet_large": "https://tfhub.dev/google/imagenet/pnasnet_large/classification/4",
106
+ "mobilenet_v2_100_224": "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/4",
107
+ "mobilenet_v2_130_224": "https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/classification/4",
108
+ "mobilenet_v2_140_224": "https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/4",
109
+ "mobilenet_v3_small_100_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_small_100_224/classification/5",
110
+ "mobilenet_v3_small_075_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_small_075_224/classification/5",
111
+ "mobilenet_v3_large_100_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/5",
112
+ "mobilenet_v3_large_075_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_large_075_224/classification/5",
113
+ }
114
+
115
+ model_image_size_map = {
116
+ "efficientnetv2-s": 384,
117
+ "efficientnetv2-m": 480,
118
+ "efficientnetv2-l": 480,
119
+ "efficientnetv2-b0": 224,
120
+ "efficientnetv2-b1": 240,
121
+ "efficientnetv2-b2": 260,
122
+ "efficientnetv2-b3": 300,
123
+ "efficientnetv2-s-21k": 384,
124
+ "efficientnetv2-m-21k": 480,
125
+ "efficientnetv2-l-21k": 480,
126
+ "efficientnetv2-xl-21k": 512,
127
+ "efficientnetv2-b0-21k": 224,
128
+ "efficientnetv2-b1-21k": 240,
129
+ "efficientnetv2-b2-21k": 260,
130
+ "efficientnetv2-b3-21k": 300,
131
+ "efficientnetv2-s-21k-ft1k": 384,
132
+ "efficientnetv2-m-21k-ft1k": 480,
133
+ "efficientnetv2-l-21k-ft1k": 480,
134
+ "efficientnetv2-xl-21k-ft1k": 512,
135
+ "efficientnetv2-b0-21k-ft1k": 224,
136
+ "efficientnetv2-b1-21k-ft1k": 240,
137
+ "efficientnetv2-b2-21k-ft1k": 260,
138
+ "efficientnetv2-b3-21k-ft1k": 300,
139
+ "efficientnet_b0": 224,
140
+ "efficientnet_b1": 240,
141
+ "efficientnet_b2": 260,
142
+ "efficientnet_b3": 300,
143
+ "efficientnet_b4": 380,
144
+ "efficientnet_b5": 456,
145
+ "efficientnet_b6": 528,
146
+ "efficientnet_b7": 600,
147
+ "inception_v3": 299,
148
+ "inception_resnet_v2": 299,
149
+ "mobilenet_v2_100_224": 224,
150
+ "mobilenet_v2_130_224": 224,
151
+ "mobilenet_v2_140_224": 224,
152
+ "nasnet_large": 331,
153
+ "nasnet_mobile": 224,
154
+ "pnasnet_large": 331,
155
+ "resnet_v1_50": 224,
156
+ "resnet_v1_101": 224,
157
+ "resnet_v1_152": 224,
158
+ "resnet_v2_50": 224,
159
+ "resnet_v2_101": 224,
160
+ "resnet_v2_152": 224,
161
+ "mobilenet_v3_small_100_224": 224,
162
+ "mobilenet_v3_small_075_224": 224,
163
+ "mobilenet_v3_large_100_224": 224,
164
+ "mobilenet_v3_large_075_224": 224,
165
+ }
166
+
167
+ model_handle = model_handle_map[model_name]
168
+
169
+
170
+ max_dynamic_size = 512
171
+ if model_name in model_image_size_map:
172
+ image_size = model_image_size_map[model_name]
173
+ dynamic_size = False
174
+ print(f"Images will be converted to {image_size}x{image_size}")
175
+ else:
176
+ dynamic_size = True
177
+ print(f"Images will be capped to a max size of {max_dynamic_size}x{max_dynamic_size}")
178
+
179
+ labels_file = "https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt"
180
+
181
+ #download labels and creates a maps
182
+ downloaded_file = tf.keras.utils.get_file("labels.txt", origin=labels_file)
183
+
184
+ classes = []
185
+
186
+ with open(downloaded_file) as f:
187
+ labels = f.readlines()
188
+ classes = [l.strip() for l in labels]
189
+
190
+
191
+ classifier = hub.load(model_handle)
192
+
193
+
194
+ def inference(img):
195
+ image, original_image = load_image(img, image_size, dynamic_size, max_dynamic_size)
196
+
197
+
198
+ input_shape = image.shape
199
+ warmup_input = tf.random.uniform(input_shape, 0, 1.0)
200
+ warmup_logits = classifier(warmup_input).numpy()
201
+
202
+ # Run model on image
203
+ probabilities = tf.nn.softmax(classifier(image)).numpy()
204
+
205
+ top_5 = tf.argsort(probabilities, axis=-1, direction="DESCENDING")[0][:5].numpy()
206
+ np_classes = np.array(classes)
207
+
208
+ # Some models include an additional 'background' class in the predictions, so
209
+ # we must account for this when reading the class labels.
210
+ includes_background_class = probabilities.shape[1] == 1001
211
+ result = {}
212
+ for i, item in enumerate(top_5):
213
+ class_index = item if includes_background_class else item + 1
214
+ line = f'({i+1}) {class_index:4} - {classes[class_index]}: {probabilities[0][top_5][i]}'
215
+ result[classes[class_index]] = probabilities[0][top_5][i].item()
216
+ return result
217
+
218
+ title="efficientnetv2-s"
219
+ description="Gradio Demo for efficientnetv2-s: EfficientNet V2 are a family of image classification models, which achieve better parameter efficiency and faster training speed than prior arts. To use it, simply upload your image or click on one of the examples to load them. Read more at the links below"
220
+ article = "<p style='text-align: center'><a href='https://tfhub.dev/google/collections/efficientnet_v2/1' target='_blank'>Tensorflow Hub</a></p>"
221
+ examples=[['apple1.jpg']]
222
+ gr.Interface(inference,gr.inputs.Image(type="filepath"),"label",title=title,description=description,article=article,examples=examples).launch(enable_queue=True)