File size: 1,318 Bytes
8d4307c
 
 
 
 
 
 
 
e47a3dd
 
 
 
 
 
 
 
8d4307c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bce6a13
cd5f2c1
8d4307c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
import gradio as gr

IMAGE_SIZE = 72
# labels taken from https://huggingface.co/datasets/cifar10
labels = {0: "airplane",
   1: "automobile",
   2: "bird",
   3: "cat",
   4: "deer",
   5: "dog",
   6: "frog",
   7: "horse",
   8: "ship",
   9: "truck"}


model = from_pretrained_keras("keras-io/randaugment")



def predict_img_label(img):
  inp = tf.image.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
  pred = model.predict(tf.expand_dims(inp, 0)).flatten()
  return {labels[i]: float(pred[i]) for i in range(len(labels))}

image = gr.inputs.Image()
label = gr.outputs.Label(num_top_classes=3)

title = "Image Classification Model Using RandAugment"
description = "Upload an image to classify images"

article = "<div style='text-align: center;'><a href='https://github.com/BishmoyPaul' target='_blank'>Space by Bishmoy Paul</a><br><a href='https://keras.io/examples/vision/randaugment/' target='_blank'>Keras example by Sayak Paul</a></div>"
gr.Interface(predict_img_label, inputs=image, outputs=label, allow_flagging=False,
   examples = [['./airplane.jpg'], ['./car.png'], ['./cat.jpg'], ['./horse.jpg']],
  title=title, description=description, article=article).launch(enable_queue=True)