im2 commited on
Commit
41a6271
1 Parent(s): 02729d4
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -31,9 +31,15 @@ model = ImageClassifier()
31
  model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
32
  model.eval()
33
 
 
 
 
 
 
 
34
  # Gradio preprocessing and prediction pipeline
35
  def predict_digit(image):
36
- # Directly handle the numpy array from gr.Sketchpad
37
  image = Image.fromarray(image).convert('L') # Convert to grayscale
38
 
39
  # Preprocess: resize to 28x28 and normalize
@@ -54,8 +60,8 @@ def predict_digit(image):
54
 
55
  # Create Gradio Interface
56
  interface = gr.Interface(
57
- fn=predict_digit,
58
- inputs=gr.Sketchpad(), # Sketchpad for users to draw
59
  outputs="text",
60
  title="Digit Recognizer",
61
  description="Draw a digit (0-9) and the model will predict the number!"
 
31
  model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
32
  model.eval()
33
 
34
+ # Function to process sketchpad input
35
+ def sketchToNumpy(image):
36
+ # Extract the 'composite' key from the sketchpad input dictionary
37
+ imArray = image['composite'] # 'composite' contains the drawn image
38
+ return imArray
39
+
40
  # Gradio preprocessing and prediction pipeline
41
  def predict_digit(image):
42
+ # Convert the sketchpad input into a PIL Image
43
  image = Image.fromarray(image).convert('L') # Convert to grayscale
44
 
45
  # Preprocess: resize to 28x28 and normalize
 
60
 
61
  # Create Gradio Interface
62
  interface = gr.Interface(
63
+ fn=lambda x: predict_digit(sketchToNumpy(x)),
64
+ inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L'),
65
  outputs="text",
66
  title="Digit Recognizer",
67
  description="Draw a digit (0-9) and the model will predict the number!"