jon-tow commited on
Commit
8e5b1aa
1 Parent(s): 589adbf

feat: add dropout support

Browse files
configuration_stablelm_epoch.py CHANGED
@@ -65,6 +65,8 @@ class StableLMEpochConfig(PretrainedConfig):
65
  Whether or not the model should use bias for qkv layers.
66
  tie_word_embeddings(`bool`, *optional*, defaults to `False`):
67
  Whether to tie weight embeddings
 
 
68
  """
69
  model_type = "stablelm_epoch"
70
  keys_to_ignore_at_inference = ["past_key_values"]
@@ -88,6 +90,7 @@ class StableLMEpochConfig(PretrainedConfig):
88
  bos_token_id=0,
89
  eos_token_id=2,
90
  tie_word_embeddings=False,
 
91
  **kwargs,
92
  ):
93
  self.vocab_size = vocab_size
@@ -105,6 +108,7 @@ class StableLMEpochConfig(PretrainedConfig):
105
  self.use_cache = use_cache
106
  self.use_qkv_bias = use_qkv_bias
107
  self.tie_word_embeddings = tie_word_embeddings
 
108
  super().__init__(
109
  bos_token_id=bos_token_id,
110
  eos_token_id=eos_token_id,
 
65
  Whether or not the model should use bias for qkv layers.
66
  tie_word_embeddings(`bool`, *optional*, defaults to `False`):
67
  Whether to tie weight embeddings
68
+ attention_dropout (`float`, *optional*, defaults to 0.0):
69
+ The dropout ratio for the attention probabilities.
70
  """
71
  model_type = "stablelm_epoch"
72
  keys_to_ignore_at_inference = ["past_key_values"]
 
90
  bos_token_id=0,
91
  eos_token_id=2,
92
  tie_word_embeddings=False,
93
+ attention_dropout: float = 0.0,
94
  **kwargs,
95
  ):
96
  self.vocab_size = vocab_size
 
108
  self.use_cache = use_cache
109
  self.use_qkv_bias = use_qkv_bias
110
  self.tie_word_embeddings = tie_word_embeddings
111
+ self.attention_dropout = attention_dropout
112
  super().__init__(
113
  bos_token_id=bos_token_id,
114
  eos_token_id=eos_token_id,
modeling_stablelm_epoch.py CHANGED
@@ -191,6 +191,7 @@ class Attention(nn.Module):
191
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
192
  self.max_position_embeddings = config.max_position_embeddings
193
  self.is_causal = True
 
194
 
195
  if (self.head_dim * self.num_heads) != self.hidden_size:
196
  raise ValueError(
@@ -275,6 +276,7 @@ class Attention(nn.Module):
275
 
276
  # Upcast attention to fp32
277
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
278
  attn_output = torch.matmul(attn_weights, value_states)
279
 
280
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
191
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
192
  self.max_position_embeddings = config.max_position_embeddings
193
  self.is_causal = True
194
+ self.attention_dropout = config.attention_dropout
195
 
196
  if (self.head_dim * self.num_heads) != self.hidden_size:
197
  raise ValueError(
 
276
 
277
  # Upcast attention to fp32
278
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
279
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
280
  attn_output = torch.matmul(attn_weights, value_states)
281
 
282
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):