File size: 3,577 Bytes
8ab2d14 86c4d14 f6f9962 86c4d14 f6f9962 86c4d14 f6f9962 86c4d14 f6f9962 08801fb f6f9962 86c4d14 f6f9962 8ab2d14 86c4d14 f6f9962 2381fc5 86c4d14 f6f9962 2381fc5 86c4d14 f6f9962 86c4d14 8ab2d14 f6f9962 86c4d14 8ab2d14 f6f9962 86c4d14 f6f9962 8ab2d14 f6f9962 86c4d14 8ab2d14 2381fc5 f6f9962 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
import torch.nn as nn
class FourDimensionalTransformer(nn.Module):
def __init__(self, num_layers=16, embed_dim=7, num_heads=1, num_extra_tokens=16, num_classes=10):
super(FourDimensionalTransformer, self).__init__()
self.embed_dim = embed_dim
self.num_extra_tokens = num_extra_tokens
# Input embedding layer to map the input to the desired embedding dimension.
self.embedding = nn.Conv2d(3, embed_dim, kernel_size=1)
# Learnable extra tokens (to augment the spatial tokens).
self.extra_tokens = nn.Parameter(torch.randn(num_extra_tokens, embed_dim))
# Build a stack of self-attention layers with layer normalization.
self.attention_layers = nn.ModuleList([
nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
for _ in range(num_layers)
])
self.layer_norms = nn.ModuleList([
nn.LayerNorm(embed_dim)
for _ in range(num_layers)
])
# GRU cell for recurrent updating—mimicking working memory or recurrent feedback.
self.gru = nn.GRUCell(embed_dim, embed_dim)
# Final classification head: Flattened after applying self-attention + GRU.
self.fc = nn.Linear((16 + num_extra_tokens) * embed_dim, num_classes)
def forward(self, x):
# x: [batch, 3, height=4, width=4]
batch_size = x.size(0)
# Embed the input: [batch, 3, height, width] -> [batch, embed_dim, height, width]
x = self.embedding(x)
# Flatten spatial dimensions: [batch, embed_dim, height, width] -> [batch, embed_dim, height * width]
# Then permute to [sequence_length, batch, embed_dim] for attention.
x = x.view(batch_size, self.embed_dim, -1).permute(2, 0, 1) # [height * width, batch, embed_dim]
# Expand and concatenate extra tokens: extra_tokens [num_extra_tokens, embed_dim]
# becomes [num_extra_tokens, batch, embed_dim] and concatenated along sequence dim.
extra_tokens = self.extra_tokens.unsqueeze(1).expand(-1, batch_size, -1)
x = torch.cat([x, extra_tokens], dim=0) # [height * width + num_extra_tokens, batch, embed_dim]
# Process through the transformer layers with recurrent GRU updates.
for attn, norm in zip(self.attention_layers, self.layer_norms):
residual = x
attn_out, _ = attn(x, x, x)
# Residual connection and layer normalization.
x = norm(residual + attn_out)
# --- Brain-inspired recurrent update ---
# Reshape tokens to apply GRUCell in parallel.
seq_len, batch, embed_dim = x.shape
x_flat = x.view(seq_len * batch, embed_dim)
# Use the same x_flat as both input and hidden state.
x_updated_flat = self.gru(x_flat, x_flat)
x = x_updated_flat.view(seq_len, batch, embed_dim)
# --- End recurrent update ---
# Rearrange back to [batch, sequence_length, embed_dim] and flatten.
x = x.permute(1, 0, 2).contiguous()
x = x.view(batch_size, -1)
# Final fully connected layer (classification head).
out = self.fc(x)
return out
# Example usage:
input_tensor = torch.rand(2, 3, 4, 4) # [batch=2, channels=3, height=4, width=4]
model = FourDimensionalTransformer(num_layers=16, embed_dim=7, num_heads=1, num_extra_tokens=16, num_classes=10)
output = model(input_tensor)
print("Output shape:", output.shape)
print("Output:", output) |