Guest / tensor_network.py
Prositron's picture
Update tensor_network.py
2381fc5 verified
raw
history blame
2.63 kB
import torch
import torch.nn as nn
# Define an enhanced neural network model with more layers
class ComplexModel(nn.Module):
def __init__(self):
super(ComplexModel, self).__init__()
# First convolutional layer: input channels=3, output channels=16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(16)
# Second convolutional layer: input channels=16, output channels=32
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(32)
# Apply max pooling to reduce spatial dimensions by a factor of 2
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Third convolutional layer: input channels=32, output channels=64
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.bn3 = nn.BatchNorm2d(64)
# Fully connected layers:
# For input tensors of shape (batch_size, 3, 4, 4):
# - conv1 and conv2 keep spatial dimensions at 4x4.
# - After conv2, we apply pooling, reducing 4x4 -> 2x2.
# - conv3 keeps 2x2 spatial dimensions.
# Thus, the flattened feature size is 64 * 2 * 2 = 256.
self.fc1 = nn.Linear(64 * 2 * 2, 128)
self.fc2 = nn.Linear(128, 10) # For example, output layer with 10 classes
def forward(self, x):
# First conv layer with batch normalization and ReLU activation
x = self.conv1(x)
x = self.bn1(x)
x = torch.relu(x)
# Second conv layer with batch normalization and ReLU activation
x = self.conv2(x)
x = self.bn2(x)
x = torch.relu(x)
# Pooling to reduce spatial dimensions
x = self.pool(x)
# Third conv layer with batch normalization and ReLU activation
x = self.conv3(x)
x = self.bn3(x)
x = torch.relu(x)
# Flatten the tensor for the fully connected layers
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
# Example of creating input tensors (each with shape: batch_size=2, channels=3, height=4, width=4)
tensor1 = torch.rand(2, 3, 4, 4)
tensor2 = torch.rand(2, 3, 4, 4)
tensor3 = torch.rand(2, 3, 4, 4)
# Adding the tensors element-wise to form the input tensor
input_tensor = tensor1 + tensor2 + tensor3
# Initialize the enhanced model
model = ComplexModel()
# Forward pass through the model
output = model(input_tensor)
print("Output shape:", output.shape)
print("Output:", output)