|
import gradio as gr |
|
from huggingface_hub import from_pretrained_keras |
|
import tensorflow as tf |
|
|
|
|
|
CLASSES = { |
|
0: "airplane", |
|
1: "automobile", |
|
2: "bird", |
|
3: "cat", |
|
4: "deer", |
|
5: "dog", |
|
6: "frog", |
|
7: "horse", |
|
8: "ship", |
|
9: "truck", |
|
} |
|
IMAGE_SIZE = 32 |
|
|
|
|
|
model = from_pretrained_keras("keras-io/cct") |
|
|
|
|
|
def reshape_image(image): |
|
image = tf.convert_to_tensor(image) |
|
image.set_shape([None, None, 3]) |
|
image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE]) |
|
image = tf.expand_dims(image, axis=0) |
|
return image |
|
|
|
|
|
def classify_image(input_image): |
|
input_image = reshape_image(input_image) |
|
logits = model.predict(input_image).flatten() |
|
predictions = tf.nn.softmax(logits) |
|
output_labels = {CLASSES[i]: float(predictions[i]) for i in CLASSES.keys()} |
|
return output_labels |
|
|
|
|
|
|
|
examples = [["./bird.png"], ["./cat.png"], ["./dog.png"], ["./horse.png"]] |
|
title = "Image Classification using Compact Convolutional Transformer (CCT)" |
|
description = """ |
|
Upload an image or select one from the examples and ask the model to label it! |
|
<br /> |
|
The model was trained on the <a href="https://www.cs.toronto.edu/~kriz/cifar.html" target="_blank">CIFAR-10 dataset</a>. Therefore, it is able to recognise these 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck. |
|
<br /> |
|
<br /> |
|
<p> |
|
<b>Model:</b> <a href="https://huggingface.co/keras-io/cct" target="_blank">https://huggingface.co/keras-io/cct</a> |
|
<br /> |
|
<b>Keras Example:</b> <a href="https://keras.io/examples/vision/cct/" target="_blank">https://keras.io/examples/vision/cct/</a> |
|
</p> |
|
<br /> |
|
""" |
|
article = """ |
|
<div style="text-align: center;"> |
|
Space by <a href="https://github.com/EdAbati" target="_blank">Edoardo Abati</a> |
|
<br /> |
|
Keras example by <a href="https://twitter.com/RisingSayak" target="_blank">Sayak Paul</a> |
|
</div> |
|
""" |
|
|
|
interface = gr.Interface( |
|
fn=classify_image, |
|
inputs=gr.inputs.Image(), |
|
outputs=gr.outputs.Label(), |
|
examples=examples, |
|
title=title, |
|
description=description, |
|
article=article, |
|
allow_flagging="never", |
|
) |
|
interface.launch(enable_queue=True) |
|
|