add custom ops
Browse files- .gitattributes +3 -0
- config.json +1 -1
- lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so +3 -0
- lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so +3 -0
- lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so +3 -0
- modeling_step1.py +73 -51
- tokenizer_config.json +3 -1
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
37 |
+
lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
38 |
+
lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
|
config.json
CHANGED
@@ -6,7 +6,7 @@
|
|
6 |
"AutoConfig": "configuration_step1.Step1Config",
|
7 |
"AutoModelForCausalLM": "modeling_step1.Step1ForCausalLM"
|
8 |
},
|
9 |
-
"model_type": "
|
10 |
"bos_token_id": 1,
|
11 |
"pad_token_id": 0,
|
12 |
"eos_token_id": 3,
|
|
|
6 |
"AutoConfig": "configuration_step1.Step1Config",
|
7 |
"AutoModelForCausalLM": "modeling_step1.Step1ForCausalLM"
|
8 |
},
|
9 |
+
"model_type": "step1",
|
10 |
"bos_token_id": 1,
|
11 |
"pad_token_id": 0,
|
12 |
"eos_token_id": 3,
|
lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e018916e5e93fb904be6b34af32e71d03ba9e888d8c086a43a5c9fcacda661a1
|
3 |
+
size 31250408
|
lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ee23bba95f7806364e101e285720892b755a176d603842fb4646822800ac2344
|
3 |
+
size 31250472
|
lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6fa1a77f035203ff90a071218f775381f705269ef454163474d22501684b7e1f
|
3 |
+
size 31258792
|
modeling_step1.py
CHANGED
@@ -76,6 +76,11 @@ class StepAttention(torch.nn.Module):
|
|
76 |
|
77 |
self.layer_idx = layer_idx
|
78 |
|
|
|
|
|
|
|
|
|
|
|
79 |
def forward(
|
80 |
self,
|
81 |
x: torch.Tensor,
|
@@ -95,24 +100,31 @@ class StepAttention(torch.nn.Module):
|
|
95 |
k = rearrange(k, "b s (g d) -> b s g d", g=self.num_groups)
|
96 |
v = rearrange(v, "b s (g d) -> b s g d", g=self.num_groups)
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
q, k, v, attn_mask=attention_mask
|
111 |
-
)
|
112 |
-
o = o.transpose(1, 2).flatten(-2, -1)
|
113 |
|
114 |
-
|
115 |
-
return
|
116 |
|
117 |
|
118 |
class StepMLP(torch.nn.Module):
|
@@ -153,26 +165,26 @@ class StepLayer(torch.nn.Module):
|
|
153 |
|
154 |
def forward(
|
155 |
self,
|
156 |
-
|
157 |
attention_mask: Optional[torch.Tensor] = None,
|
158 |
past_key_value: Optional[Cache] = None,
|
|
|
159 |
cache_position: Optional[torch.LongTensor] = None,
|
160 |
):
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
|
166 |
-
|
|
|
|
|
|
|
167 |
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
x = x + f(x)
|
174 |
-
|
175 |
-
return x
|
176 |
|
177 |
|
178 |
class StepPreTrainedModel(PreTrainedModel):
|
@@ -241,9 +253,16 @@ class Step1Model(StepPreTrainedModel):
|
|
241 |
return_dict: Optional[bool] = None,
|
242 |
cache_position: Optional[torch.LongTensor] = None,
|
243 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
244 |
-
output_attentions =
|
245 |
-
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
248 |
return_dict = (
|
249 |
return_dict if return_dict is not None else self.config.use_return_dict
|
@@ -274,22 +293,37 @@ class Step1Model(StepPreTrainedModel):
|
|
274 |
|
275 |
hidden_states = inputs_embeds
|
276 |
|
|
|
|
|
|
|
|
|
277 |
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
|
|
|
|
|
278 |
layer_outputs = decoder_layer(
|
279 |
hidden_states,
|
280 |
attention_mask=causal_mask,
|
281 |
past_key_value=past_key_values,
|
282 |
cache_position=cache_position,
|
|
|
283 |
)
|
284 |
|
285 |
-
hidden_states = layer_outputs
|
|
|
|
|
|
|
286 |
|
287 |
hidden_states = self.norm(hidden_states)
|
288 |
|
|
|
|
|
|
|
|
|
289 |
output = BaseModelOutputWithPast(
|
290 |
last_hidden_state=hidden_states,
|
291 |
past_key_values=past_key_values if use_cache else None,
|
292 |
-
hidden_states=
|
293 |
attentions=None,
|
294 |
)
|
295 |
return output if return_dict else output.to_tuple()
|
@@ -313,12 +347,6 @@ class Step1ForCausalLM(StepPreTrainedModel, GenerationMixin):
|
|
313 |
def set_input_embeddings(self, value):
|
314 |
self.model.embed_tokens = value
|
315 |
|
316 |
-
# def get_output_embeddings(self):
|
317 |
-
# return self.lm_head
|
318 |
-
|
319 |
-
# def set_output_embeddings(self, new_embeddings):
|
320 |
-
# self.lm_head = new_embeddings
|
321 |
-
|
322 |
def set_decoder(self, decoder):
|
323 |
self.model = decoder
|
324 |
|
@@ -338,14 +366,11 @@ class Step1ForCausalLM(StepPreTrainedModel, GenerationMixin):
|
|
338 |
output_hidden_states: Optional[bool] = None,
|
339 |
return_dict: Optional[bool] = None,
|
340 |
cache_position: Optional[torch.LongTensor] = None,
|
341 |
-
num_logits_to_keep: int = 0,
|
342 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
# output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
348 |
-
# )
|
349 |
return_dict = (
|
350 |
return_dict if return_dict is not None else self.config.use_return_dict
|
351 |
)
|
@@ -368,15 +393,12 @@ class Step1ForCausalLM(StepPreTrainedModel, GenerationMixin):
|
|
368 |
|
369 |
logits = self.lm_head(hidden_states)
|
370 |
|
371 |
-
# logits = torch.matmul(hidden_states, lm_stat)
|
372 |
-
|
373 |
loss = None
|
374 |
if labels is not None:
|
375 |
loss = self.loss_function(
|
376 |
logits=logits,
|
377 |
labels=labels,
|
378 |
vocab_size=self.config.vocab_size,
|
379 |
-
**kwargs
|
380 |
)
|
381 |
|
382 |
if not return_dict:
|
|
|
76 |
|
77 |
self.layer_idx = layer_idx
|
78 |
|
79 |
+
def flash_attn_func(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
|
80 |
+
return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
|
81 |
+
softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
|
82 |
+
return torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
|
83 |
+
|
84 |
def forward(
|
85 |
self,
|
86 |
x: torch.Tensor,
|
|
|
100 |
k = rearrange(k, "b s (g d) -> b s g d", g=self.num_groups)
|
101 |
v = rearrange(v, "b s (g d) -> b s g d", g=self.num_groups)
|
102 |
|
103 |
+
try:
|
104 |
+
if self.head_dim not in (64, 128):
|
105 |
+
raise ValueError("head_dim must be 64 or 128")
|
106 |
+
attn_output = self.flash_attn_func(q, k, v)
|
107 |
+
attn_output = attn_output.flatten(-2, -1)
|
108 |
+
except:
|
109 |
+
k = k.repeat_interleave(self.num_heads // self.num_groups, dim=-2)
|
110 |
+
v = v.repeat_interleave(self.num_heads // self.num_groups, dim=-2)
|
111 |
+
|
112 |
+
attention_mask = build_alibi_cache(
|
113 |
+
k.size(1), self.num_heads, dtype=q.dtype, device=q.device
|
114 |
+
)[:, :, -q.size(1) :, :].contiguous()
|
115 |
+
|
116 |
+
q = q.transpose(1, 2)
|
117 |
+
k = k.transpose(1, 2)
|
118 |
+
v = v.transpose(1, 2)
|
119 |
+
|
120 |
+
attn_output: torch.Tensor = torch.nn.functional.scaled_dot_product_attention(
|
121 |
+
q, k, v, attn_mask=attention_mask
|
122 |
+
)
|
123 |
|
124 |
+
attn_output = attn_output.transpose(1, 2).flatten(-2, -1)
|
|
|
|
|
|
|
125 |
|
126 |
+
out = self.o_proj(attn_output)
|
127 |
+
return out, None # attn weights are not returned
|
128 |
|
129 |
|
130 |
class StepMLP(torch.nn.Module):
|
|
|
165 |
|
166 |
def forward(
|
167 |
self,
|
168 |
+
hidden_states: torch.Tensor,
|
169 |
attention_mask: Optional[torch.Tensor] = None,
|
170 |
past_key_value: Optional[Cache] = None,
|
171 |
+
output_attentions: Optional[bool] = False,
|
172 |
cache_position: Optional[torch.LongTensor] = None,
|
173 |
):
|
174 |
+
residual = hidden_states
|
175 |
+
hidden_states = self.input_layernorm(hidden_states)
|
176 |
+
hidden_states, self_attn_weights = self.self_attn(hidden_states, past_key_value, attention_mask, cache_position)
|
177 |
+
hidden_states = residual + hidden_states
|
178 |
|
179 |
+
residual = hidden_states
|
180 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
181 |
+
hidden_states = self.mlp(hidden_states)
|
182 |
+
hidden_states = residual + hidden_states
|
183 |
|
184 |
+
outputs = (hidden_states, )
|
185 |
+
if output_attentions:
|
186 |
+
outputs += (self_attn_weights,)
|
187 |
+
return outputs
|
|
|
|
|
|
|
|
|
188 |
|
189 |
|
190 |
class StepPreTrainedModel(PreTrainedModel):
|
|
|
253 |
return_dict: Optional[bool] = None,
|
254 |
cache_position: Optional[torch.LongTensor] = None,
|
255 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
256 |
+
output_attentions = (
|
257 |
+
output_attentions
|
258 |
+
if output_attentions is not None
|
259 |
+
else self.config.output_attentions
|
260 |
+
)
|
261 |
+
output_hidden_states = (
|
262 |
+
output_hidden_states
|
263 |
+
if output_hidden_states is not None
|
264 |
+
else self.config.output_hidden_states
|
265 |
+
)
|
266 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
267 |
return_dict = (
|
268 |
return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
293 |
|
294 |
hidden_states = inputs_embeds
|
295 |
|
296 |
+
# decoder layers
|
297 |
+
all_hidden_states = () if output_hidden_states else None
|
298 |
+
all_self_attns = () if output_attentions else None
|
299 |
+
|
300 |
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
301 |
+
if output_hidden_states:
|
302 |
+
all_hidden_states += (hidden_states,)
|
303 |
+
|
304 |
layer_outputs = decoder_layer(
|
305 |
hidden_states,
|
306 |
attention_mask=causal_mask,
|
307 |
past_key_value=past_key_values,
|
308 |
cache_position=cache_position,
|
309 |
+
output_attentions=output_attentions,
|
310 |
)
|
311 |
|
312 |
+
hidden_states = layer_outputs[0]
|
313 |
+
|
314 |
+
if output_attentions:
|
315 |
+
all_self_attns += (layer_outputs[1],)
|
316 |
|
317 |
hidden_states = self.norm(hidden_states)
|
318 |
|
319 |
+
# add hidden states from the last decoder layer
|
320 |
+
if output_hidden_states:
|
321 |
+
all_hidden_states += (hidden_states,)
|
322 |
+
|
323 |
output = BaseModelOutputWithPast(
|
324 |
last_hidden_state=hidden_states,
|
325 |
past_key_values=past_key_values if use_cache else None,
|
326 |
+
hidden_states=all_hidden_states,
|
327 |
attentions=None,
|
328 |
)
|
329 |
return output if return_dict else output.to_tuple()
|
|
|
347 |
def set_input_embeddings(self, value):
|
348 |
self.model.embed_tokens = value
|
349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
def set_decoder(self, decoder):
|
351 |
self.model = decoder
|
352 |
|
|
|
366 |
output_hidden_states: Optional[bool] = None,
|
367 |
return_dict: Optional[bool] = None,
|
368 |
cache_position: Optional[torch.LongTensor] = None,
|
|
|
369 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
370 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
371 |
+
output_hidden_states = (
|
372 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
373 |
+
)
|
|
|
|
|
374 |
return_dict = (
|
375 |
return_dict if return_dict is not None else self.config.use_return_dict
|
376 |
)
|
|
|
393 |
|
394 |
logits = self.lm_head(hidden_states)
|
395 |
|
|
|
|
|
396 |
loss = None
|
397 |
if labels is not None:
|
398 |
loss = self.loss_function(
|
399 |
logits=logits,
|
400 |
labels=labels,
|
401 |
vocab_size=self.config.vocab_size,
|
|
|
402 |
)
|
403 |
|
404 |
if not return_dict:
|
tokenizer_config.json
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
{
|
|
|
2 |
"bos_token": "<s>",
|
3 |
"clean_up_tokenization_spaces": false,
|
4 |
"eos_token": "</s>",
|
@@ -9,6 +10,7 @@
|
|
9 |
"sp_model_kwargs": {},
|
10 |
"tokenizer_class": "LlamaTokenizer",
|
11 |
"unk_token": "<unk>",
|
12 |
-
"use_default_system_prompt": false
|
|
|
13 |
}
|
14 |
|
|
|
1 |
{
|
2 |
+
"add_bos_token": true,
|
3 |
"bos_token": "<s>",
|
4 |
"clean_up_tokenization_spaces": false,
|
5 |
"eos_token": "</s>",
|
|
|
10 |
"sp_model_kwargs": {},
|
11 |
"tokenizer_class": "LlamaTokenizer",
|
12 |
"unk_token": "<unk>",
|
13 |
+
"use_default_system_prompt": false,
|
14 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|BOT|>system\nYou are a helpful assistant.<|EOT|>' }}{% endif %}{{'<|BOT|>' + (message['role'] if message['role'] != 'user' else 'human') + '\n' + message['content'] + '<|EOT|>'}}{% endfor %}{% if add_generation_prompt %}{{ '<|BOT|>assistant\n' }}{% endif %}"
|
15 |
}
|
16 |
|