number-identity-model / number-identity-model.git
szili2011's picture
Update number-identity-model.git
b99e640 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define your model architecture
class NumberIdentityModel(nn.Module):
def __init__(self):
super(NumberIdentityModel, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10) # 10 classes (digits 0-9)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# Instantiate your model
model = NumberIdentityModel()
# Optionally, load state_dict from a trained model if available
# model.load_state_dict(torch.load('path_to_pretrained_model.pth'))
# Example of using dummy input
dummy_input = torch.randn(1, 1, 28, 28) # Replace with your input shape (batch_size, channels, height, width)
output = model(dummy_input)
# Save the model to a .pth file
model_path = "number-identity-model.pth"
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")