Upload model
Browse files- config.json +6 -1
- modeling_mamba.py +74 -80
config.json
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
{
|
|
|
|
|
|
|
2 |
"auto_map": {
|
3 |
-
"AutoConfig": "configuration_mamba.MambaConfig"
|
|
|
4 |
},
|
5 |
"bias": false,
|
6 |
"conv_bias": true,
|
@@ -14,6 +18,7 @@
|
|
14 |
"model_type": "mamba",
|
15 |
"n_layer": 24,
|
16 |
"pad_vocab_size_multiple": 8,
|
|
|
17 |
"transformers_version": "4.37.2",
|
18 |
"vocab_size": 50280
|
19 |
}
|
|
|
1 |
{
|
2 |
+
"architectures": [
|
3 |
+
"MambaLMHeadModel"
|
4 |
+
],
|
5 |
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_mamba.MambaConfig",
|
7 |
+
"AutoModelForCausalLM": "modeling_mamba.MambaLMHeadModel"
|
8 |
},
|
9 |
"bias": false,
|
10 |
"conv_bias": true,
|
|
|
18 |
"model_type": "mamba",
|
19 |
"n_layer": 24,
|
20 |
"pad_vocab_size_multiple": 8,
|
21 |
+
"torch_dtype": "float32",
|
22 |
"transformers_version": "4.37.2",
|
23 |
"vocab_size": 50280
|
24 |
}
|
modeling_mamba.py
CHANGED
@@ -380,23 +380,17 @@ class MambaModel(MambaPretrainedModel):
|
|
380 |
**kwargs,
|
381 |
)
|
382 |
|
383 |
-
# self.embedding = nn.Embedding(
|
384 |
-
# num_embeddings=config.vocab_size,
|
385 |
-
# embedding_dim=config.d_model,
|
386 |
-
# )
|
387 |
-
|
388 |
-
|
389 |
self.embedding = nn.Embedding(
|
390 |
-
num_embeddings=config.vocab_size,
|
391 |
-
embedding_dim=config.d_model,
|
392 |
)
|
393 |
|
394 |
self.layers = nn.ModuleList(
|
395 |
-
[ResidualBlock(config) for _ in range(self.config.n_layer)]
|
396 |
)
|
397 |
# self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
|
398 |
# # self.norm_f = RMSNorm(d_model=embedding_dim)
|
399 |
-
self.norm_f = RMSNorm(config.d_model)
|
400 |
|
401 |
# self.gradient_checkpointing = False
|
402 |
# # self.post_init()
|
@@ -454,54 +448,54 @@ class MambaModel(MambaPretrainedModel):
|
|
454 |
# def set_input_embeddings(self, value):
|
455 |
# self.embed_out = value
|
456 |
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
|
506 |
|
507 |
# Influences:
|
@@ -538,31 +532,31 @@ class MambaLMHeadModel(MambaPretrainedModel):
|
|
538 |
# Initialize weights and apply final processing
|
539 |
self.post_init()
|
540 |
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
|
554 |
-
|
555 |
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
|
567 |
# # def prepare_inputs_for_generation(
|
568 |
# # self, input_ids, attention_mask=None, **model_kwargs
|
|
|
380 |
**kwargs,
|
381 |
)
|
382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
self.embedding = nn.Embedding(
|
384 |
+
num_embeddings=self.config.vocab_size,
|
385 |
+
embedding_dim=self.config.d_model,
|
386 |
)
|
387 |
|
388 |
self.layers = nn.ModuleList(
|
389 |
+
[ResidualBlock(self.config) for _ in range(self.config.n_layer)]
|
390 |
)
|
391 |
# self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
|
392 |
# # self.norm_f = RMSNorm(d_model=embedding_dim)
|
393 |
+
self.norm_f = RMSNorm(self.config.d_model)
|
394 |
|
395 |
# self.gradient_checkpointing = False
|
396 |
# # self.post_init()
|
|
|
448 |
# def set_input_embeddings(self, value):
|
449 |
# self.embed_out = value
|
450 |
|
451 |
+
def forward(
|
452 |
+
self,
|
453 |
+
input_ids: torch.LongTensor = None,
|
454 |
+
output_hidden_states=False,
|
455 |
+
return_dict: Optional[bool] = None,
|
456 |
+
**kwargs,
|
457 |
+
# ) -> BaseModelOutput:
|
458 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
459 |
+
batch_size = input_ids.shape[0]
|
460 |
+
hidden_size = self.config.hidden_size
|
461 |
+
hidden_states: Tuple[Tensor[(batch_size, sequence_length, hidden_size)]] = ()
|
462 |
+
sequence_length = input_ids.shape[1]
|
463 |
+
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
464 |
+
|
465 |
+
last_hidden_state = self.embedding(input_ids)
|
466 |
+
assert last_hidden_state.shape == (
|
467 |
+
batch_size,
|
468 |
+
sequence_length,
|
469 |
+
hidden_size,
|
470 |
+
), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
471 |
+
hidden_states += (last_hidden_state,)
|
472 |
+
|
473 |
+
for layer in self.layers:
|
474 |
+
last_hidden_state = layer(last_hidden_state)
|
475 |
+
assert last_hidden_state.shape == (
|
476 |
+
batch_size,
|
477 |
+
sequence_length,
|
478 |
+
hidden_size,
|
479 |
+
), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
480 |
+
hidden_states += (last_hidden_state,)
|
481 |
+
|
482 |
+
last_hidden_state = self.norm_f(last_hidden_state)
|
483 |
+
assert last_hidden_state.shape == (
|
484 |
+
batch_size,
|
485 |
+
sequence_length,
|
486 |
+
hidden_size,
|
487 |
+
), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
488 |
+
hidden_states += (last_hidden_state,)
|
489 |
+
|
490 |
+
assert (
|
491 |
+
len(hidden_states) == self.config.n_layer + 2
|
492 |
+
), f"{len(hidden_states)} != {self.config.n_layer + 2}"
|
493 |
+
|
494 |
+
# return BaseModelOutput(
|
495 |
+
return BaseModelOutputWithPast(
|
496 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
497 |
+
last_hidden_state=last_hidden_state,
|
498 |
+
)
|
499 |
|
500 |
|
501 |
# Influences:
|
|
|
532 |
# Initialize weights and apply final processing
|
533 |
self.post_init()
|
534 |
|
535 |
+
def forward(
|
536 |
+
self, input_ids, output_hidden_states=False, **kwargs
|
537 |
+
) -> CausalLMOutput:
|
538 |
+
batch_size = input_ids.shape[0]
|
539 |
+
sequence_length = input_ids.shape[1]
|
540 |
+
vocab_size = self.config.vocab_size
|
541 |
+
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
542 |
+
|
543 |
+
outputs = self.backbone(
|
544 |
+
input_ids=input_ids,
|
545 |
+
output_hidden_states=output_hidden_states,
|
546 |
+
)
|
547 |
|
548 |
+
last_hidden_state = outputs.last_hidden_state
|
549 |
|
550 |
+
logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = (
|
551 |
+
self.lm_head(
|
552 |
+
last_hidden_state,
|
553 |
+
)
|
554 |
+
)
|
555 |
|
556 |
+
return CausalLMOutput(
|
557 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
558 |
+
logits=logits,
|
559 |
+
)
|
560 |
|
561 |
# # def prepare_inputs_for_generation(
|
562 |
# # self, input_ids, attention_mask=None, **model_kwargs
|