Jeet Paul commited on
Commit
8b6ef4c
·
1 Parent(s): a411364

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import numpy as np
2
- import pandas as pd
3
  import gradio as gr
4
- from matplotlib import pyplot as plt
5
 
6
  def ReLU(Z):
7
  return np.maximum(Z, 0)
@@ -32,16 +31,16 @@ def make_predictions(X, W1, b1, W2, b2):
32
  predictions = get_predictions(A2)
33
  return predictions
34
 
35
- def predict_digit(input_data):
36
  # Load the trained parameters
37
  params = np.load("trained_params.npz", allow_pickle=True)
38
  W1, b1, W2, b2 = params["W1"], params["b1"], params["W2"], params["b2"]
39
 
40
- # Convert the uploaded image to grayscale and resize it to (28, 28)
41
- img = Image.fromarray(input_data).convert("L").resize((28, 28))
42
 
43
  # Convert the image to a NumPy array and normalize it
44
- X = np.array(img).reshape((784, 1)) / 255.
45
 
46
  # Get the prediction
47
  prediction = make_predictions(X, W1, b1, W2, b2)
@@ -50,7 +49,7 @@ def predict_digit(input_data):
50
 
51
  iface = gr.Interface(
52
  fn=predict_digit,
53
- inputs=gr.inputs.Image(shape=(28, 28)),
54
  outputs=gr.outputs.Label(num_top_classes=3),
55
  live=True,
56
  capture_session=True,
 
1
  import numpy as np
 
2
  import gradio as gr
3
+ from PIL import Image
4
 
5
  def ReLU(Z):
6
  return np.maximum(Z, 0)
 
31
  predictions = get_predictions(A2)
32
  return predictions
33
 
34
+ def predict_digit(img):
35
  # Load the trained parameters
36
  params = np.load("trained_params.npz", allow_pickle=True)
37
  W1, b1, W2, b2 = params["W1"], params["b1"], params["W2"], params["b2"]
38
 
39
+ # Convert the sketchpad drawing to grayscale and resize it to (28, 28)
40
+ img_pil = Image.fromarray(np.uint8(img * 255)).convert("L").resize((28, 28))
41
 
42
  # Convert the image to a NumPy array and normalize it
43
+ X = np.array(img_pil).reshape((784, 1)) / 255.
44
 
45
  # Get the prediction
46
  prediction = make_predictions(X, W1, b1, W2, b2)
 
49
 
50
  iface = gr.Interface(
51
  fn=predict_digit,
52
+ inputs="sketchpad",
53
  outputs=gr.outputs.Label(num_top_classes=3),
54
  live=True,
55
  capture_session=True,