Spaces:
Sleeping
Sleeping
im2
commited on
Commit
•
41a6271
1
Parent(s):
02729d4
final no
Browse files
app.py
CHANGED
@@ -31,9 +31,15 @@ model = ImageClassifier()
|
|
31 |
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
|
32 |
model.eval()
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Gradio preprocessing and prediction pipeline
|
35 |
def predict_digit(image):
|
36 |
-
#
|
37 |
image = Image.fromarray(image).convert('L') # Convert to grayscale
|
38 |
|
39 |
# Preprocess: resize to 28x28 and normalize
|
@@ -54,8 +60,8 @@ def predict_digit(image):
|
|
54 |
|
55 |
# Create Gradio Interface
|
56 |
interface = gr.Interface(
|
57 |
-
fn=predict_digit,
|
58 |
-
inputs=gr.Sketchpad(),
|
59 |
outputs="text",
|
60 |
title="Digit Recognizer",
|
61 |
description="Draw a digit (0-9) and the model will predict the number!"
|
|
|
31 |
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
|
32 |
model.eval()
|
33 |
|
34 |
+
# Function to process sketchpad input
|
35 |
+
def sketchToNumpy(image):
|
36 |
+
# Extract the 'composite' key from the sketchpad input dictionary
|
37 |
+
imArray = image['composite'] # 'composite' contains the drawn image
|
38 |
+
return imArray
|
39 |
+
|
40 |
# Gradio preprocessing and prediction pipeline
|
41 |
def predict_digit(image):
|
42 |
+
# Convert the sketchpad input into a PIL Image
|
43 |
image = Image.fromarray(image).convert('L') # Convert to grayscale
|
44 |
|
45 |
# Preprocess: resize to 28x28 and normalize
|
|
|
60 |
|
61 |
# Create Gradio Interface
|
62 |
interface = gr.Interface(
|
63 |
+
fn=lambda x: predict_digit(sketchToNumpy(x)),
|
64 |
+
inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L'),
|
65 |
outputs="text",
|
66 |
title="Digit Recognizer",
|
67 |
description="Draw a digit (0-9) and the model will predict the number!"
|