mjschock commited on
Commit
b6f9af2
·
verified ·
1 Parent(s): 7e02380

Upload model

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. model.safetensors +2 -2
  3. modeling_mamba.py +75 -18
config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_mamba.MambaConfig"
 
4
  },
5
  "bias": false,
6
  "conv_bias": true,
@@ -15,6 +19,7 @@
15
  "model_type": "mamba",
16
  "n_layer": 24,
17
  "pad_vocab_size_multiple": 8,
 
18
  "transformers_version": "4.37.2",
19
  "vocab_size": 50280
20
  }
 
1
  {
2
+ "architectures": [
3
+ "MambaModelForCausalLM"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_mamba.MambaConfig",
7
+ "AutoModelForCausalLM": "modeling_mamba.MambaModelForCausalLM"
8
  },
9
  "bias": false,
10
  "conv_bias": true,
 
19
  "model_type": "mamba",
20
  "n_layer": 24,
21
  "pad_vocab_size_multiple": 8,
22
+ "torch_dtype": "float32",
23
  "transformers_version": "4.37.2",
24
  "vocab_size": 50280
25
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:287cad4048030ae246aeda26c0e703b838c50422fe89f19099298c034b25e7b5
3
- size 516565384
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:699ed6f59fb948186f449c5031e0dc659d504c90d7e018302aa1e190cdb40220
3
+ size 516567560
modeling_mamba.py CHANGED
@@ -8,8 +8,7 @@ from torch.nn import CrossEntropyLoss
8
  from transformers.modeling_outputs import (
9
  BaseModelOutputWithPast,
10
  CausalLMOutputWithPast,
11
- QuestionAnsweringModelOutput,
12
- SequenceClassifierOutput,
13
  )
14
  from transformers.modeling_utils import PreTrainedModel
15
 
@@ -320,9 +319,9 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
320
  **kwargs,
321
  ) -> CausalLMOutputWithPast:
322
  batch_size = input_ids.shape[0]
 
323
  sequence_length = input_ids.shape[1]
324
  vocab_size = self.config.vocab_size
325
- output_hidden_states = output_hidden_states or self.config.output_hidden_states
326
 
327
  outputs = self.backbone(
328
  input_ids=input_ids,
@@ -337,7 +336,7 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
337
  )
338
  )
339
 
340
- if labels:
341
  shift_logits = logits[..., :-1, :].contiguous()
342
  shift_labels = labels[..., 1:].contiguous()
343
  loss_fct = CrossEntropyLoss()
@@ -364,17 +363,75 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
364
  }
365
 
366
 
367
- class MambaModelForSequenceClassification(MambaPreTrainedModel):
368
- def __init__(self, config):
369
- super().__init__(config)
370
- self.model = MambaModel(config)
371
- # self.classifier = nn.Linear(config.d_model, config.num_labels)
372
- # self.post_init()
373
-
374
- def forward(
375
- self,
376
- input_ids: Optional[torch.Tensor] = None,
377
- labels: Optional[torch.Tensor] = None,
378
- **kwargs,
379
- ) -> SequenceClassifierOutput:
380
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from transformers.modeling_outputs import (
9
  BaseModelOutputWithPast,
10
  CausalLMOutputWithPast,
11
+ SequenceClassifierOutputWithPast,
 
12
  )
13
  from transformers.modeling_utils import PreTrainedModel
14
 
 
319
  **kwargs,
320
  ) -> CausalLMOutputWithPast:
321
  batch_size = input_ids.shape[0]
322
+ output_hidden_states = output_hidden_states or self.config.output_hidden_states
323
  sequence_length = input_ids.shape[1]
324
  vocab_size = self.config.vocab_size
 
325
 
326
  outputs = self.backbone(
327
  input_ids=input_ids,
 
336
  )
337
  )
338
 
339
+ if labels is not None:
340
  shift_logits = logits[..., :-1, :].contiguous()
341
  shift_labels = labels[..., 1:].contiguous()
342
  loss_fct = CrossEntropyLoss()
 
363
  }
364
 
365
 
366
+ # class MambaModelForSequenceClassification(MambaModelForCausalLM):
367
+ # def __init__(
368
+ # self,
369
+ # config,
370
+ # id2label={0: "NEGATIVE", 1: "POSITIVE"},
371
+ # label2id={"NEGATIVE": 0, "POSITIVE": 1},
372
+ # num_labels=2,
373
+ # **kwargs,
374
+ # ):
375
+ # super().__init__(
376
+ # config,
377
+ # **kwargs,
378
+ # )
379
+
380
+ # self.id2label = id2label
381
+ # self.label2id = label2id
382
+ # self.num_labels = num_labels # TODO: config.num_labels
383
+
384
+ # self.score = nn.Linear(
385
+ # in_features=self.config.vocab_size,
386
+ # out_features=self.num_labels,
387
+ # bias=False,
388
+ # )
389
+
390
+ # def forward(
391
+ # self,
392
+ # input_ids: Optional[torch.Tensor] = None,
393
+ # labels: Optional[torch.Tensor] = None,
394
+ # output_hidden_states=False,
395
+ # **kwargs,
396
+ # ) -> SequenceClassifierOutputWithPast:
397
+ # batch_size = input_ids.shape[0]
398
+ # hidden_size = self.config.vocab_size
399
+ # hidden_states: Tuple[
400
+ # torch.Tensor[(batch_size, sequence_length, hidden_size)]
401
+ # ] = ()
402
+ # num_labels = self.num_labels # TODO: config.num_labels
403
+ # sequence_length = input_ids.shape[1]
404
+ # vocab_size = self.config.vocab_size
405
+ # output_hidden_states = output_hidden_states or self.config.output_hidden_states
406
+
407
+ # outputs = super().forward(
408
+ # input_ids=input_ids,
409
+ # labels=None,
410
+ # output_hidden_states=output_hidden_states,
411
+ # **kwargs,
412
+ # )
413
+
414
+ # last_hidden_state = outputs.logits
415
+ # assert last_hidden_state.shape == (
416
+ # batch_size,
417
+ # sequence_length,
418
+ # hidden_size,
419
+ # ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
420
+ # hidden_states += (last_hidden_state,)
421
+
422
+ # logits: torch.FloatTensor[batch_size, num_labels] = self.score(
423
+ # last_hidden_state[:, -1, :] # TODO: Check if this makes sense
424
+ # )
425
+
426
+ # if labels is not None:
427
+ # loss_fct = CrossEntropyLoss()
428
+ # loss = loss_fct(logits, labels)
429
+
430
+ # else:
431
+ # loss = None
432
+
433
+ # return SequenceClassifierOutputWithPast(
434
+ # loss=loss,
435
+ # logits=logits,
436
+ # hidden_states=hidden_states,
437
+ # )