ClassCat commited on
Commit
6dbbb44
Β·
1 Parent(s): dd31a0d

update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -23
app.py CHANGED
@@ -1,31 +1,38 @@
1
 
2
  import torch
3
  from torch import nn
4
- from torchvision import datasets
5
  from torchvision.transforms import ToTensor
6
 
7
-
8
  # Define model
9
- class NeuralNetwork(nn.Module):
10
  def __init__(self):
11
- super().__init__()
12
- self.flatten = nn.Flatten()
13
- self.linear_relu_stack = nn.Sequential(
14
- nn.Linear(28*28, 512),
15
- nn.ReLU(),
16
- nn.Linear(512, 512),
17
- nn.ReLU(),
18
- nn.Linear(512, 10)
19
- )
20
 
21
  def forward(self, x):
22
- x = self.flatten(x)
23
- logits = self.linear_relu_stack(x)
 
 
 
 
 
 
 
 
24
  return logits
25
 
26
 
27
- model = NeuralNetwork()
28
- model.load_state_dict(torch.load("model_mnist_mlp.pth"))
 
 
 
29
 
30
  model.eval()
31
 
@@ -40,25 +47,24 @@ def predict(image):
40
  prob = torch.nn.functional.softmax(pred[0], dim=0)
41
 
42
  confidences = {i: float(prob[i]) for i in range(10)}
43
-
44
  return confidences
45
 
46
 
47
- with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", title="γƒ†γ‚Ήγƒˆ"
48
  ) as demo:
49
- gr.HTML('<div style="font-size:12pt; text-align:center; color:yellow;"MNIST εˆ†ι‘žε™¨</div>')
50
 
51
  with gr.Row():
52
  with gr.Tab("キャンバス"):
53
  input_image1 = gr.Image(label="画像ε…₯εŠ›", source="canvas", type="pil", image_mode="L", shape=(28,28), invert_colors=True)
54
- send_btn1 = gr.Button("δΊˆζΈ¬γ™γ‚‹")
55
 
56
  with gr.Tab("画像フゑむル"):
57
  input_image2 = gr.Image(label="画像ε…₯εŠ›", type="pil", image_mode="L", shape=(28, 28), invert_colors=True)
58
- send_btn2 = gr.Button("δΊˆζΈ¬γ™γ‚‹")
59
- gr.Examples(['examples/example02.png', 'examples/example04.png'], inputs=input_image2)
60
 
61
- output_label=gr.Label(label="δΊˆζΈ¬η’ΊηŽ‡", num_top_classes=5)
62
 
63
  send_btn1.click(fn=predict, inputs=input_image1, outputs=output_label)
64
  send_btn2.click(fn=predict, inputs=input_image2, outputs=output_label)
@@ -66,4 +72,5 @@ with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", titl
66
  # demo.queue(concurrency_count=3)
67
  demo.launch()
68
 
 
69
  ### EOF ###
 
1
 
2
  import torch
3
  from torch import nn
4
+ import torch.nn.functional as F
5
  from torchvision.transforms import ToTensor
6
 
 
7
  # Define model
8
+ class ConvNet(nn.Module):
9
  def __init__(self):
10
+ super(ConvNet, self).__init__()
11
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
12
+ self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
13
+ self.conv3 = nn.Conv2d(32,64, kernel_size=5)
14
+ self.fc1 = nn.Linear(3*3*64, 256)
15
+ self.fc2 = nn.Linear(256, 10)
 
 
 
16
 
17
  def forward(self, x):
18
+ x = F.relu(self.conv1(x))
19
+ #x = F.dropout(x, p=0.5, training=self.training)
20
+ x = F.relu(F.max_pool2d(self.conv2(x), 2))
21
+ x = F.dropout(x, p=0.5, training=self.training)
22
+ x = F.relu(F.max_pool2d(self.conv3(x),2))
23
+ x = F.dropout(x, p=0.5, training=self.training)
24
+ x = x.view(-1,3*3*64 )
25
+ x = F.relu(self.fc1(x))
26
+ x = F.dropout(x, training=self.training)
27
+ logits = self.fc2(x)
28
  return logits
29
 
30
 
31
+ model = ConvNet()
32
+ model.load_state_dict(
33
+ torch.load("weights/mnist_convnet_model.pth",
34
+ map_location=torch.device('cpu'))
35
+ )
36
 
37
  model.eval()
38
 
 
47
  prob = torch.nn.functional.softmax(pred[0], dim=0)
48
 
49
  confidences = {i: float(prob[i]) for i in range(10)}
 
50
  return confidences
51
 
52
 
53
+ with gr.Blocks(css=".gradio-container {background:honeydew;}", title="MNIST εˆ†ι‘žε™¨"
54
  ) as demo:
55
+ gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">MNIST εˆ†ι‘žε™¨</div>""")
56
 
57
  with gr.Row():
58
  with gr.Tab("キャンバス"):
59
  input_image1 = gr.Image(label="画像ε…₯εŠ›", source="canvas", type="pil", image_mode="L", shape=(28,28), invert_colors=True)
60
+ send_btn1 = gr.Button("ζŽ¨θ«–γ™γ‚‹")
61
 
62
  with gr.Tab("画像フゑむル"):
63
  input_image2 = gr.Image(label="画像ε…₯εŠ›", type="pil", image_mode="L", shape=(28, 28), invert_colors=True)
64
+ send_btn2 = gr.Button("ζŽ¨θ«–γ™γ‚‹")
65
+ gr.Examples(['examples/sample02.png', 'examples/sample04.png'], inputs=input_image2)
66
 
67
+ output_label=gr.Label(label="ζŽ¨θ«–η’ΊηŽ‡", num_top_classes=3)
68
 
69
  send_btn1.click(fn=predict, inputs=input_image1, outputs=output_label)
70
  send_btn2.click(fn=predict, inputs=input_image2, outputs=output_label)
 
72
  # demo.queue(concurrency_count=3)
73
  demo.launch()
74
 
75
+
76
  ### EOF ###