|
from typing import Optional, Tuple |
|
|
|
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel |
|
import torch |
|
from transformers import GenerationMixin, PreTrainedModel |
|
from transformers.generation import TextStreamer |
|
|
|
from .configuration_mamba import MambaConfig |
|
|
|
class MambaModel(PreTrainedModel): |
|
config_class = MambaConfig |
|
|
|
def __init__( |
|
self, |
|
config, |
|
initializer_cfg=None, |
|
device=None, |
|
dtype=None, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
config, |
|
**kwargs, |
|
) |
|
|
|
self.model = MambaLMHeadModel( |
|
config, |
|
initializer_cfg=initializer_cfg, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
position_ids=None, |
|
inference_params=None, |
|
num_last_tokens=0, |
|
**kwargs, |
|
): |
|
return self.model.forward( |
|
input_ids, |
|
position_ids, |
|
inference_params, |
|
num_last_tokens |
|
) |
|
|
|
class MambaModelForCausalLM(MambaModel, GenerationMixin): |
|
def generate( |
|
self, |
|
input_ids, |
|
max_length: int = 2048, |
|
top_k: int = 1, |
|
top_p: float = 0.0, |
|
temperature: float = 1.0, |
|
return_dict_in_generate: bool = False, |
|
output_scores: bool = False, |
|
repetition_penalty: float = 1.0, |
|
eos_token_id: Optional[int] = None, |
|
teacher_outputs: Optional[torch.Tensor] = None, |
|
vocab_size: Optional[int] = None, |
|
cg: bool = False, |
|
enable_timing: bool = False, |
|
streamer: Optional[TextStreamer] = None, |
|
**kwargs, |
|
): |
|
return self.model.generate( |
|
input_ids=input_ids, |
|
max_length=max_length, |
|
top_k=top_k, |
|
top_p=top_p, |
|
temperature=temperature, |
|
return_dict_in_generate=return_dict_in_generate, |
|
output_scores=output_scores, |
|
repetition_penalty=repetition_penalty, |
|
eos_token_id=eos_token_id, |
|
teacher_outputs=teacher_outputs, |
|
vocab_size=vocab_size, |
|
cg=cg, |
|
enable_timing=enable_timing, |
|
streamer=streamer, |
|
) |
|
|