File size: 2,259 Bytes
4e04e38 44ea34e abaaf5a 4e04e38 abaaf5a 16ea4a7 abaaf5a 4e04e38 abaaf5a 4e04e38 abaaf5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from typing import Optional, Tuple
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import torch
from transformers import GenerationMixin, PreTrainedModel
from transformers.generation import TextStreamer
from .configuration_mamba import MambaConfig
class MambaModel(PreTrainedModel):
config_class = MambaConfig
def __init__(
self,
config,
initializer_cfg=None,
device=None,
dtype=None,
**kwargs,
):
super().__init__(
config,
**kwargs,
)
self.model = MambaLMHeadModel(
config,
initializer_cfg=initializer_cfg,
device=device,
dtype=dtype,
)
def forward(
self,
input_ids,
position_ids=None,
inference_params=None,
num_last_tokens=0,
**kwargs,
):
return self.model.forward(
input_ids,
position_ids,
inference_params,
num_last_tokens
)
class MambaModelForCausalLM(MambaModel, GenerationMixin):
def generate(
self,
input_ids,
max_length: int = 2048,
top_k: int = 1,
top_p: float = 0.0,
temperature: float = 1.0,
return_dict_in_generate: bool = False,
output_scores: bool = False,
repetition_penalty: float = 1.0,
eos_token_id: Optional[int] = None,
teacher_outputs: Optional[torch.Tensor] = None,
vocab_size: Optional[int] = None,
cg: bool = False,
enable_timing: bool = False,
streamer: Optional[TextStreamer] = None,
**kwargs,
):
return self.model.generate(
input_ids=input_ids,
max_length=max_length,
top_k=top_k,
top_p=top_p,
temperature=temperature,
return_dict_in_generate=return_dict_in_generate,
output_scores=output_scores,
repetition_penalty=repetition_penalty,
eos_token_id=eos_token_id,
teacher_outputs=teacher_outputs,
vocab_size=vocab_size,
cg=cg,
enable_timing=enable_timing,
streamer=streamer,
)
|