Commit
·
56f9050
1
Parent(s):
44d0907
naming fix
Browse files- attn.py +9 -9
- configuration_minitransformer.py +1 -1
- modeling_minitransformer.py +6 -6
attn.py
CHANGED
@@ -27,13 +27,13 @@ class Attention(nn.Module):
|
|
27 |
|
28 |
self.device = torch.device("cuda")
|
29 |
self.bsz = config.bsz
|
30 |
-
self.
|
31 |
config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype
|
32 |
)
|
33 |
-
self.
|
34 |
config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype
|
35 |
)
|
36 |
-
self.
|
37 |
self.dropout = config.dropout
|
38 |
self.resid_dropout = nn.Dropout(self.dropout)
|
39 |
self.alibi_slopes = self._get_alibi_slopes(self.n_heads)
|
@@ -65,7 +65,7 @@ class Attention(nn.Module):
|
|
65 |
def forward(self, x):
|
66 |
bsz, seq_len, d_in = x.size()
|
67 |
|
68 |
-
qkv = self.
|
69 |
q, k, v = torch.chunk(qkv, 3, dim=2)
|
70 |
|
71 |
q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
|
@@ -82,7 +82,7 @@ class Attention(nn.Module):
|
|
82 |
softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
|
83 |
)
|
84 |
y = y.contiguous().view(bsz, seq_len, d_in)
|
85 |
-
y = self.resid_dropout(self.
|
86 |
return y
|
87 |
|
88 |
class AttentionSDPA(nn.Module):
|
@@ -98,15 +98,15 @@ class AttentionSDPA(nn.Module):
|
|
98 |
|
99 |
self.device = torch.device("cuda") # Technically don't need CUDA for SDPA
|
100 |
self.bsz = config.bsz
|
101 |
-
self.
|
102 |
-
self.
|
103 |
self.dropout = config.dropout
|
104 |
self.resid_dropout = nn.Dropout(self.dropout)
|
105 |
|
106 |
def forward(self, x):
|
107 |
bsz, seq_len, d_in = x.size()
|
108 |
|
109 |
-
qkv = self.
|
110 |
q, k, v = torch.chunk(qkv, 3, dim=2)
|
111 |
|
112 |
q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
|
@@ -121,5 +121,5 @@ class AttentionSDPA(nn.Module):
|
|
121 |
|
122 |
y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in)
|
123 |
|
124 |
-
y = self.resid_dropout(self.
|
125 |
return y
|
|
|
27 |
|
28 |
self.device = torch.device("cuda")
|
29 |
self.bsz = config.bsz
|
30 |
+
self.attn = nn.Linear(
|
31 |
config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype
|
32 |
)
|
33 |
+
self.o_proj = nn.Linear(
|
34 |
config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype
|
35 |
)
|
36 |
+
self.o_proj.SCALE_INIT = 1
|
37 |
self.dropout = config.dropout
|
38 |
self.resid_dropout = nn.Dropout(self.dropout)
|
39 |
self.alibi_slopes = self._get_alibi_slopes(self.n_heads)
|
|
|
65 |
def forward(self, x):
|
66 |
bsz, seq_len, d_in = x.size()
|
67 |
|
68 |
+
qkv = self.attn(x)
|
69 |
q, k, v = torch.chunk(qkv, 3, dim=2)
|
70 |
|
71 |
q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
|
|
|
82 |
softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
|
83 |
)
|
84 |
y = y.contiguous().view(bsz, seq_len, d_in)
|
85 |
+
y = self.resid_dropout(self.o_proj(y))
|
86 |
return y
|
87 |
|
88 |
class AttentionSDPA(nn.Module):
|
|
|
98 |
|
99 |
self.device = torch.device("cuda") # Technically don't need CUDA for SDPA
|
100 |
self.bsz = config.bsz
|
101 |
+
self.attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype)
|
102 |
+
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype)
|
103 |
self.dropout = config.dropout
|
104 |
self.resid_dropout = nn.Dropout(self.dropout)
|
105 |
|
106 |
def forward(self, x):
|
107 |
bsz, seq_len, d_in = x.size()
|
108 |
|
109 |
+
qkv = self.attn(x)
|
110 |
q, k, v = torch.chunk(qkv, 3, dim=2)
|
111 |
|
112 |
q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
|
|
|
121 |
|
122 |
y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in)
|
123 |
|
124 |
+
y = self.resid_dropout(self.o_proj(y))
|
125 |
return y
|
configuration_minitransformer.py
CHANGED
@@ -8,7 +8,7 @@ class MiniTransformerConfig(PretrainedConfig):
|
|
8 |
self,
|
9 |
bsz: int = 1,
|
10 |
n_embd: int = 768,
|
11 |
-
n_heads: int =
|
12 |
n_layers: int = 27,
|
13 |
seq_len: int = 8192,
|
14 |
window_size: int = 8192,
|
|
|
8 |
self,
|
9 |
bsz: int = 1,
|
10 |
n_embd: int = 768,
|
11 |
+
n_heads: int = 12,
|
12 |
n_layers: int = 27,
|
13 |
seq_len: int = 8192,
|
14 |
window_size: int = 8192,
|
modeling_minitransformer.py
CHANGED
@@ -121,12 +121,12 @@ class MiniTransformer(PreTrainedModel):
|
|
121 |
elif isinstance(module, nn.Embedding):
|
122 |
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
|
123 |
elif isinstance(module, Attention):
|
124 |
-
torch.nn.init.xavier_normal_(module.
|
125 |
-
torch.nn.init.xavier_normal_(module.
|
126 |
-
if module.
|
127 |
-
torch.nn.init.zeros_(module.
|
128 |
-
if module.
|
129 |
-
torch.nn.init.zeros_(module.
|
130 |
@staticmethod
|
131 |
def top_k_top_p_filtering(
|
132 |
logits: torch.Tensor,
|
|
|
121 |
elif isinstance(module, nn.Embedding):
|
122 |
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
|
123 |
elif isinstance(module, Attention):
|
124 |
+
torch.nn.init.xavier_normal_(module.attn.weight)
|
125 |
+
torch.nn.init.xavier_normal_(module.o_proj.weight)
|
126 |
+
if module.attn.bias is not None:
|
127 |
+
torch.nn.init.zeros_(module.attn.bias)
|
128 |
+
if module.o_proj.bias is not None:
|
129 |
+
torch.nn.init.zeros_(module.o_proj.bias)
|
130 |
@staticmethod
|
131 |
def top_k_top_p_filtering(
|
132 |
logits: torch.Tensor,
|