mjschock commited on
Commit
d78c29a
·
verified ·
1 Parent(s): 7134c6a

Upload model

Browse files
Files changed (2) hide show
  1. config.json +6 -1
  2. modeling_mamba.py +74 -80
config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_mamba.MambaConfig"
 
4
  },
5
  "bias": false,
6
  "conv_bias": true,
@@ -14,6 +18,7 @@
14
  "model_type": "mamba",
15
  "n_layer": 24,
16
  "pad_vocab_size_multiple": 8,
 
17
  "transformers_version": "4.37.2",
18
  "vocab_size": 50280
19
  }
 
1
  {
2
+ "architectures": [
3
+ "MambaLMHeadModel"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_mamba.MambaConfig",
7
+ "AutoModelForCausalLM": "modeling_mamba.MambaLMHeadModel"
8
  },
9
  "bias": false,
10
  "conv_bias": true,
 
18
  "model_type": "mamba",
19
  "n_layer": 24,
20
  "pad_vocab_size_multiple": 8,
21
+ "torch_dtype": "float32",
22
  "transformers_version": "4.37.2",
23
  "vocab_size": 50280
24
  }
modeling_mamba.py CHANGED
@@ -380,23 +380,17 @@ class MambaModel(MambaPretrainedModel):
380
  **kwargs,
381
  )
382
 
383
- # self.embedding = nn.Embedding(
384
- # num_embeddings=config.vocab_size,
385
- # embedding_dim=config.d_model,
386
- # )
387
-
388
-
389
  self.embedding = nn.Embedding(
390
- num_embeddings=config.vocab_size,
391
- embedding_dim=config.d_model,
392
  )
393
 
394
  self.layers = nn.ModuleList(
395
- [ResidualBlock(config) for _ in range(self.config.n_layer)]
396
  )
397
  # self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
398
  # # self.norm_f = RMSNorm(d_model=embedding_dim)
399
- self.norm_f = RMSNorm(config.d_model)
400
 
401
  # self.gradient_checkpointing = False
402
  # # self.post_init()
@@ -454,54 +448,54 @@ class MambaModel(MambaPretrainedModel):
454
  # def set_input_embeddings(self, value):
455
  # self.embed_out = value
456
 
457
- # def forward(
458
- # self,
459
- # input_ids: torch.LongTensor = None,
460
- # output_hidden_states=False,
461
- # return_dict: Optional[bool] = None,
462
- # **kwargs,
463
- # # ) -> BaseModelOutput:
464
- # ) -> Union[Tuple, BaseModelOutputWithPast]:
465
- # batch_size = input_ids.shape[0]
466
- # hidden_size = self.config.hidden_size
467
- # hidden_states: Tuple[Tensor[(batch_size, sequence_length, hidden_size)]] = ()
468
- # sequence_length = input_ids.shape[1]
469
- # output_hidden_states = output_hidden_states or self.config.output_hidden_states
470
-
471
- # last_hidden_state = self.embed_out(input_ids)
472
- # assert last_hidden_state.shape == (
473
- # batch_size,
474
- # sequence_length,
475
- # hidden_size,
476
- # ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
477
- # hidden_states += (last_hidden_state,)
478
-
479
- # for layer in self.layers:
480
- # last_hidden_state = layer(last_hidden_state)
481
- # assert last_hidden_state.shape == (
482
- # batch_size,
483
- # sequence_length,
484
- # hidden_size,
485
- # ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
486
- # hidden_states += (last_hidden_state,)
487
-
488
- # last_hidden_state = self.norm_f(last_hidden_state)
489
- # assert last_hidden_state.shape == (
490
- # batch_size,
491
- # sequence_length,
492
- # hidden_size,
493
- # ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
494
- # hidden_states += (last_hidden_state,)
495
-
496
- # assert (
497
- # len(hidden_states) == self.config.n_layer + 2
498
- # ), f"{len(hidden_states)} != {self.config.n_layer + 2}"
499
-
500
- # # return BaseModelOutput(
501
- # return BaseModelOutputWithPast(
502
- # hidden_states=hidden_states if output_hidden_states else None,
503
- # last_hidden_state=last_hidden_state,
504
- # )
505
 
506
 
507
  # Influences:
@@ -538,31 +532,31 @@ class MambaLMHeadModel(MambaPretrainedModel):
538
  # Initialize weights and apply final processing
539
  self.post_init()
540
 
