Jeet Paul commited on
Commit
c3a61ad
·
1 Parent(s): 8b6ef4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py CHANGED
@@ -2,6 +2,68 @@ import numpy as np
2
  import gradio as gr
3
  from PIL import Image
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def ReLU(Z):
6
  return np.maximum(Z, 0)
7
 
@@ -59,3 +121,4 @@ iface = gr.Interface(
59
 
60
  if __name__ == "__main__":
61
  iface.launch()
 
 
2
  import gradio as gr
3
  from PIL import Image
4
 
5
+ def ReLU(Z):
6
+ return np.maximum(Z, 0)
7
+
8
+ def softmax(Z):
9
+ A = np.exp(Z) / np.sum(np.exp(Z), axis=0)
10
+ return A
11
+
12
+ def init_params():
13
+ W1 = np.random.rand(10, 784) - 0.5
14
+ b1 = np.random.rand(10, 1) - 0.5
15
+ W2 = np.random.rand(10, 10) - 0.5
16
+ b2 = np.random.rand(10, 1) - 0.5
17
+ return W1, b1, W2, b2
18
+
19
+ def forward_prop(W1, b1, W2, b2, X):
20
+ Z1 = W1.dot(X) + b1
21
+ A1 = ReLU(Z1)
22
+ Z2 = W2.dot(A1) + b2
23
+ A2 = softmax(Z2)
24
+ return Z1, A1, Z2, A2
25
+
26
+ def get_predictions(A2):
27
+ return np.argmax(A2, axis=0)
28
+
29
+ def make_predictions(X, W1, b1, W2, b2):
30
+ _, _, _, A2 = forward_prop(W1, b1, W2, b2, X)
31
+ predictions = get_predictions(A2)
32
+ return predictions
33
+
34
+ def predict_digit(img):
35
+ # Load the trained parameters
36
+ params = np.load("trained_params.npz", allow_pickle=True)
37
+ W1, b1, W2, b2 = params["W1"], params["b1"], params["W2"], params["b2"]
38
+
39
+ # Convert the sketchpad drawing to grayscale and resize it to (28, 28)
40
+ img_pil = Image.fromarray(np.uint8(img * 255)).convert("L").resize((28, 28))
41
+
42
+ # Convert the image to a NumPy array and normalize it
43
+ X = np.array(img_pil).reshape((784, 1)) / 255.
44
+
45
+ # Get the prediction
46
+ prediction = make_predictions(X, W1, b1, W2, b2)
47
+
48
+ return int(prediction)
49
+
50
+ iface = gr.Interface(
51
+ fn=predict_digit,
52
+ inputs="sketchpad",
53
+ outputs=gr.outputs.Label(num_top_classes=3),
54
+ live=True,
55
+ capture_session=True,
56
+ title="Handwritten Digit Recognizer",
57
+ description="Draw a digit using your mouse, and the model will try to recognize it.",
58
+ )
59
+
60
+ if __name__ == "__main__":
61
+ iface.launch()
62
+
63
+ '''import numpy as np
64
+ import gradio as gr
65
+ from PIL import Image
66
+
67
  def ReLU(Z):
68
  return np.maximum(Z, 0)
69
 
 
121
 
122
  if __name__ == "__main__":
123
  iface.launch()
124
+ '''