Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from PIL import Image | |
from torchvision import transforms | |
class simpleCNN(nn.Module): | |
def __init__(self, num_classes=3): | |
super(simpleCNN, self).__init__() | |
self.name = "simpleCNN" | |
self.conv1 = nn.Conv2d(3, 5, 5) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.conv2 = nn.Conv2d(5, 10, 5) | |
self.fc1 = nn.Linear(10 * 5 * 5, 32) | |
self.fc2 = nn.Linear(32, num_classes) | |
def forward(self, x): | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = x.view(-1, 10 * 5 * 5) | |
x = F.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
net = simpleCNN(num_classes=3) | |
net.load_state_dict(torch.load("./ckpt.pth", map_location=torch.device("cpu"))) | |
net.eval() | |
class_labels = ["other", "car", "truck"] | |
transform = transforms.Compose( | |
[ | |
transforms.Resize((32, 32)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
def predict(img): | |
global net | |
img = Image.fromarray(img.astype("uint8"), "RGB") | |
img = transform(img).unsqueeze(0) | |
pred = net(img).detach().numpy()[0] | |
pred = np.exp(pred) / np.sum(np.exp(pred)) | |
return {class_labels[i]: float(pred[i]) for i in range(len(class_labels))} | |
iface = gr.Interface(fn=predict, inputs="image", outputs="label") | |
iface.launch() | |