541
- # # def forward(
542
- # # self, input_ids, output_hidden_states=False, **kwargs
543
- # # ) -> CausalLMOutput:
544
- # # batch_size = input_ids.shape[0]
545
- # # sequence_length = input_ids.shape[1]
546
- # # vocab_size = self.config.vocab_size
547
- # # output_hidden_states = output_hidden_states or self.config.output_hidden_states
548
-
549
- # # outputs = self.backbone(
550
- # # input_ids=input_ids,
551
- # # output_hidden_states=output_hidden_states,
552
- # # )
553
 
554
- # # last_hidden_state = outputs.last_hidden_state
555
 
556
- # # logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = (
557
- # # self.lm_head(
558
- # # last_hidden_state,
559
- # # )
560
- # # )
561
 
562
- # # return CausalLMOutput(
563
- # # hidden_states=outputs.hidden_states if output_hidden_states else None,
564
- # # logits=logits,
565
- # # )
566
 
567
  # # def prepare_inputs_for_generation(
568
  # # self, input_ids, attention_mask=None, **model_kwargs
 
380
  **kwargs,
381
  )
382
 
 
 
 
 
 
 
383
  self.embedding = nn.Embedding(
384
+ num_embeddings=self.config.vocab_size,
385
+ embedding_dim=self.config.d_model,
386
  )
387
 
388
  self.layers = nn.ModuleList(
389
+ [ResidualBlock(self.config) for _ in range(self.config.n_layer)]
390
  )
391
  # self.layers = nn.ModuleList([MambaBlock(config) for _ in range(config.n_layer)])
392
  # # self.norm_f = RMSNorm(d_model=embedding_dim)
393
+ self.norm_f = RMSNorm(self.config.d_model)
394
 
395
  # self.gradient_checkpointing = False
396
  # # self.post_init()
 
448
  # def set_input_embeddings(self, value):
449
  # self.embed_out = value
450
 
451
+ def forward(
452
+ self,
453
+ input_ids: torch.LongTensor = None,
454
+ output_hidden_states=False,
455
+ return_dict: Optional[bool] = None,
456
+ **kwargs,
457
+ # ) -> BaseModelOutput:
458
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
459
+ batch_size = input_ids.shape[0]
460
+ hidden_size = self.config.hidden_size
461
+ hidden_states: Tuple[Tensor[(batch_size, sequence_length, hidden_size)]] = ()
462
+ sequence_length = input_ids.shape[1]
463
+ output_hidden_states = output_hidden_states or self.config.output_hidden_states
464
+
465
+ last_hidden_state = self.embedding(input_ids)
466
+ assert last_hidden_state.shape == (
467
+ batch_size,
468
+ sequence_length,
469
+ hidden_size,
470
+ ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
471
+ hidden_states += (last_hidden_state,)
472
+
473
+ for layer in self.layers:
474
+ last_hidden_state = layer(last_hidden_state)
475
+ assert last_hidden_state.shape == (
476
+ batch_size,
477
+ sequence_length,
478
+ hidden_size,
479
+ ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
480
+ hidden_states += (last_hidden_state,)
481
+
482
+ last_hidden_state = self.norm_f(last_hidden_state)
483
+ assert last_hidden_state.shape == (
484
+ batch_size,
485
+ sequence_length,
486
+ hidden_size,
487
+ ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
488
+ hidden_states += (last_hidden_state,)
489
+
490
+ assert (
491
+ len(hidden_states) == self.config.n_layer + 2
492
+ ), f"{len(hidden_states)} != {self.config.n_layer + 2}"
493
+
494
+ # return BaseModelOutput(
495
+ return BaseModelOutputWithPast(
496
+ hidden_states=hidden_states if output_hidden_states else None,
497
+ last_hidden_state=last_hidden_state,
498
+ )
499
 
500
 
501
  # Influences:
 
532
  # Initialize weights and apply final processing
533
  self.post_init()
534
 
535
+ def forward(
536
+ self, input_ids, output_hidden_states=False, **kwargs
537
+ ) -> CausalLMOutput:
538
+ batch_size = input_ids.shape[0]
539
+ sequence_length = input_ids.shape[1]
540
+ vocab_size = self.config.vocab_size
541
+ output_hidden_states = output_hidden_states or self.config.output_hidden_states
542
+
543
+ outputs = self.backbone(
544
+ input_ids=input_ids,
545
+ output_hidden_states=output_hidden_states,
546
+ )
547
 
548
+ last_hidden_state = outputs.last_hidden_state
549
 
550
+ logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = (
551
+ self.lm_head(
552
+ last_hidden_state,
553
+ )
554
+ )
555
 
556
+ return CausalLMOutput(
557
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
558
+ logits=logits,
559
+ )
560
 
561
  # # def prepare_inputs_for_generation(
562
  # # self, input_ids, attention_mask=None, **model_kwargs