|
import torch |
|
import torch.nn as nn |
|
|
|
class FourDimensionalTransformer(nn.Module): |
|
def __init__(self, num_layers=16, embed_dim=7, num_heads=1, num_extra_tokens=16, num_classes=10): |
|
super(FourDimensionalTransformer, self).__init__() |
|
self.embed_dim = embed_dim |
|
self.num_extra_tokens = num_extra_tokens |
|
|
|
|
|
self.embedding = nn.Conv2d(3, embed_dim, kernel_size=1) |
|
|
|
|
|
self.extra_tokens = nn.Parameter(torch.randn(num_extra_tokens, embed_dim)) |
|
|
|
|
|
self.attention_layers = nn.ModuleList([ |
|
nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads) |
|
for _ in range(num_layers) |
|
]) |
|
self.layer_norms = nn.ModuleList([ |
|
nn.LayerNorm(embed_dim) |
|
for _ in range(num_layers) |
|
]) |
|
|
|
|
|
self.gru = nn.GRUCell(embed_dim, embed_dim) |
|
|
|
|
|
self.fc = nn.Linear((16 + num_extra_tokens) * embed_dim, num_classes) |
|
|
|
def forward(self, x): |
|
|
|
batch_size = x.size(0) |
|
|
|
|
|
x = self.embedding(x) |
|
|
|
|
|
|
|
x = x.view(batch_size, self.embed_dim, -1).permute(2, 0, 1) |
|
|
|
|
|
|
|
extra_tokens = self.extra_tokens.unsqueeze(1).expand(-1, batch_size, -1) |
|
x = torch.cat([x, extra_tokens], dim=0) |
|
|
|
|
|
for attn, norm in zip(self.attention_layers, self.layer_norms): |
|
residual = x |
|
attn_out, _ = attn(x, x, x) |
|
|
|
x = norm(residual + attn_out) |
|
|
|
|
|
|
|
seq_len, batch, embed_dim = x.shape |
|
x_flat = x.view(seq_len * batch, embed_dim) |
|
|
|
x_updated_flat = self.gru(x_flat, x_flat) |
|
x = x_updated_flat.view(seq_len, batch, embed_dim) |
|
|
|
|
|
|
|
x = x.permute(1, 0, 2).contiguous() |
|
x = x.view(batch_size, -1) |
|
|
|
|
|
out = self.fc(x) |
|
return out |
|
|
|
|
|
input_tensor = torch.rand(2, 3, 4, 4) |
|
model = FourDimensionalTransformer(num_layers=16, embed_dim=7, num_heads=1, num_extra_tokens=16, num_classes=10) |
|
output = model(input_tensor) |
|
|
|
print("Output shape:", output.shape) |
|
print("Output:", output) |