File size: 2,665 Bytes
ffd7267
 
 
6dbbb44
ffd7267
 
 
6dbbb44
ffd7267
6dbbb44
 
 
 
 
 
ffd7267
 
6dbbb44
 
 
 
 
 
 
 
 
 
ffd7267
 
 
6dbbb44
 
 
 
 
ffd7267
 
 
 
 
 
aa767bc
 
 
 
 
 
ffd7267
 
 
 
 
 
 
 
 
 
 
345be36
ffd7267
79e50e5
ffd7267
55b8eb4
345be36
 
 
1cf13d0
345be36
79e50e5
345be36
aa767bc
 
ffd7267
345be36
ffd7267
1cf13d0
 
ffd7267
 
 
 
6dbbb44
ffd7267
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

import torch
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor

# Define model
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32,64, kernel_size=5)
        self.fc1 = nn.Linear(3*3*64, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        #x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(F.max_pool2d(self.conv3(x),2))
        x = F.dropout(x, p=0.5, training=self.training)
        x = x.view(-1,3*3*64 )
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        logits = self.fc2(x)
        return logits


model = ConvNet()
model.load_state_dict(
    torch.load("weights/mnist_convnet_model.pth",
    map_location=torch.device('cpu'))
    )

model.eval()

import gradio as gr
from torchvision import transforms

import os
import glob

examples_dir = './examples'
example_files = glob.glob(os.path.join(examples_dir, '*.png'))

def predict(image):
    tsr_image = transforms.ToTensor()(image)

    with torch.no_grad():
        pred = model(tsr_image)
        prob = torch.nn.functional.softmax(pred[0], dim=0)

    confidences = {i: float(prob[i]) for i in range(10)} 
    return confidences


with gr.Blocks(css=".gradio-container {background:honeydew;}", title="MNIST Classification"
               ) as demo:
    gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">MNIST Classification</div>""")

    with gr.Row():
        with gr.Tab("Canvas"):
            input_image1 = gr.Image(source="canvas", type="pil", image_mode="L", shape=(28,28), invert_colors=True)    
            send_btn1 = gr.Button("Infer")

        with gr.Tab("Image file"):
            input_image2 = gr.Image(type="pil", image_mode="L", shape=(28, 28), invert_colors=True)
            send_btn2 = gr.Button("Infer")
            gr.Examples(example_files, inputs=input_image2)
            #gr.Examples(['examples/sample02.png', 'examples/sample04.png'], inputs=input_image2)

        output_label=gr.Label(label="Probabilities", num_top_classes=3)

    send_btn1.click(fn=predict, inputs=input_image1, outputs=output_label)
    send_btn2.click(fn=predict, inputs=input_image2, outputs=output_label)

# demo.queue(concurrency_count=3)
demo.launch()


### EOF ###