im2 commited on
Commit
02729d4
1 Parent(s): 85738e1
Files changed (1) hide show
  1. app.py +3 -9
app.py CHANGED
@@ -32,15 +32,9 @@ model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
32
  model.eval()
33
 
34
  # Gradio preprocessing and prediction pipeline
35
- def predict_digit(image_dict):
36
- # Extract the image from the 'image' key in the dictionary (if it exists)
37
- if isinstance(image_dict, dict) and "image" in image_dict:
38
- image = image_dict["image"] # Access the image data
39
- else:
40
- raise ValueError("Invalid input format")
41
-
42
- # Convert the image (numpy array) to a PIL Image
43
- image = Image.fromarray(np.array(image)).convert('L') # Convert to grayscale
44
 
45
  # Preprocess: resize to 28x28 and normalize
46
  transform = transforms.Compose([
 
32
  model.eval()
33
 
34
  # Gradio preprocessing and prediction pipeline
35
+ def predict_digit(image):
36
+ # Directly handle the numpy array from gr.Sketchpad
37
+ image = Image.fromarray(image).convert('L') # Convert to grayscale
 
 
 
 
 
 
38
 
39
  # Preprocess: resize to 28x28 and normalize
40
  transform = transforms.Compose([