buyun commited on
Commit
aa82b18
·
verified ·
1 Parent(s): 40bdebd

add custom ops

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
37
+ lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
38
+ lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
config.json CHANGED
@@ -6,7 +6,7 @@
6
  "AutoConfig": "configuration_step1.Step1Config",
7
  "AutoModelForCausalLM": "modeling_step1.Step1ForCausalLM"
8
  },
9
- "model_type": "step_audio",
10
  "bos_token_id": 1,
11
  "pad_token_id": 0,
12
  "eos_token_id": 3,
 
6
  "AutoConfig": "configuration_step1.Step1Config",
7
  "AutoModelForCausalLM": "modeling_step1.Step1ForCausalLM"
8
  },
9
+ "model_type": "step1",
10
  "bos_token_id": 1,
11
  "pad_token_id": 0,
12
  "eos_token_id": 3,
lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e018916e5e93fb904be6b34af32e71d03ba9e888d8c086a43a5c9fcacda661a1
3
+ size 31250408
lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee23bba95f7806364e101e285720892b755a176d603842fb4646822800ac2344
3
+ size 31250472
lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fa1a77f035203ff90a071218f775381f705269ef454163474d22501684b7e1f
3
+ size 31258792
modeling_step1.py CHANGED
@@ -76,6 +76,11 @@ class StepAttention(torch.nn.Module):
76
 
77
  self.layer_idx = layer_idx
