Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
TITLE = "Handwritten Digit Recognition Demo" | |
DESCRIPTION = "This demo employs a basic CNN architecture inspired by [MIT 6.S191’s Lab2 Part1](https://github.com/aamini/introtodeeplearning/blob/master/lab2/Part1_MNIST.ipynb). "\ | |
"It achieves about 98% accuracy on the MNIST test dataset but may perform poorly, particularly with digits 8 and 9, likely due to suboptimal image preprocessing." | |
PIL_INTERPOLATION_METHODS = { | |
"nearest": Image.Resampling.NEAREST, | |
"bilinear": Image.Resampling.BILINEAR, | |
"bicubic": Image.Resampling.BICUBIC, | |
"hamming": Image.Resampling.HAMMING, | |
"box": Image.Resampling.BOX, | |
"lanczos": Image.Resampling.LANCZOS, | |
} | |
model = tf.keras.saving.load_model("tf_model_mnist") | |
def preprocess(image, resample_method): | |
""" Normalize Gradio image to MNIST format """ | |
image = image.resize((28, 28), PIL_INTERPOLATION_METHODS[resample_method]) | |
img_array = np.asarray(image, dtype=np.float32) | |
for i in range(img_array.shape[0]): | |
for j in range(img_array.shape[1]): | |
alpha = img_array[i, j, 3] | |
if alpha == 0.: | |
img_array[i, j] = [0., 0., 0., 255.] | |
else: | |
img_array[i, j] = [255., 255., 255., 255.] | |
new_image = Image.fromarray(img_array.astype(np.uint8), "RGBA") | |
new_image = new_image.convert("L") | |
image_array = tf.keras.utils.img_to_array(new_image) | |
image_array = (np.expand_dims(image_array, axis=0)/255.).astype(np.float32) | |
return image_array, new_image | |
def predict(img, resample_method): | |
img = img["composite"] | |
input_arr, new_image = preprocess(img, resample_method) | |
print("input:", input_arr.shape) | |
predictions = model.predict(input_arr) | |
return {str(i): predictions[0][i] for i in range(10)}, new_image | |
resample_method = gr.Dropdown( | |
choices=list(PIL_INTERPOLATION_METHODS.keys()), | |
value='bilinear', | |
) | |
input_image = gr.Sketchpad( | |
layers=False, | |
type="pil", | |
canvas_size=(500, 500), | |
) | |
demo = gr.Interface( | |
predict, | |
title=TITLE, | |
description=DESCRIPTION, | |
inputs=[input_image, resample_method], | |
outputs=['label', 'image'] | |
) | |
demo.launch() | |