Upload model
Browse files- config.json +6 -1
- model.safetensors +2 -2
- modeling_mamba.py +75 -18
config.json
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
{
|
|
|
|
|
|
|
2 |
"auto_map": {
|
3 |
-
"AutoConfig": "configuration_mamba.MambaConfig"
|
|
|
4 |
},
|
5 |
"bias": false,
|
6 |
"conv_bias": true,
|
@@ -15,6 +19,7 @@
|
|
15 |
"model_type": "mamba",
|
16 |
"n_layer": 24,
|
17 |
"pad_vocab_size_multiple": 8,
|
|
|
18 |
"transformers_version": "4.37.2",
|
19 |
"vocab_size": 50280
|
20 |
}
|
|
|
1 |
{
|
2 |
+
"architectures": [
|
3 |
+
"MambaModelForCausalLM"
|
4 |
+
],
|
5 |
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_mamba.MambaConfig",
|
7 |
+
"AutoModelForCausalLM": "modeling_mamba.MambaModelForCausalLM"
|
8 |
},
|
9 |
"bias": false,
|
10 |
"conv_bias": true,
|
|
|
19 |
"model_type": "mamba",
|
20 |
"n_layer": 24,
|
21 |
"pad_vocab_size_multiple": 8,
|
22 |
+
"torch_dtype": "float32",
|
23 |
"transformers_version": "4.37.2",
|
24 |
"vocab_size": 50280
|
25 |
}
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:699ed6f59fb948186f449c5031e0dc659d504c90d7e018302aa1e190cdb40220
|
3 |
+
size 516567560
|
modeling_mamba.py
CHANGED
@@ -8,8 +8,7 @@ from torch.nn import CrossEntropyLoss
|
|
8 |
from transformers.modeling_outputs import (
|
9 |
BaseModelOutputWithPast,
|
10 |
CausalLMOutputWithPast,
|
11 |
-
|
12 |
-
SequenceClassifierOutput,
|
13 |
)
|
14 |
from transformers.modeling_utils import PreTrainedModel
|
15 |
|
@@ -320,9 +319,9 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
|
|
320 |
**kwargs,
|
321 |
) -> CausalLMOutputWithPast:
|
322 |
batch_size = input_ids.shape[0]
|
|
|
323 |
sequence_length = input_ids.shape[1]
|
324 |
vocab_size = self.config.vocab_size
|
325 |
-
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
326 |
|
327 |
outputs = self.backbone(
|
328 |
input_ids=input_ids,
|
@@ -337,7 +336,7 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
|
|
337 |
)
|
338 |
)
|
339 |
|
340 |
-
if labels:
|
341 |
shift_logits = logits[..., :-1, :].contiguous()
|
342 |
shift_labels = labels[..., 1:].contiguous()
|
343 |
loss_fct = CrossEntropyLoss()
|
@@ -364,17 +363,75 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
|
|
364 |
}
|
365 |
|
366 |
|
367 |
-
class MambaModelForSequenceClassification(
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from transformers.modeling_outputs import (
|
9 |
BaseModelOutputWithPast,
|
10 |
CausalLMOutputWithPast,
|
11 |
+
SequenceClassifierOutputWithPast,
|
|
|
12 |
)
|
13 |
from transformers.modeling_utils import PreTrainedModel
|
14 |
|
|
|
319 |
**kwargs,
|
320 |
) -> CausalLMOutputWithPast:
|
321 |
batch_size = input_ids.shape[0]
|
322 |
+
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
323 |
sequence_length = input_ids.shape[1]
|
324 |
vocab_size = self.config.vocab_size
|
|
|
325 |
|
326 |
outputs = self.backbone(
|
327 |
input_ids=input_ids,
|
|
|
336 |
)
|
337 |
)
|
338 |
|
339 |
+
if labels is not None:
|
340 |
shift_logits = logits[..., :-1, :].contiguous()
|
341 |
shift_labels = labels[..., 1:].contiguous()
|
342 |
loss_fct = CrossEntropyLoss()
|
|
|
363 |
}
|
364 |
|
365 |
|
366 |
+
# class MambaModelForSequenceClassification(MambaModelForCausalLM):
|
367 |
+
# def __init__(
|
368 |
+
# self,
|
369 |
+
# config,
|
370 |
+
# id2label={0: "NEGATIVE", 1: "POSITIVE"},
|
371 |
+
# label2id={"NEGATIVE": 0, "POSITIVE": 1},
|
372 |
+
# num_labels=2,
|
373 |
+
# **kwargs,
|
374 |
+
# ):
|
375 |
+
# super().__init__(
|
376 |
+
# config,
|
377 |
+
# **kwargs,
|
378 |
+
# )
|
379 |
+
|
380 |
+
# self.id2label = id2label
|
381 |
+
# self.label2id = label2id
|
382 |
+
# self.num_labels = num_labels # TODO: config.num_labels
|
383 |
+
|
384 |
+
# self.score = nn.Linear(
|
385 |
+
# in_features=self.config.vocab_size,
|
386 |
+
# out_features=self.num_labels,
|
387 |
+
# bias=False,
|
388 |
+
# )
|
389 |
+
|
390 |
+
# def forward(
|
391 |
+
# self,
|
392 |
+
# input_ids: Optional[torch.Tensor] = None,
|
393 |
+
# labels: Optional[torch.Tensor] = None,
|
394 |
+
# output_hidden_states=False,
|
395 |
+
# **kwargs,
|
396 |
+
# ) -> SequenceClassifierOutputWithPast:
|
397 |
+
# batch_size = input_ids.shape[0]
|
398 |
+
# hidden_size = self.config.vocab_size
|
399 |
+
# hidden_states: Tuple[
|
400 |
+
# torch.Tensor[(batch_size, sequence_length, hidden_size)]
|
401 |
+
# ] = ()
|
402 |
+
# num_labels = self.num_labels # TODO: config.num_labels
|
403 |
+
# sequence_length = input_ids.shape[1]
|
404 |
+
# vocab_size = self.config.vocab_size
|
405 |
+
# output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
406 |
+
|
407 |
+
# outputs = super().forward(
|
408 |
+
# input_ids=input_ids,
|
409 |
+
# labels=None,
|
410 |
+
# output_hidden_states=output_hidden_states,
|
411 |
+
# **kwargs,
|
412 |
+
# )
|
413 |
+
|
414 |
+
# last_hidden_state = outputs.logits
|
415 |
+
# assert last_hidden_state.shape == (
|
416 |
+
# batch_size,
|
417 |
+
# sequence_length,
|
418 |
+
# hidden_size,
|
419 |
+
# ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
|
420 |
+
# hidden_states += (last_hidden_state,)
|
421 |
+
|
422 |
+
# logits: torch.FloatTensor[batch_size, num_labels] = self.score(
|
423 |
+
# last_hidden_state[:, -1, :] # TODO: Check if this makes sense
|
424 |
+
# )
|
425 |
+
|
426 |
+
# if labels is not None:
|
427 |
+
# loss_fct = CrossEntropyLoss()
|
428 |
+
# loss = loss_fct(logits, labels)
|
429 |
+
|
430 |
+
# else:
|
431 |
+
# loss = None
|
432 |
+
|
433 |
+
# return SequenceClassifierOutputWithPast(
|
434 |
+
# loss=loss,
|
435 |
+
# logits=logits,
|
436 |
+
# hidden_states=hidden_states,
|
437 |
+
# )
|