78
 
 
 
 
 
 
79
  def forward(
80
  self,
81
  x: torch.Tensor,
@@ -95,24 +100,31 @@ class StepAttention(torch.nn.Module):
95
  k = rearrange(k, "b s (g d) -> b s g d", g=self.num_groups)
96
  v = rearrange(v, "b s (g d) -> b s g d", g=self.num_groups)
97
 
98
- k = k.repeat_interleave(self.num_heads // self.num_groups, dim=-2)
99
- v = v.repeat_interleave(self.num_heads // self.num_groups, dim=-2)
100
-
101
- attention_mask = build_alibi_cache(
102
- k.size(1), self.num_heads, dtype=q.dtype, device=q.device
103
- )[:, :, -q.size(1) :, :].contiguous()
104
-
105
- q = q.transpose(1, 2)
106
- k = k.transpose(1, 2)
107
- v = v.transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
108
 
109
- o: torch.Tensor = torch.nn.functional.scaled_dot_product_attention(
110
- q, k, v, attn_mask=attention_mask
111
- )
112
- o = o.transpose(1, 2).flatten(-2, -1)
113
 
114
- o = self.o_proj(o)
115
- return o
116
 
117
 
118
  class StepMLP(torch.nn.Module):
@@ -153,26 +165,26 @@ class StepLayer(torch.nn.Module):
153
 
154
  def forward(
155
  self,
156
- x,
157
  attention_mask: Optional[torch.Tensor] = None,
158
  past_key_value: Optional[Cache] = None,
 
159
  cache_position: Optional[torch.LongTensor] = None,
160
  ):
161
- def f(x):
162
- x = self.input_layernorm(x)
163
- x = self.self_attn(x, past_key_value, attention_mask, cache_position)
164
- return x
165
 
166
- x = x + f(x)
 
 
 
167
 
168
- def f(x):
169
- x = self.post_attention_layernorm(x)
170
- x = self.mlp(x)
171
- return x
172
-
173
- x = x + f(x)
174
-
175
- return x
176
 
177
 
178
  class StepPreTrainedModel(PreTrainedModel):
@@ -241,9 +253,16 @@ class Step1Model(StepPreTrainedModel):
241
  return_dict: Optional[bool] = None,
242
  cache_position: Optional[torch.LongTensor] = None,
243
  ) -> Union[Tuple, BaseModelOutputWithPast]:
244
- output_attentions = False
245
- output_hidden_states = False
246
-
 
 
 
 
 
 
 
247
  use_cache = use_cache if use_cache is not None else self.config.use_cache
248
  return_dict = (
249
  return_dict if return_dict is not None else self.config.use_return_dict
@@ -274,22 +293,37 @@ class Step1Model(StepPreTrainedModel):
274
 
275
  hidden_states = inputs_embeds
276
 
 
 
 
 
277
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
 
 
 
278
  layer_outputs = decoder_layer(
279
  hidden_states,
280
  attention_mask=causal_mask,
281
  past_key_value=past_key_values,
282
  cache_position=cache_position,
 
283
  )
284
 
285
- hidden_states = layer_outputs
 
 
 
286
 
287
  hidden_states = self.norm(hidden_states)
288
 
 
 
 
 
289
  output = BaseModelOutputWithPast(
290
  last_hidden_state=hidden_states,
291
  past_key_values=past_key_values if use_cache else None,
292
- hidden_states=hidden_states,
293
  attentions=None,
294
  )
295
  return output if return_dict else output.to_tuple()
@@ -313,12 +347,6 @@ class Step1ForCausalLM(StepPreTrainedModel, GenerationMixin):
313
  def set_input_embeddings(self, value):
314
  self.model.embed_tokens = value
315
 
316
- # def get_output_embeddings(self):
317
- # return self.lm_head
318
-
319
- # def set_output_embeddings(self, new_embeddings):
320
- # self.lm_head = new_embeddings
321
-
322
  def set_decoder(self, decoder):
323
  self.model = decoder
324
 
@@ -338,14 +366,11 @@ class Step1ForCausalLM(StepPreTrainedModel, GenerationMixin):
338
  output_hidden_states: Optional[bool] = None,
339
  return_dict: Optional[bool] = None,
340
  cache_position: Optional[torch.LongTensor] = None,
341
- num_logits_to_keep: int = 0,
342
  ) -> Union[Tuple, CausalLMOutputWithPast]:
343
- # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
344
- output_attentions = False
345
- output_hidden_states = False
346
- # output_hidden_states = (
347
- # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
348
- # )
349
  return_dict = (
350
  return_dict if return_dict is not None else self.config.use_return_dict
351
  )
@@ -368,15 +393,12 @@ class Step1ForCausalLM(StepPreTrainedModel, GenerationMixin):
368
 
369
  logits = self.lm_head(hidden_states)
370
 
371
- # logits = torch.matmul(hidden_states, lm_stat)
372
-
373
  loss = None
374
  if labels is not None:
375
  loss = self.loss_function(
376
  logits=logits,
377
  labels=labels,
378
  vocab_size=self.config.vocab_size,
379
- **kwargs
380
  )
381
 
382
  if not return_dict:
 
76
 
77
  self.layer_idx = layer_idx
78
 
79
+ def flash_attn_func(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
80
+ return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
81
+ softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
82
+ return torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
83
+
84
  def forward(
85
  self,
86
  x: torch.Tensor,
 
100
  k = rearrange(k, "b s (g d) -> b s g d", g=self.num_groups)
101
  v = rearrange(v, "b s (g d) -> b s g d", g=self.num_groups)
102
 
103
+ try:
104
+ if self.head_dim not in (64, 128):
105
+ raise ValueError("head_dim must be 64 or 128")
106
+ attn_output = self.flash_attn_func(q, k, v)
107
+ attn_output = attn_output.flatten(-2, -1)
108
+ except:
109
+ k = k.repeat_interleave(self.num_heads // self.num_groups, dim=-2)
110
+ v = v.repeat_interleave(self.num_heads // self.num_groups, dim=-2)
111
+
112
+ attention_mask = build_alibi_cache(
113
+ k.size(1), self.num_heads, dtype=q.dtype, device=q.device
114
+ )[:, :, -q.size(1) :, :].contiguous()
115
+
116
+ q = q.transpose(1, 2)
117
+ k = k.transpose(1, 2)
118
+ v = v.transpose(1, 2)
119
+
120
+ attn_output: torch.Tensor = torch.nn.functional.scaled_dot_product_attention(
121
+ q, k, v, attn_mask=attention_mask
122
+ )
123
 
124
+ attn_output = attn_output.transpose(1, 2).flatten(-2, -1)
 
 
 
125
 
126
+ out = self.o_proj(attn_output)
127
+ return out, None # attn weights are not returned
128
 
129
 
130
  class StepMLP(torch.nn.Module):
 
165
 
166
  def forward(
167
  self,
168
+ hidden_states: torch.Tensor,
169
  attention_mask: Optional[torch.Tensor] = None,
170
  past_key_value: Optional[Cache] = None,
171
+ output_attentions: Optional[bool] = False,
172
  cache_position: Optional[torch.LongTensor] = None,
173
  ):
174
+ residual = hidden_states
175
+ hidden_states = self.input_layernorm(hidden_states)
176
+ hidden_states, self_attn_weights = self.self_attn(hidden_states, past_key_value, attention_mask, cache_position)
177
+ hidden_states = residual + hidden_states
178
 
179
+ residual = hidden_states
180
+ hidden_states = self.post_attention_layernorm(hidden_states)
181
+ hidden_states = self.mlp(hidden_states)
182
+ hidden_states = residual + hidden_states
183
 
184
+ outputs = (hidden_states, )
185
+ if output_attentions:
186
+ outputs += (self_attn_weights,)
187
+ return outputs
 
 
 
 
188
 
189
 
190
  class StepPreTrainedModel(PreTrainedModel):
 
253
  return_dict: Optional[bool] = None,
254
  cache_position: Optional[torch.LongTensor] = None,
255
  ) -> Union[Tuple, BaseModelOutputWithPast]:
256
+ output_attentions = (
257
+ output_attentions
258
+ if output_attentions is not None
259
+ else self.config.output_attentions
260
+ )
261
+ output_hidden_states = (
262
+ output_hidden_states
263
+ if output_hidden_states is not None
264
+ else self.config.output_hidden_states
265
+ )
266
  use_cache = use_cache if use_cache is not None else self.config.use_cache
267
  return_dict = (
268
  return_dict if return_dict is not None else self.config.use_return_dict
 
293
 
294
  hidden_states = inputs_embeds
295
 
296
+ # decoder layers
297
+ all_hidden_states = () if output_hidden_states else None
298
+ all_self_attns = () if output_attentions else None
299
+
300
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
301
+ if output_hidden_states:
302
+ all_hidden_states += (hidden_states,)
303
+
304
  layer_outputs = decoder_layer(
305
  hidden_states,
306
  attention_mask=causal_mask,
307
  past_key_value=past_key_values,
308
  cache_position=cache_position,
309
+ output_attentions=output_attentions,
310
  )
311
 
312
+ hidden_states = layer_outputs[0]
313
+
314
+ if output_attentions:
315
+ all_self_attns += (layer_outputs[1],)
316
 
317
  hidden_states = self.norm(hidden_states)
318
 
319
+ # add hidden states from the last decoder layer
320
+ if output_hidden_states:
321
+ all_hidden_states += (hidden_states,)
322
+
323
  output = BaseModelOutputWithPast(
324
  last_hidden_state=hidden_states,
325
  past_key_values=past_key_values if use_cache else None,
326
+ hidden_states=all_hidden_states,
327
  attentions=None,
328
  )
329
  return output if return_dict else output.to_tuple()
 
347
  def set_input_embeddings(self, value):
348
  self.model.embed_tokens = value
349
 
 
 
 
 
 
 
350
  def set_decoder(self, decoder):
351
  self.model = decoder
352
 
 
366
  output_hidden_states: Optional[bool] = None,
367
  return_dict: Optional[bool] = None,
368
  cache_position: Optional[torch.LongTensor] = None,
 
369
  ) -> Union[Tuple, CausalLMOutputWithPast]:
370
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
371
+ output_hidden_states = (
372
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
373
+ )
 
 
374
  return_dict = (
375
  return_dict if return_dict is not None else self.config.use_return_dict
376
  )
 
393
 
394
  logits = self.lm_head(hidden_states)
395
 
 
 
396
  loss = None
397
  if labels is not None:
398
  loss = self.loss_function(
399
  logits=logits,
400
  labels=labels,
401
  vocab_size=self.config.vocab_size,
 
402
  )
403
 
404
  if not return_dict:
tokenizer_config.json CHANGED
@@ -1,4 +1,5 @@
1
  {
 
2
  "bos_token": "<s>",
3
  "clean_up_tokenization_spaces": false,
4
  "eos_token": "</s>",
@@ -9,6 +10,7 @@
9
  "sp_model_kwargs": {},
10
  "tokenizer_class": "LlamaTokenizer",
11
  "unk_token": "<unk>",
12
- "use_default_system_prompt": false
 
13
  }
14
 
 
1
  {
2
+ "add_bos_token": true,
3
  "bos_token": "<s>",
4
  "clean_up_tokenization_spaces": false,
5
  "eos_token": "</s>",
 
10
  "sp_model_kwargs": {},
11
  "tokenizer_class": "LlamaTokenizer",
12
  "unk_token": "<unk>",
13
+ "use_default_system_prompt": false,
14
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|BOT|>system\nYou are a helpful assistant.<|EOT|>' }}{% endif %}{{'<|BOT|>' + (message['role'] if message['role'] != 'user' else 'human') + '\n' + message['content'] + '<|EOT|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|BOT|>assistant\n' }}{% endif %}"
15
  }
16