Spaces:
Sleeping
Sleeping
File size: 2,665 Bytes
ffd7267 6dbbb44 ffd7267 6dbbb44 ffd7267 6dbbb44 ffd7267 6dbbb44 ffd7267 6dbbb44 ffd7267 aa767bc ffd7267 345be36 ffd7267 79e50e5 ffd7267 55b8eb4 345be36 1cf13d0 345be36 79e50e5 345be36 aa767bc ffd7267 345be36 ffd7267 1cf13d0 ffd7267 6dbbb44 ffd7267 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor
# Define model
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
self.conv3 = nn.Conv2d(32,64, kernel_size=5)
self.fc1 = nn.Linear(3*3*64, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
#x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(F.max_pool2d(self.conv3(x),2))
x = F.dropout(x, p=0.5, training=self.training)
x = x.view(-1,3*3*64 )
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
logits = self.fc2(x)
return logits
model = ConvNet()
model.load_state_dict(
torch.load("weights/mnist_convnet_model.pth",
map_location=torch.device('cpu'))
)
model.eval()
import gradio as gr
from torchvision import transforms
import os
import glob
examples_dir = './examples'
example_files = glob.glob(os.path.join(examples_dir, '*.png'))
def predict(image):
tsr_image = transforms.ToTensor()(image)
with torch.no_grad():
pred = model(tsr_image)
prob = torch.nn.functional.softmax(pred[0], dim=0)
confidences = {i: float(prob[i]) for i in range(10)}
return confidences
with gr.Blocks(css=".gradio-container {background:honeydew;}", title="MNIST Classification"
) as demo:
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">MNIST Classification</div>""")
with gr.Row():
with gr.Tab("Canvas"):
input_image1 = gr.Image(source="canvas", type="pil", image_mode="L", shape=(28,28), invert_colors=True)
send_btn1 = gr.Button("Infer")
with gr.Tab("Image file"):
input_image2 = gr.Image(type="pil", image_mode="L", shape=(28, 28), invert_colors=True)
send_btn2 = gr.Button("Infer")
gr.Examples(example_files, inputs=input_image2)
#gr.Examples(['examples/sample02.png', 'examples/sample04.png'], inputs=input_image2)
output_label=gr.Label(label="Probabilities", num_top_classes=3)
send_btn1.click(fn=predict, inputs=input_image1, outputs=output_label)
send_btn2.click(fn=predict, inputs=input_image2, outputs=output_label)
# demo.queue(concurrency_count=3)
demo.launch()
### EOF ### |