randaugment / app.py
bishmoy's picture
Update app.py
cd5f2c1
import numpy as np
import tensorflow as tf
from huggingface_hub import from_pretrained_keras
import gradio as gr
IMAGE_SIZE = 72
# labels taken from https://huggingface.co/datasets/cifar10
labels = {0: "airplane",
1: "automobile",
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog",
7: "horse",
8: "ship",
9: "truck"}
model = from_pretrained_keras("keras-io/randaugment")
def predict_img_label(img):
inp = tf.image.resize(img, (IMAGE_SIZE, IMAGE_SIZE))
pred = model.predict(tf.expand_dims(inp, 0)).flatten()
return {labels[i]: float(pred[i]) for i in range(len(labels))}
image = gr.inputs.Image()
label = gr.outputs.Label(num_top_classes=3)
title = "Image Classification Model Using RandAugment"
description = "Upload an image to classify images"
article = "<div style='text-align: center;'><a href='https://github.com/BishmoyPaul' target='_blank'>Space by Bishmoy Paul</a><br><a href='https://keras.io/examples/vision/randaugment/' target='_blank'>Keras example by Sayak Paul</a></div>"
gr.Interface(predict_img_label, inputs=image, outputs=label, allow_flagging=False,
examples = [['./airplane.jpg'], ['./car.png'], ['./cat.jpg'], ['./horse.jpg']],
title=title, description=description, article=article).launch(enable_queue=True)