ClassCat commited on
Commit
ffd7267
·
1 Parent(s): b31591d

add app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ from torchvision import datasets
5
+ from torchvision.transforms import ToTensor
6
+
7
+
8
+ # Define model
9
+ class NeuralNetwork(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.flatten = nn.Flatten()
13
+ self.linear_relu_stack = nn.Sequential(
14
+ nn.Linear(28*28, 512),
15
+ nn.ReLU(),
16
+ nn.Linear(512, 512),
17
+ nn.ReLU(),
18
+ nn.Linear(512, 10)
19
+ )
20
+
21
+ def forward(self, x):
22
+ x = self.flatten(x)
23
+ logits = self.linear_relu_stack(x)
24
+ return logits
25
+
26
+
27
+ model = NeuralNetwork()
28
+ model.load_state_dict(torch.load("model_mnist_mlp.pth"))
29
+
30
+ model.eval()
31
+
32
+ import gradio as gr
33
+ from torchvision import transforms
34
+
35
+ def predict(image):
36
+ tsr_image = transforms.ToTensor()(image)
37
+
38
+ with torch.no_grad():
39
+ pred = model(tsr_image)
40
+ prob = torch.nn.functional.softmax(pred[0], dim=0)
41
+
42
+ confidences = {i: float(prob[i]) for i in range(10)}
43
+
44
+ return confidences
45
+
46
+
47
+ with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", title="テスト"
48
+ ) as demo:
49
+ gr.HTML('<div style="font-size:12pt; text-align:center; color:yellow;"MNIST 分類器</div>')
50
+
51
+ with gr.Row():
52
+ input_image = gr.Image(label="画像入力", type="pil", image_mode="L", shape=(28, 28), invert_colors=True)
53
+
54
+ output_label=gr.Label(label="予測確率", num_top_classes=5)
55
+
56
+ send_btn = gr.Button("予測する")
57
+ send_btn.click(fn=predict, inputs=input_image, outputs=output_label)
58
+
59
+ # demo.queue(concurrency_count=3)
60
+ demo.launch()
61
+
62
+ ### EOF ###