Spaces:
Build error
Build error
import numpy as np | |
import tensorflow as tf | |
from huggingface_hub import from_pretrained_keras | |
import gradio as gr | |
IMAGE_SIZE = 72 | |
# labels taken from https://huggingface.co/datasets/cifar10 | |
labels = {0: "airplane", | |
1: "automobile", | |
2: "bird", | |
3: "cat", | |
4: "deer", | |
5: "dog", | |
6: "frog", | |
7: "horse", | |
8: "ship", | |
9: "truck"} | |
model = from_pretrained_keras("keras-io/randaugment") | |
def predict_img_label(img): | |
inp = tf.image.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) | |
pred = model.predict(tf.expand_dims(inp, 0)).flatten() | |
return {labels[i]: float(pred[i]) for i in range(len(labels))} | |
image = gr.inputs.Image() | |
label = gr.outputs.Label(num_top_classes=3) | |
title = "Image Classification Model Using RandAugment" | |
description = "Upload an image to classify images" | |
article = "<div style='text-align: center;'><a href='https://github.com/BishmoyPaul' target='_blank'>Space by Bishmoy Paul</a><br><a href='https://keras.io/examples/vision/randaugment/' target='_blank'>Keras example by Sayak Paul</a></div>" | |
gr.Interface(predict_img_label, inputs=image, outputs=label, allow_flagging=False, | |
examples = [['./airplane.jpg'], ['./car.png'], ['./cat.jpg'], ['./horse.jpg']], | |
title=title, description=description, article=article).launch(enable_queue=True) | |