import numpy as np import torch from pathlib import Path import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import transforms import gradio as gr transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.Grayscale(), transforms.ToTensor() ]) labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"] LABELS = {i:k for i, k in enumerate(labels)} # dictionary of index and label # Load model using DropoutThaiDigit instead class DropoutThaiDigit(nn.Module): def __init__(self): super(DropoutThaiDigit, self).__init__() self.fc1 = nn.Linear(28 * 28, 392) self.fc2 = nn.Linear(392, 196) self.fc3 = nn.Linear(196, 98) self.fc4 = nn.Linear(98, 10) self.dropout = nn.Dropout(0.1) def forward(self, x): x = x.view(-1, 28 * 28) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) x = F.relu(x) x = self.dropout(x) x = self.fc4(x) return x model = DropoutThaiDigit() model.load_state_dict(torch.load("thai_digit_net.pth")) model.eval() import numpy as np import torch from pathlib import Path import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import transforms import gradio as gr transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.Grayscale(), transforms.ToTensor() ]) labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"] LABELS = {i:k for i, k in enumerate(labels)} # dictionary of index and label # Load model using DropoutThaiDigit instead class DropoutThaiDigit(nn.Module): def __init__(self): super(DropoutThaiDigit, self).__init__() self.fc1 = nn.Linear(28 * 28, 392) self.fc2 = nn.Linear(392, 196) self.fc3 = nn.Linear(196, 98) self.fc4 = nn.Linear(98, 10) self.dropout = nn.Dropout(0.1) def forward(self, x): x = x.view(-1, 28 * 28) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) x = F.relu(x) x = self.dropout(x) x = self.fc4(x) return x model = DropoutThaiDigit() model.load_state_dict(torch.load("thai_digit_net.pth")) model.eval() def predict(img): """ Predict function takes image editor data and returns top 5 predictions as a dictionary: {label: confidence, label: confidence, ...} """ if img is None: return {} # Handle if Sketchpad returns a dictionary if isinstance(img, dict): # Try common keys that might contain the image img = img.get('image') or img.get('composite') or img.get('background') if img is None: return {} img = 1 - transform(img) # do not need to use 1 - transform(img) because gradio already do it probs = model(img).softmax(dim=1).ravel() probs, indices = torch.topk(probs, 5) # select top 5 confidences = {LABELS[i]: float(prob) for i, prob in zip(indices.tolist(), probs.tolist())} return confidences with gr.Blocks(title="Thai Digit Handwritten Classification") as interface: gr.Markdown("# Thai Digit Handwritten Classification") gr.Markdown("Draw a Thai digit (๐-๙) in the box below:") with gr.Row(): with gr.Column(): input_component = gr.Sketchpad( label="Draw Here", height=300, width=300, brush=gr.Brush(default_size=8, colors=["#000000"]), eraser=False, type="pil", canvas_size=(300, 300), ) with gr.Column(): output_component = gr.Label(label="Prediction", num_top_classes=5) # Set up the prediction input_component.change( fn=predict, inputs=input_component, outputs=output_component ) interface.launch()