Upload model
Browse files- config.json +6 -1
- modeling_mamba.py +4 -14
config.json
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
{
|
|
|
|
|
|
|
2 |
"auto_map": {
|
3 |
-
"AutoConfig": "configuration_mamba.MambaConfig"
|
|
|
4 |
},
|
5 |
"bias": false,
|
6 |
"conv_bias": true,
|
@@ -14,6 +18,7 @@
|
|
14 |
"model_type": "mamba",
|
15 |
"n_layer": 24,
|
16 |
"pad_vocab_size_multiple": 8,
|
|
|
17 |
"transformers_version": "4.37.2",
|
18 |
"vocab_size": 50280
|
19 |
}
|
|
|
1 |
{
|
2 |
+
"architectures": [
|
3 |
+
"MambaModelForCausalLM"
|
4 |
+
],
|
5 |
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_mamba.MambaConfig",
|
7 |
+
"AutoModelForCausalLM": "modeling_mamba.MambaModelForCausalLM"
|
8 |
},
|
9 |
"bias": false,
|
10 |
"conv_bias": true,
|
|
|
18 |
"model_type": "mamba",
|
19 |
"n_layer": 24,
|
20 |
"pad_vocab_size_multiple": 8,
|
21 |
+
"torch_dtype": "float32",
|
22 |
"transformers_version": "4.37.2",
|
23 |
"vocab_size": 50280
|
24 |
}
|
modeling_mamba.py
CHANGED
@@ -311,18 +311,9 @@ class MambaModel(MambaPreTrainedModel):
|
|
311 |
)
|
312 |
|
313 |
class MambaModelForCausalLM(MambaPreTrainedModel):
|
314 |
-
_tied_weights_keys = [
|
315 |
-
"lm_head.weight", # will remove this since it's a duplicate of backbone.embedding.weight
|
316 |
-
]
|
317 |
|
318 |
def __init__(self, config, **kwargs):
|
319 |
-
# super().__init__(config)
|
320 |
-
# self.backbone = MambaModel(config)
|
321 |
-
# self.vocab_size = config.vocab_size
|
322 |
-
# self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
323 |
-
# self.lm_head.weight = self.backbone.embedding.weight
|
324 |
-
# self.post_init()
|
325 |
-
|
326 |
super().__init__(
|
327 |
config,
|
328 |
**kwargs,
|
@@ -338,7 +329,6 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
|
|
338 |
bias=False,
|
339 |
)
|
340 |
|
341 |
-
# self.lm_head.weight = self.backbone.embedding.weight
|
342 |
self.post_init()
|
343 |
|
344 |
def _tie_weights(self):
|
@@ -444,9 +434,9 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
|
|
444 |
loss=loss,
|
445 |
)
|
446 |
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
|
451 |
|
452 |
class MambaModelForSequenceClassification(MambaPreTrainedModel):
|
|
|
311 |
)
|
312 |
|
313 |
class MambaModelForCausalLM(MambaPreTrainedModel):
|
314 |
+
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
|
315 |
|
316 |
def __init__(self, config, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
super().__init__(
|
318 |
config,
|
319 |
**kwargs,
|
|
|
329 |
bias=False,
|
330 |
)
|
331 |
|
|
|
332 |
self.post_init()
|
333 |
|
334 |
def _tie_weights(self):
|
|
|
434 |
loss=loss,
|
435 |
)
|
436 |
|
437 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
438 |
+
model_inputs = {"input_ids": input_ids}
|
439 |
+
return model_inputs
|
440 |
|
441 |
|
442 |
class MambaModelForSequenceClassification(MambaPreTrainedModel):
|