File size: 2,894 Bytes
fd35fa1
 
 
 
164ffbc
 
fd35fa1
164ffbc
fd35fa1
 
 
 
 
 
 
 
164ffbc
fd35fa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164ffbc
fd35fa1
 
 
 
 
 
 
 
 
 
 
164ffbc
 
 
 
 
 
 
fd35fa1
164ffbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import numpy as np
import pickle
from PIL import Image
import os
import random

# 1. Load the model
with open('model.pkl', 'rb') as f:
    model_params = pickle.load(f)

W1 = model_params['W1']
b1 = model_params['b1']
W2 = model_params['W2']
b2 = model_params['b2']

# 2. Define helper functions
def ReLu(Z):
    return np.maximum(Z, 0)

def softmax(Z):
    return np.exp(Z) / sum(np.exp(Z))

def forward_prop(W1, b1, W2, b2, X):
    Z1 = W1.dot(X) + b1
    A1 = ReLu(Z1)
    Z2 = W2.dot(A1) + b2
    A2 = softmax(Z2)
    return Z1, Z2, A1, A2

def get_predictions(A2):
    return np.argmax(A2, 0)

def preprocess_image(image):
    # Convert to grayscale
    img = image.convert('L')
    
    # Resize the image
    img = img.resize((28, 28))
    
    # Convert to numpy array and normalize
    img_array = np.array(img).reshape(1, 28*28) / 255.0
    
    return img_array.T  # Transpose to match the shape (784, 1)

# 3. Define prediction function
def predict_digit(image):
    X = preprocess_image(image)
    
    # Forward propagation
    _, _, _, A2 = forward_prop(W1, b1, W2, b2, X)
    
    # Get the prediction
    prediction = get_predictions(A2)
    
    return int(prediction[0])

# 4. Load sample images
sample_images = []
sample_dir = "sample_images"  # Make sure this directory exists in your Space
for filename in os.listdir(sample_dir):
    if filename.endswith((".png", ".jpg", ".jpeg")):
        img_path = os.path.join(sample_dir, filename)
        sample_images.append(img_path)

# 5. Define function to select random image
def select_random_image():
    return random.choice(sample_images)

# 6. Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Handwritten Digit Recognition")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
            upload_button = gr.UploadButton("Upload Image", file_types=["image"])
            sample_button = gr.Button("Use Random Sample Image")
        
        with gr.Column():
            output_label = gr.Label(label="Prediction")
            predict_button = gr.Button("Predict")
    
    upload_button.upload(fn=lambda file: file.name, inputs=upload_button, outputs=input_image)
    sample_button.click(fn=select_random_image, inputs=None, outputs=input_image)
    predict_button.click(fn=predict_digit, inputs=input_image, outputs=output_label)
    
    gr.Markdown("## Sample Images")
    with gr.Row():
        for img_path in sample_images[:5]:  # Display first 5 sample images
            gr.Image(img_path, show_label=False, height=100)
    with gr.Row():
        for img_path in sample_images[5:10]:  # Display next 5 sample images
            gr.Image(img_path, show_label=False, height=100)

# 7. Launch the app
demo.launch()