mjschock commited on
Commit
abaaf5a
·
verified ·
1 Parent(s): 19b3607

Upload model

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. model.safetensors +3 -0
  3. modeling_mamba.py +80 -0
config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_mamba.MambaConfig"
 
4
  },
5
  "d_model": 768,
6
  "fused_add_norm": true,
@@ -10,6 +14,7 @@
10
  "residual_in_fp32": true,
11
  "rms_norm": true,
12
  "ssm_cfg": {},
 
13
  "transformers_version": "4.37.2",
14
  "vocab_size": 50277
15
  }
 
1
  {
2
+ "architectures": [
3
+ "MambaModel"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_mamba.MambaConfig",
7
+ "AutoModel": "modeling_mamba.MambaModel"
8
  },
9
  "d_model": 768,
10
  "fused_add_norm": true,
 
14
  "residual_in_fp32": true,
15
  "rms_norm": true,
16
  "ssm_cfg": {},
17
+ "torch_dtype": "float16",
18
  "transformers_version": "4.37.2",
19
  "vocab_size": 50277
20
  }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6504b24e9ba95e4a6bad94a346c849040623647d1a99a47f4f5e1cd32cbd9572
3
+ size 259551392
modeling_mamba.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
3
+ from transformers import GenerationMixin, PreTrainedModel
4
+ from transformers.generation import TextStreamer
5
+
6
+ from mamba_model.configuration_mamba import MambaConfig
7
+
8
+ class MambaModel(PreTrainedModel):
9
+ config_class = MambaConfig
10
+
11
+ def __init__(
12
+ self,
13
+ config,
14
+ initializer_cfg=None,
15
+ device=None,
16
+ dtype=None,
17
+ **kwargs,
18
+ ):
19
+ super().__init__(
20
+ config,
21
+ **kwargs,
22
+ )
23
+
24
+ self.model = MambaLMHeadModel(
25
+ config,
26
+ initializer_cfg=initializer_cfg,
27
+ device=device,
28
+ dtype=dtype,
29
+ )
30
+
31
+ def forward(
32
+ self,
33
+ input_ids,
34
+ position_ids=None,
35
+ inference_params=None,
36
+ num_last_tokens=0,
37
+ **kwargs,
38
+ ):
39
+ return self.model.forward(
40
+ input_ids,
41
+ position_ids,
42
+ inference_params,
43
+ num_last_tokens
44
+ )
45
+
46
+ class MambaModelForCausalLM(MambaModel, GenerationMixin):
47
+ def generate(
48
+ self,
49
+ input_ids,
50
+ max_length,
51
+ top_k=1,
52
+ top_p=0.0,
53
+ temperature=1.0,
54
+ return_dict_in_generate=False,
55
+ output_scores=False,
56
+ repetition_penalty=1.0,
57
+ eos_token_id=None,
58
+ teacher_outputs=None,
59
+ vocab_size=None,
60
+ cg=False,
61
+ enable_timing=False,
62
+ streamer: Optional[TextStreamer] = None,
63
+ **kwargs,
64
+ ):
65
+ return self.model.generate(
66
+ input_ids=input_ids,
67
+ max_length=max_length,
68
+ top_k=top_k,
69
+ top_p=top_p,
70
+ temperature=temperature,
71
+ return_dict_in_generate=return_dict_in_generate,
72
+ output_scores=output_scores,
73
+ repetition_penalty=repetition_penalty,
74
+ eos_token_id=eos_token_id,
75
+ teacher_outputs=teacher_outputs,
76
+ vocab_size=vocab_size,
77
+ cg=cg,
78
+ enable_timing=enable_timing,
79
+ streamer = streamer,
80
+ )