import torch | |
import torch.nn as nn | |
# Define an enhanced neural network model with more layers | |
class ComplexModel(nn.Module): | |
def __init__(self): | |
super(ComplexModel, self).__init__() | |
# First convolutional layer: input channels=3, output channels=16 | |
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) | |
self.bn1 = nn.BatchNorm2d(16) | |
# Second convolutional layer: input channels=16, output channels=32 | |
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) | |
self.bn2 = nn.BatchNorm2d(32) | |
# Apply max pooling to reduce spatial dimensions by a factor of 2 | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
# Third convolutional layer: input channels=32, output channels=64 | |
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) | |
self.bn3 = nn.BatchNorm2d(64) | |
# Fully connected layers: | |
# For input tensors of shape (batch_size, 3, 4, 4): | |
# - conv1 and conv2 keep spatial dimensions at 4x4. | |
# - After conv2, we apply pooling, reducing 4x4 -> 2x2. | |
# - conv3 keeps 2x2 spatial dimensions. | |
# Thus, the flattened feature size is 64 * 2 * 2 = 256. | |
self.fc1 = nn.Linear(64 * 2 * 2, 128) | |
self.fc2 = nn.Linear(128, 10) # For example, output layer with 10 classes | |
def forward(self, x): | |
# First conv layer with batch normalization and ReLU activation | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = torch.relu(x) | |
# Second conv layer with batch normalization and ReLU activation | |
x = self.conv2(x) | |
x = self.bn2(x) | |
x = torch.relu(x) | |
# Pooling to reduce spatial dimensions | |
x = self.pool(x) | |
# Third conv layer with batch normalization and ReLU activation | |
x = self.conv3(x) | |
x = self.bn3(x) | |
x = torch.relu(x) | |
# Flatten the tensor for the fully connected layers | |
x = x.view(x.size(0), -1) | |
x = self.fc1(x) | |
x = torch.relu(x) | |
x = self.fc2(x) | |
return x | |
# Example of creating input tensors (each with shape: batch_size=2, channels=3, height=4, width=4) | |
tensor1 = torch.rand(2, 3, 4, 4) | |
tensor2 = torch.rand(2, 3, 4, 4) | |
tensor3 = torch.rand(2, 3, 4, 4) | |
# Adding the tensors element-wise to form the input tensor | |
input_tensor = tensor1 + tensor2 + tensor3 | |
# Initialize the enhanced model | |
model = ComplexModel() | |
# Forward pass through the model | |
output = model(input_tensor) | |
print("Output shape:", output.shape) | |
print("Output:", output) |