File size: 3,577 Bytes
8ab2d14
 
 
86c4d14
f6f9962
86c4d14
f6f9962
 
 
86c4d14
f6f9962
 
86c4d14
f6f9962
 
 
 
 
 
 
 
 
 
 
08801fb
f6f9962
 
 
86c4d14
f6f9962
8ab2d14
 
86c4d14
f6f9962
2381fc5
86c4d14
f6f9962
2381fc5
86c4d14
f6f9962
86c4d14
8ab2d14
f6f9962
 
 
86c4d14
8ab2d14
f6f9962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86c4d14
f6f9962
 
8ab2d14
f6f9962
 
86c4d14
8ab2d14
 
2381fc5
f6f9962
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn

class FourDimensionalTransformer(nn.Module):
    def __init__(self, num_layers=16, embed_dim=7, num_heads=1, num_extra_tokens=16, num_classes=10):
        super(FourDimensionalTransformer, self).__init__()
        self.embed_dim = embed_dim
        self.num_extra_tokens = num_extra_tokens

        # Input embedding layer to map the input to the desired embedding dimension.
        self.embedding = nn.Conv2d(3, embed_dim, kernel_size=1)

        # Learnable extra tokens (to augment the spatial tokens).
        self.extra_tokens = nn.Parameter(torch.randn(num_extra_tokens, embed_dim))

        # Build a stack of self-attention layers with layer normalization.
        self.attention_layers = nn.ModuleList([
            nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
            for _ in range(num_layers)
        ])
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(embed_dim)
            for _ in range(num_layers)
        ])
        
        # GRU cell for recurrent updating—mimicking working memory or recurrent feedback.
        self.gru = nn.GRUCell(embed_dim, embed_dim)

        # Final classification head: Flattened after applying self-attention + GRU.
        self.fc = nn.Linear((16 + num_extra_tokens) * embed_dim, num_classes)

    def forward(self, x):
        # x: [batch, 3, height=4, width=4]
        batch_size = x.size(0)
        
        # Embed the input: [batch, 3, height, width] -> [batch, embed_dim, height, width]
        x = self.embedding(x)
        
        # Flatten spatial dimensions: [batch, embed_dim, height, width] -> [batch, embed_dim, height * width]
        # Then permute to [sequence_length, batch, embed_dim] for attention.
        x = x.view(batch_size, self.embed_dim, -1).permute(2, 0, 1)  # [height * width, batch, embed_dim]

        # Expand and concatenate extra tokens: extra_tokens [num_extra_tokens, embed_dim]
        # becomes [num_extra_tokens, batch, embed_dim] and concatenated along sequence dim.
        extra_tokens = self.extra_tokens.unsqueeze(1).expand(-1, batch_size, -1)
        x = torch.cat([x, extra_tokens], dim=0)  # [height * width + num_extra_tokens, batch, embed_dim]

        # Process through the transformer layers with recurrent GRU updates.
        for attn, norm in zip(self.attention_layers, self.layer_norms):
            residual = x
            attn_out, _ = attn(x, x, x)
            # Residual connection and layer normalization.
            x = norm(residual + attn_out)
            
            # --- Brain-inspired recurrent update ---
            # Reshape tokens to apply GRUCell in parallel.
            seq_len, batch, embed_dim = x.shape
            x_flat = x.view(seq_len * batch, embed_dim)
            # Use the same x_flat as both input and hidden state.
            x_updated_flat = self.gru(x_flat, x_flat)
            x = x_updated_flat.view(seq_len, batch, embed_dim)
            # --- End recurrent update ---
        
        # Rearrange back to [batch, sequence_length, embed_dim] and flatten.
        x = x.permute(1, 0, 2).contiguous()
        x = x.view(batch_size, -1)
        
        # Final fully connected layer (classification head).
        out = self.fc(x)
        return out

# Example usage:
input_tensor = torch.rand(2, 3, 4, 4)  # [batch=2, channels=3, height=4, width=4]
model = FourDimensionalTransformer(num_layers=16, embed_dim=7, num_heads=1, num_extra_tokens=16, num_classes=10)
output = model(input_tensor)

print("Output shape:", output.shape)
print("Output:", output)