File size: 1,597 Bytes
ffd7267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor


# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


model = NeuralNetwork()
model.load_state_dict(torch.load("model_mnist_mlp.pth"))

model.eval()

import gradio as gr
from torchvision import transforms

def predict(image):
    tsr_image = transforms.ToTensor()(image)

    with torch.no_grad():
        pred = model(tsr_image)
        prob = torch.nn.functional.softmax(pred[0], dim=0)

    confidences = {i: float(prob[i]) for i in range(10)} 

    return confidences


with gr.Blocks(css=".gradio-container {background:lightyellow;color:red;}", title="テスト"
               ) as demo:
    gr.HTML('<div style="font-size:12pt; text-align:center; color:yellow;"MNIST 分類器</div>')

    with gr.Row(): 
        input_image = gr.Image(label="画像入力", type="pil", image_mode="L", shape=(28, 28), invert_colors=True)

        output_label=gr.Label(label="予測確率", num_top_classes=5)

    send_btn = gr.Button("予測する")
    send_btn.click(fn=predict, inputs=input_image, outputs=output_label)

# demo.queue(concurrency_count=3)
demo.launch()

### EOF ###