mnist / model.py
tezuesh's picture
Upload model.py with huggingface_hub
2728806 verified
raw
history blame contribute delete
463 Bytes
import torch
import torch.nn as nn
class MNISTModel(nn.Module):
def __init__(self):
super(MNISTModel, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128) # MNIST images are 28x28
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = x.view(-1, 28 * 28) # Flatten the input
x = torch.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x