im2 commited on
Commit
85738e1
1 Parent(s): 95da1e4
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -32,9 +32,15 @@ model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
32
  model.eval()
33
 
34
  # Gradio preprocessing and prediction pipeline
35
- def predict_digit(image):
36
- # Convert the numpy array (from gr.Sketchpad) to a PIL Image
37
- image = Image.fromarray(image).convert('L') # Convert to grayscale
 
 
 
 
 
 
38
 
39
  # Preprocess: resize to 28x28 and normalize
40
  transform = transforms.Compose([
 
32
  model.eval()
33
 
34
  # Gradio preprocessing and prediction pipeline
35
+ def predict_digit(image_dict):
36
+ # Extract the image from the 'image' key in the dictionary (if it exists)
37
+ if isinstance(image_dict, dict) and "image" in image_dict:
38
+ image = image_dict["image"] # Access the image data
39
+ else:
40
+ raise ValueError("Invalid input format")
41
+
42
+ # Convert the image (numpy array) to a PIL Image
43
+ image = Image.fromarray(np.array(image)).convert('L') # Convert to grayscale
44
 
45
  # Preprocess: resize to 28x28 and normalize
46
  transform = transforms.Compose([