im2 commited on
Commit
95da1e4
1 Parent(s): 1937011

struggling

Browse files
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -32,12 +32,9 @@ model.load_state_dict(torch.hub.load_state_dict_from_url(model_path))
32
  model.eval()
33
 
34
  # Gradio preprocessing and prediction pipeline
35
- def predict_digit(image_dict):
36
- # Extract the image array from the 'mask' key (gr.Sketchpad output)
37
- image = image_dict["mask"] # Get the image from the dict
38
-
39
- # Convert the image to a numpy array, then to a PIL image, and preprocess
40
- image = Image.fromarray(np.array(image)).convert('L') # Convert to grayscale
41
 
42
  # Preprocess: resize to 28x28 and normalize
43
  transform = transforms.Compose([
 
32
  model.eval()
33
 
34
  # Gradio preprocessing and prediction pipeline
35
+ def predict_digit(image):
36
+ # Convert the numpy array (from gr.Sketchpad) to a PIL Image
37
+ image = Image.fromarray(image).convert('L') # Convert to grayscale
 
 
 
38
 
39
  # Preprocess: resize to 28x28 and normalize
40
  transform = transforms.Compose([