mjschock commited on
Commit
0214425
·
verified ·
1 Parent(s): c47ffb9

Upload model

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. model.safetensors +2 -2
  3. modeling_mamba.py +87 -26
config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_mamba.MambaConfig"
 
4
  },
5
  "bias": false,
6
  "conv_bias": true,
@@ -14,6 +18,7 @@
14
  "model_type": "mamba",
15
  "n_layer": 24,
16
  "pad_vocab_size_multiple": 8,
 
17
  "transformers_version": "4.37.2",
18
  "vocab_size": 50280
19
  }
 
1
  {
2
+ "architectures": [
3
+ "MambaModelForCausalLM"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_mamba.MambaConfig",
7
+ "AutoModelForCausalLM": "modeling_mamba.MambaModelForCausalLM"
8
  },
9
  "bias": false,
10
  "conv_bias": true,
 
18
  "model_type": "mamba",
19
  "n_layer": 24,
20
  "pad_vocab_size_multiple": 8,
21
+ "torch_dtype": "float32",
22
  "transformers_version": "4.37.2",
23
  "vocab_size": 50280
24
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:699ed6f59fb948186f449c5031e0dc659d504c90d7e018302aa1e190cdb40220
3
- size 516567560
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bd3ca62665de4bfabff9d443f87a11090a10e505c0ccb56e6f9ca495b6e05bd
3
+ size 671027808
modeling_mamba.py CHANGED
@@ -313,12 +313,29 @@ class MambaModel(MambaPreTrainedModel):
313
  class MambaModelForCausalLM(MambaPreTrainedModel):
314
  _tied_weights_keys = ["lm_head.weight"]
315
 
316
- def __init__(self, config):
317
- super().__init__(config)
318
- self.backbone = MambaModel(config)
319
- self.vocab_size = config.vocab_size
320
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
321
- self.lm_head.weight = self.backbone.embedding.weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  self.post_init()
323
 
324
  # def get_input_embeddings(self):
@@ -339,47 +356,91 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
339
  # def get_decoder(self):
340
  # return self.model
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  def forward(
343
  self,
344
- input_ids: torch.LongTensor = None,
345
  labels: Optional[torch.LongTensor] = None,
346
- output_attentions: Optional[bool] = None,
347
- output_hidden_states: Optional[bool] = None,
348
- return_dict: Optional[bool] = None,
349
  **kwargs,
350
- ) -> Union[Tuple, CausalLMOutputWithPast]:
 
 
 
 
 
351
  outputs = self.backbone(
352
  input_ids=input_ids,
353
- return_dict=return_dict,
354
  )
355
- hidden_states = outputs[0]
356
- logits = self.lm_head(hidden_states)
357
- logits = logits.float()
358
- loss = None
359
 
360
- if labels is not None:
 
 
 
 
 
 
 
 
361
  shift_logits = logits[..., :-1, :].contiguous()
362
  shift_labels = labels[..., 1:].contiguous()
363
  loss_fct = CrossEntropyLoss()
364
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
365
  shift_labels = shift_labels.view(-1)
366
 
367
  shift_labels = shift_labels.to(shift_logits.device)
368
  loss = loss_fct(shift_logits, shift_labels)
369
 
370
- if not return_dict:
371
- output = (logits,) + outputs[1:]
372
- return (loss,) + output if loss is not None else output
373
 
374
  return CausalLMOutputWithPast(
375
- loss=loss,
376
  logits=logits,
377
- hidden_states=outputs.hidden_states,
378
  )
379
 
380
- def prepare_inputs_for_generation(self, input_ids, **kwargs):
381
- model_inputs = {"input_ids": input_ids}
382
- return model_inputs
383
 
384
 
385
  class MambaModelForSequenceClassification(MambaPreTrainedModel):
 
313
  class MambaModelForCausalLM(MambaPreTrainedModel):
314
  _tied_weights_keys = ["lm_head.weight"]
315
 
316
+ def __init__(self, config, **kwargs):
317
+ # super().__init__(config)
318
+ # self.backbone = MambaModel(config)
319
+ # self.vocab_size = config.vocab_size
320
+ # self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
321
+ # self.lm_head.weight = self.backbone.embedding.weight
322
+ # self.post_init()
323
+
324
+ super().__init__(
325
+ config,
326
+ **kwargs,
327
+ )
328
+
329
+ self.backbone = MambaModel(
330
+ config=self.config,
331
+ )
332
+
333
+ self.lm_head = nn.Linear(
334
+ in_features=self.config.d_model,
335
+ out_features=self.config.vocab_size,
336
+ bias=False,
337
+ )
338
+
339
  self.post_init()
340
 
341
  # def get_input_embeddings(self):
 
356
  # def get_decoder(self):
357
  # return self.model
358
 
359
+ # def forward(
360
+ # self,
361
+ # input_ids: torch.LongTensor = None,
362
+ # labels: Optional[torch.LongTensor] = None,
363
+ # output_attentions: Optional[bool] = None,
364
+ # output_hidden_states: Optional[bool] = None,
365
+ # return_dict: Optional[bool] = None,
366
+ # **kwargs,
367
+ # ) -> Union[Tuple, CausalLMOutputWithPast]:
368
+ # outputs = self.backbone(
369
+ # input_ids=input_ids,
370
+ # return_dict=return_dict,
371
+ # )
372
+ # hidden_states = outputs[0]
373
+ # logits = self.lm_head(hidden_states)
374
+ # logits = logits.float()
375
+ # loss = None
376
+
377
+ # if labels is not None:
378
+ # shift_logits = logits[..., :-1, :].contiguous()
379
+ # shift_labels = labels[..., 1:].contiguous()
380
+ # loss_fct = CrossEntropyLoss()
381
+ # shift_logits = shift_logits.view(-1, self.config.vocab_size)
382
+ # shift_labels = shift_labels.view(-1)
383
+
384
+ # shift_labels = shift_labels.to(shift_logits.device)
385
+ # loss = loss_fct(shift_logits, shift_labels)
386
+
387
+ # if not return_dict:
388
+ # output = (logits,) + outputs[1:]
389
+ # return (loss,) + output if loss is not None else output
390
+
391
+ # return CausalLMOutputWithPast(
392
+ # loss=loss,
393
+ # logits=logits,
394
+ # hidden_states=outputs.hidden_states,
395
+ # )
396
+
397
  def forward(
398
  self,
399
+ input_ids,
400
  labels: Optional[torch.LongTensor] = None,
401
+ output_hidden_states=False,
 
 
402
  **kwargs,
403
+ ) -> CausalLMOutputWithPast:
404
+ batch_size = input_ids.shape[0]
405
+ sequence_length = input_ids.shape[1]
406
+ vocab_size = self.config.vocab_size
407
+ output_hidden_states = output_hidden_states or self.config.output_hidden_states
408
+
409
  outputs = self.backbone(
410
  input_ids=input_ids,
411
+ output_hidden_states=output_hidden_states,
412
  )
 
 
 
 
413
 
414
+ last_hidden_state = outputs.last_hidden_state
415
+
416
+ logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = (
417
+ self.lm_head(
418
+ last_hidden_state,
419
+ )
420
+ )
421
+
422
+ if labels:
423
  shift_logits = logits[..., :-1, :].contiguous()
424
  shift_labels = labels[..., 1:].contiguous()
425
  loss_fct = CrossEntropyLoss()
426
+ shift_logits = shift_logits.view(-1, vocab_size)
427
  shift_labels = shift_labels.view(-1)
428
 
429
  shift_labels = shift_labels.to(shift_logits.device)
430
  loss = loss_fct(shift_logits, shift_labels)
431
 
432
+ else:
433
+ loss = None
 
434
 
435
  return CausalLMOutputWithPast(
436
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
437
  logits=logits,
438
+ loss=loss,
439
  )
440
 
441
+ # def prepare_inputs_for_generation(self, input_ids, **kwargs):
442
+ # model_inputs = {"input_ids": input_ids}
443
+ # return model_inputs
444
 
445
 
446
  class MambaModelForSequenceClassification(MambaPreTrainedModel):