import torch import torch.nn as nn # Define an enhanced neural network model with more layers and self-attention class ComplexModel(nn.Module): def __init__(self): super(ComplexModel, self).__init__() # First convolutional layer: input channels=3, output channels=16 self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) self.bn1 = nn.BatchNorm2d(16) # Second convolutional layer: input channels=16, output channels=32 self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(32) # Max pooling to reduce spatial dimensions by a factor of 2 self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # Third convolutional layer: input channels=32, output channels=64 self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) self.bn3 = nn.BatchNorm2d(64) # Self-attention layer: # After conv3, the feature map is expected to be of shape [batch, 64, 2, 2]. # We treat the spatial dimensions (2x2=4 tokens) as the sequence length. # For nn.MultiheadAttention, the embed dimension is 64. self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=4) # Fully connected layers: # After conv3 and attention, the tensor shape remains [batch, 64, 2, 2], # so the flattened feature size is 64 * 2 * 2 = 256. self.fc1 = nn.Linear(64 * 2 * 2, 128) self.fc2 = nn.Linear(128, 10) # For example, output layer with 10 classes def forward(self, x): # First conv layer with batch normalization and ReLU activation x = self.conv1(x) x = self.bn1(x) x = torch.relu(x) # Second conv layer with batch normalization and ReLU activation x = self.conv2(x) x = self.bn2(x) x = torch.relu(x) # Pooling to reduce spatial dimensions x = self.pool(x) # Third conv layer with batch normalization and ReLU activation x = self.conv3(x) x = self.bn3(x) x = torch.relu(x) # --------- Self-Attention Block --------- # x shape: [batch_size, channels=64, height=2, width=2] batch, channels, height, width = x.shape # Flatten spatial dimensions: create a sequence of tokens. # New shape: [batch_size, channels, sequence_length] where sequence_length = height * width (4 tokens) x_flat = x.view(batch, channels, height * width) # Shape: [B, 64, 4] # Permute to match nn.MultiheadAttention input: [sequence_length, batch_size, embed_dim] x_flat = x_flat.permute(2, 0, 1) # Shape: [4, B, 64] # Apply self-attention (keys, queries, and values are all x_flat) attn_output, _ = self.attention(x_flat, x_flat, x_flat) # attn_output shape remains: [4, B, 64] # Permute back to [batch_size, channels, sequence_length] x_flat = attn_output.permute(1, 2, 0) # Shape: [B, 64, 4] # Reshape back to spatial dimensions: [B, 64, 2, 2] x = x_flat.view(batch, channels, height, width) # --------- End Self-Attention Block --------- # Flatten the tensor for the fully connected layers x = x.view(x.size(0), -1) # Flatten to [batch, 256] x = self.fc1(x) x = torch.relu(x) x = self.fc2(x) return x # Example of creating input tensors (each with shape: batch_size=2, channels=3, height=4, width=4) tensor1 = torch.rand(2, 3, 4, 4) tensor2 = torch.rand(2, 3, 4, 4) tensor3 = torch.rand(2, 3, 4, 4) # Adding the tensors element-wise to form the input tensor input_tensor = tensor1 + tensor2 + tensor3 # Initialize the enhanced model model = ComplexModel() # Forward pass through the model output = model(input_tensor) print("Output shape:", output.shape) print("Output:", output)