yagizdevre commited on
Commit
56f9050
·
1 Parent(s): 44d0907

naming fix

Browse files
attn.py CHANGED
@@ -27,13 +27,13 @@ class Attention(nn.Module):
27
 
28
  self.device = torch.device("cuda")
29
  self.bsz = config.bsz
30
- self.c_attn = nn.Linear(
31
  config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype
32
  )
33
- self.c_proj = nn.Linear(
34
  config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype
35
  )
36
- self.c_proj.SCALE_INIT = 1
37
  self.dropout = config.dropout
38
  self.resid_dropout = nn.Dropout(self.dropout)
39
  self.alibi_slopes = self._get_alibi_slopes(self.n_heads)
@@ -65,7 +65,7 @@ class Attention(nn.Module):
65
  def forward(self, x):
66
  bsz, seq_len, d_in = x.size()
67
 
68
- qkv = self.c_attn(x)
69
  q, k, v = torch.chunk(qkv, 3, dim=2)
70
 
71
  q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
@@ -82,7 +82,7 @@ class Attention(nn.Module):
82
  softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
83
  )
84
  y = y.contiguous().view(bsz, seq_len, d_in)
85
- y = self.resid_dropout(self.c_proj(y))
86
  return y
87
 
88
  class AttentionSDPA(nn.Module):
@@ -98,15 +98,15 @@ class AttentionSDPA(nn.Module):
98
 
99
  self.device = torch.device("cuda") # Technically don't need CUDA for SDPA
100
  self.bsz = config.bsz
101
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype)
102
- self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype)
103
  self.dropout = config.dropout
104
  self.resid_dropout = nn.Dropout(self.dropout)
105
 
106
  def forward(self, x):
107
  bsz, seq_len, d_in = x.size()
108
 
109
- qkv = self.c_attn(x)
110
  q, k, v = torch.chunk(qkv, 3, dim=2)
111
 
112
  q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
@@ -121,5 +121,5 @@ class AttentionSDPA(nn.Module):
121
 
122
  y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in)
123
 
124
- y = self.resid_dropout(self.c_proj(y))
125
  return y
 
27
 
28
  self.device = torch.device("cuda")
29
  self.bsz = config.bsz
30
+ self.attn = nn.Linear(
31
  config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype
32
  )
33
+ self.o_proj = nn.Linear(
34
  config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype
35
  )
36
+ self.o_proj.SCALE_INIT = 1
37
  self.dropout = config.dropout
38
  self.resid_dropout = nn.Dropout(self.dropout)
39
  self.alibi_slopes = self._get_alibi_slopes(self.n_heads)
 
65
  def forward(self, x):
66
  bsz, seq_len, d_in = x.size()
67
 
68
+ qkv = self.attn(x)
69
  q, k, v = torch.chunk(qkv, 3, dim=2)
70
 
71
  q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
 
82
  softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
83
  )
84
  y = y.contiguous().view(bsz, seq_len, d_in)
85
+ y = self.resid_dropout(self.o_proj(y))
86
  return y
87
 
88
  class AttentionSDPA(nn.Module):
 
98
 
99
  self.device = torch.device("cuda") # Technically don't need CUDA for SDPA
100
  self.bsz = config.bsz
101
+ self.attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype)
102
+ self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype)
103
  self.dropout = config.dropout
104
  self.resid_dropout = nn.Dropout(self.dropout)
105
 
106
  def forward(self, x):
107
  bsz, seq_len, d_in = x.size()
108
 
109
+ qkv = self.attn(x)
110
  q, k, v = torch.chunk(qkv, 3, dim=2)
111
 
112
  q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
 
121
 
122
  y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in)
123
 
124
+ y = self.resid_dropout(self.o_proj(y))
125
  return y
configuration_minitransformer.py CHANGED
@@ -8,7 +8,7 @@ class MiniTransformerConfig(PretrainedConfig):
8
  self,
9
  bsz: int = 1,
10
  n_embd: int = 768,
11
- n_heads: int = 24,
12
  n_layers: int = 27,
13
  seq_len: int = 8192,
14
  window_size: int = 8192,
 
8
  self,
9
  bsz: int = 1,
10
  n_embd: int = 768,
11
+ n_heads: int = 12,
12
  n_layers: int = 27,
13
  seq_len: int = 8192,
14
  window_size: int = 8192,
modeling_minitransformer.py CHANGED
@@ -121,12 +121,12 @@ class MiniTransformer(PreTrainedModel):
121
  elif isinstance(module, nn.Embedding):
122
  torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
123
  elif isinstance(module, Attention):
124
- torch.nn.init.xavier_normal_(module.c_attn.weight)
125
- torch.nn.init.xavier_normal_(module.c_proj.weight)
126
- if module.c_attn.bias is not None:
127
- torch.nn.init.zeros_(module.c_attn.bias)
128
- if module.c_proj.bias is not None:
129
- torch.nn.init.zeros_(module.c_proj.bias)
130
  @staticmethod
131
  def top_k_top_p_filtering(
132
  logits: torch.Tensor,
 
121
  elif isinstance(module, nn.Embedding):
122
  torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
123
  elif isinstance(module, Attention):
124
+ torch.nn.init.xavier_normal_(module.attn.weight)
125
+ torch.nn.init.xavier_normal_(module.o_proj.weight)
126
+ if module.attn.bias is not None:
127
+ torch.nn.init.zeros_(module.attn.bias)
128
+ if module.o_proj.bias is not None:
129
+ torch.nn.init.zeros_(module.o_proj.bias)
130
  @staticmethod
131
  def top_k_top_p_filtering(
132
  logits: torch.Tensor,