MIE1517_A1_Demo / app.py
dhkim2810's picture
MNIST practice
17c015f
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torchvision import transforms
class simpleCNN(nn.Module):
def __init__(self, num_classes=3):
super(simpleCNN, self).__init__()
self.name = "simpleCNN"
self.conv1 = nn.Conv2d(3, 5, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(5, 10, 5)
self.fc1 = nn.Linear(10 * 5 * 5, 32)
self.fc2 = nn.Linear(32, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 10 * 5 * 5)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
net = simpleCNN(num_classes=3)
net.load_state_dict(torch.load("./ckpt.pth", map_location=torch.device("cpu")))
net.eval()
class_labels = ["other", "car", "truck"]
transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
@torch.no_grad()
def predict(img):
global net
img = Image.fromarray(img.astype("uint8"), "RGB")
img = transform(img).unsqueeze(0)
pred = net(img).detach().numpy()[0]
pred = np.exp(pred) / np.sum(np.exp(pred))
return {class_labels[i]: float(pred[i]) for i in range(len(class_labels))}
iface = gr.Interface(fn=predict, inputs="image", outputs="label")
iface.launch()