Upload model
Browse files- config.json +6 -1
- model.safetensors +2 -2
- modeling_mamba.py +87 -26
config.json
CHANGED
@@ -1,6 +1,10 @@
|
|
1 |
{
|
|
|
|
|
|
|
2 |
"auto_map": {
|
3 |
-
"AutoConfig": "configuration_mamba.MambaConfig"
|
|
|
4 |
},
|
5 |
"bias": false,
|
6 |
"conv_bias": true,
|
@@ -14,6 +18,7 @@
|
|
14 |
"model_type": "mamba",
|
15 |
"n_layer": 24,
|
16 |
"pad_vocab_size_multiple": 8,
|
|
|
17 |
"transformers_version": "4.37.2",
|
18 |
"vocab_size": 50280
|
19 |
}
|
|
|
1 |
{
|
2 |
+
"architectures": [
|
3 |
+
"MambaModelForCausalLM"
|
4 |
+
],
|
5 |
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_mamba.MambaConfig",
|
7 |
+
"AutoModelForCausalLM": "modeling_mamba.MambaModelForCausalLM"
|
8 |
},
|
9 |
"bias": false,
|
10 |
"conv_bias": true,
|
|
|
18 |
"model_type": "mamba",
|
19 |
"n_layer": 24,
|
20 |
"pad_vocab_size_multiple": 8,
|
21 |
+
"torch_dtype": "float32",
|
22 |
"transformers_version": "4.37.2",
|
23 |
"vocab_size": 50280
|
24 |
}
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1bd3ca62665de4bfabff9d443f87a11090a10e505c0ccb56e6f9ca495b6e05bd
|
3 |
+
size 671027808
|
modeling_mamba.py
CHANGED
@@ -313,12 +313,29 @@ class MambaModel(MambaPreTrainedModel):
|
|
313 |
class MambaModelForCausalLM(MambaPreTrainedModel):
|
314 |
_tied_weights_keys = ["lm_head.weight"]
|
315 |
|
316 |
-
def __init__(self, config):
|
317 |
-
super().__init__(config)
|
318 |
-
self.backbone = MambaModel(config)
|
319 |
-
self.vocab_size = config.vocab_size
|
320 |
-
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
321 |
-
self.lm_head.weight = self.backbone.embedding.weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
self.post_init()
|
323 |
|
324 |
# def get_input_embeddings(self):
|
@@ -339,47 +356,91 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
|
|
339 |
# def get_decoder(self):
|
340 |
# return self.model
|
341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
def forward(
|
343 |
self,
|
344 |
-
input_ids
|
345 |
labels: Optional[torch.LongTensor] = None,
|
346 |
-
|
347 |
-
output_hidden_states: Optional[bool] = None,
|
348 |
-
return_dict: Optional[bool] = None,
|
349 |
**kwargs,
|
350 |
-
) ->
|
|
|
|
|
|
|
|
|
|
|
351 |
outputs = self.backbone(
|
352 |
input_ids=input_ids,
|
353 |
-
|
354 |
)
|
355 |
-
hidden_states = outputs[0]
|
356 |
-
logits = self.lm_head(hidden_states)
|
357 |
-
logits = logits.float()
|
358 |
-
loss = None
|
359 |
|
360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
shift_logits = logits[..., :-1, :].contiguous()
|
362 |
shift_labels = labels[..., 1:].contiguous()
|
363 |
loss_fct = CrossEntropyLoss()
|
364 |
-
shift_logits = shift_logits.view(-1,
|
365 |
shift_labels = shift_labels.view(-1)
|
366 |
|
367 |
shift_labels = shift_labels.to(shift_logits.device)
|
368 |
loss = loss_fct(shift_logits, shift_labels)
|
369 |
|
370 |
-
|
371 |
-
|
372 |
-
return (loss,) + output if loss is not None else output
|
373 |
|
374 |
return CausalLMOutputWithPast(
|
375 |
-
|
376 |
logits=logits,
|
377 |
-
|
378 |
)
|
379 |
|
380 |
-
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
381 |
-
|
382 |
-
|
383 |
|
384 |
|
385 |
class MambaModelForSequenceClassification(MambaPreTrainedModel):
|
|
|
313 |
class MambaModelForCausalLM(MambaPreTrainedModel):
|
314 |
_tied_weights_keys = ["lm_head.weight"]
|
315 |
|
316 |
+
def __init__(self, config, **kwargs):
|
317 |
+
# super().__init__(config)
|
318 |
+
# self.backbone = MambaModel(config)
|
319 |
+
# self.vocab_size = config.vocab_size
|
320 |
+
# self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
321 |
+
# self.lm_head.weight = self.backbone.embedding.weight
|
322 |
+
# self.post_init()
|
323 |
+
|
324 |
+
super().__init__(
|
325 |
+
config,
|
326 |
+
**kwargs,
|
327 |
+
)
|
328 |
+
|
329 |
+
self.backbone = MambaModel(
|
330 |
+
config=self.config,
|
331 |
+
)
|
332 |
+
|
333 |
+
self.lm_head = nn.Linear(
|
334 |
+
in_features=self.config.d_model,
|
335 |
+
out_features=self.config.vocab_size,
|
336 |
+
bias=False,
|
337 |
+
)
|
338 |
+
|
339 |
self.post_init()
|
340 |
|
341 |
# def get_input_embeddings(self):
|
|
|
356 |
# def get_decoder(self):
|
357 |
# return self.model
|
358 |
|
359 |
+
# def forward(
|
360 |
+
# self,
|
361 |
+
# input_ids: torch.LongTensor = None,
|
362 |
+
# labels: Optional[torch.LongTensor] = None,
|
363 |
+
# output_attentions: Optional[bool] = None,
|
364 |
+
# output_hidden_states: Optional[bool] = None,
|
365 |
+
# return_dict: Optional[bool] = None,
|
366 |
+
# **kwargs,
|
367 |
+
# ) -> Union[Tuple, CausalLMOutputWithPast]:
|
368 |
+
# outputs = self.backbone(
|
369 |
+
# input_ids=input_ids,
|
370 |
+
# return_dict=return_dict,
|
371 |
+
# )
|
372 |
+
# hidden_states = outputs[0]
|
373 |
+
# logits = self.lm_head(hidden_states)
|
374 |
+
# logits = logits.float()
|
375 |
+
# loss = None
|
376 |
+
|
377 |
+
# if labels is not None:
|
378 |
+
# shift_logits = logits[..., :-1, :].contiguous()
|
379 |
+
# shift_labels = labels[..., 1:].contiguous()
|
380 |
+
# loss_fct = CrossEntropyLoss()
|
381 |
+
# shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
382 |
+
# shift_labels = shift_labels.view(-1)
|
383 |
+
|
384 |
+
# shift_labels = shift_labels.to(shift_logits.device)
|
385 |
+
# loss = loss_fct(shift_logits, shift_labels)
|
386 |
+
|
387 |
+
# if not return_dict:
|
388 |
+
# output = (logits,) + outputs[1:]
|
389 |
+
# return (loss,) + output if loss is not None else output
|
390 |
+
|
391 |
+
# return CausalLMOutputWithPast(
|
392 |
+
# loss=loss,
|
393 |
+
# logits=logits,
|
394 |
+
# hidden_states=outputs.hidden_states,
|
395 |
+
# )
|
396 |
+
|
397 |
def forward(
|
398 |
self,
|
399 |
+
input_ids,
|
400 |
labels: Optional[torch.LongTensor] = None,
|
401 |
+
output_hidden_states=False,
|
|
|
|
|
402 |
**kwargs,
|
403 |
+
) -> CausalLMOutputWithPast:
|
404 |
+
batch_size = input_ids.shape[0]
|
405 |
+
sequence_length = input_ids.shape[1]
|
406 |
+
vocab_size = self.config.vocab_size
|
407 |
+
output_hidden_states = output_hidden_states or self.config.output_hidden_states
|
408 |
+
|
409 |
outputs = self.backbone(
|
410 |
input_ids=input_ids,
|
411 |
+
output_hidden_states=output_hidden_states,
|
412 |
)
|
|
|
|
|
|
|
|
|
413 |
|
414 |
+
last_hidden_state = outputs.last_hidden_state
|
415 |
+
|
416 |
+
logits: torch.FloatTensor[batch_size, sequence_length, vocab_size] = (
|
417 |
+
self.lm_head(
|
418 |
+
last_hidden_state,
|
419 |
+
)
|
420 |
+
)
|
421 |
+
|
422 |
+
if labels:
|
423 |
shift_logits = logits[..., :-1, :].contiguous()
|
424 |
shift_labels = labels[..., 1:].contiguous()
|
425 |
loss_fct = CrossEntropyLoss()
|
426 |
+
shift_logits = shift_logits.view(-1, vocab_size)
|
427 |
shift_labels = shift_labels.view(-1)
|
428 |
|
429 |
shift_labels = shift_labels.to(shift_logits.device)
|
430 |
loss = loss_fct(shift_logits, shift_labels)
|
431 |
|
432 |
+
else:
|
433 |
+
loss = None
|
|
|
434 |
|
435 |
return CausalLMOutputWithPast(
|
436 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
437 |
logits=logits,
|
438 |
+
loss=loss,
|
439 |
)
|
440 |
|
441 |
+
# def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
442 |
+
# model_inputs = {"input_ids": input_ids}
|
443 |
+
# return model_inputs
|
444 |
|
445 |
|
446 |
class MambaModelForSequenceClassification(MambaPreTrainedModel):
|