|
import gradio as gr
|
|
import numpy as np
|
|
import pickle
|
|
from PIL import Image
|
|
import os
|
|
import random
|
|
|
|
|
|
with open('model.pkl', 'rb') as f:
|
|
model_params = pickle.load(f)
|
|
|
|
W1 = model_params['W1']
|
|
b1 = model_params['b1']
|
|
W2 = model_params['W2']
|
|
b2 = model_params['b2']
|
|
|
|
|
|
def ReLu(Z):
|
|
return np.maximum(Z, 0)
|
|
|
|
def softmax(Z):
|
|
return np.exp(Z) / sum(np.exp(Z))
|
|
|
|
def forward_prop(W1, b1, W2, b2, X):
|
|
Z1 = W1.dot(X) + b1
|
|
A1 = ReLu(Z1)
|
|
Z2 = W2.dot(A1) + b2
|
|
A2 = softmax(Z2)
|
|
return Z1, Z2, A1, A2
|
|
|
|
def get_predictions(A2):
|
|
return np.argmax(A2, 0)
|
|
|
|
def preprocess_image(image):
|
|
|
|
img = image.convert('L')
|
|
|
|
|
|
img = img.resize((28, 28))
|
|
|
|
|
|
img_array = np.array(img).reshape(1, 28*28) / 255.0
|
|
|
|
return img_array.T
|
|
|
|
|
|
def predict_digit(image):
|
|
X = preprocess_image(image)
|
|
|
|
|
|
_, _, _, A2 = forward_prop(W1, b1, W2, b2, X)
|
|
|
|
|
|
prediction = get_predictions(A2)
|
|
|
|
return int(prediction[0])
|
|
|
|
|
|
sample_images = []
|
|
sample_dir = "sample_images"
|
|
for filename in os.listdir(sample_dir):
|
|
if filename.endswith((".png", ".jpg", ".jpeg")):
|
|
img_path = os.path.join(sample_dir, filename)
|
|
sample_images.append(img_path)
|
|
|
|
|
|
def select_random_image():
|
|
return random.choice(sample_images)
|
|
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("# Handwritten Digit Recognition")
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
input_image = gr.Image(type="pil", label="Input Image")
|
|
upload_button = gr.UploadButton("Upload Image", file_types=["image"])
|
|
sample_button = gr.Button("Use Random Sample Image")
|
|
|
|
with gr.Column():
|
|
output_label = gr.Label(label="Prediction")
|
|
predict_button = gr.Button("Predict")
|
|
|
|
upload_button.upload(fn=lambda file: file.name, inputs=upload_button, outputs=input_image)
|
|
sample_button.click(fn=select_random_image, inputs=None, outputs=input_image)
|
|
predict_button.click(fn=predict_digit, inputs=input_image, outputs=output_label)
|
|
|
|
gr.Markdown("## Sample Images")
|
|
with gr.Row():
|
|
for img_path in sample_images[:5]:
|
|
gr.Image(img_path, show_label=False, height=100)
|
|
with gr.Row():
|
|
for img_path in sample_images[5:10]:
|
|
gr.Image(img_path, show_label=False, height=100)
|
|
|
|
|
|
demo.launch() |