mjschock commited on
Commit
728612d
·
verified ·
1 Parent(s): 1877ce7

Upload model

Browse files
Files changed (2) hide show
  1. config.json +6 -1
  2. modeling_mamba.py +60 -11
config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_mamba.MambaConfig"
 
4
  },
5
  "bias": false,
6
  "conv_bias": true,
@@ -14,6 +18,7 @@
14
  "model_type": "mamba",
15
  "n_layer": 24,
16
  "pad_vocab_size_multiple": 8,
 
17
  "transformers_version": "4.37.2",
18
  "vocab_size": 50280
19
  }
 
1
  {
2
+ "architectures": [
3
+ "MambaModelForCausalLM"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_mamba.MambaConfig",
7
+ "AutoModelForCausalLM": "modeling_mamba.MambaModelForCausalLM"
8
  },
9
  "bias": false,
10
  "conv_bias": true,
 
18
  "model_type": "mamba",
19
  "n_layer": 24,
20
  "pad_vocab_size_multiple": 8,
21
+ "torch_dtype": "float32",
22
  "transformers_version": "4.37.2",
23
  "vocab_size": 50280
24
  }
modeling_mamba.py CHANGED
@@ -241,25 +241,74 @@ class MambaModel(MambaPreTrainedModel):
241
  # def set_input_embeddings(self, value):
242
  # self.embedding = value
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  def forward(
245
  self,
246
  input_ids: torch.LongTensor = None,
 
 
247
  **kwargs,
248
- ) -> Union[Tuple, BaseModelOutputWithPast]:
249
- x = self.embedding(input_ids)
250
- all_hidden_states = list()
251
- for layer in self.layers:
252
- x = layer(x)
253
- all_hidden_states.append(x)
254
-
255
- hidden_states = self.norm_f(x)
 
 
 
 
 
 
 
 
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  return BaseModelOutputWithPast(
258
- last_hidden_state=hidden_states,
259
- hidden_states=all_hidden_states,
260
  )
261
 
262
-
263
  class MambaModelForCausalLM(MambaPreTrainedModel):
264
  _tied_weights_keys = ["lm_head.weight"]
265
 
 
241
  # def set_input_embeddings(self, value):
242
  # self.embedding = value
243
 
244
+ # def forward(
245
+ # self,
246
+ # input_ids: torch.LongTensor = None,
247
+ # **kwargs,
248
+ # ) -> Union[Tuple, BaseModelOutputWithPast]:
249
+ # x = self.embedding(input_ids)
250
+ # all_hidden_states = list()
251
+ # for layer in self.layers:
252
+ # x = layer(x)
253
+ # all_hidden_states.append(x)
254
+
255
+ # hidden_states = self.norm_f(x)
256
+
257
+ # return BaseModelOutputWithPast(
258
+ # last_hidden_state=hidden_states,
259
+ # hidden_states=all_hidden_states,
260
+ # )
261
+
262
  def forward(
263
  self,
264
  input_ids: torch.LongTensor = None,
265
+ output_hidden_states=False,
266
+ return_dict: Optional[bool] = None,
267
  **kwargs,
268
+ # ) -> BaseModelOutput:
269
+ # ) -> Union[Tuple, BaseModelOutputWithPast]:
270
+ ) -> BaseModelOutputWithPast:
271
+ batch_size = input_ids.shape[0]
272
+ hidden_size = self.config.d_model
273
+ hidden_states: Tuple[torch.Tensor[(batch_size, sequence_length, hidden_size)]] = ()
274
+ sequence_length = input_ids.shape[1]
275
+ output_hidden_states = output_hidden_states or self.config.output_hidden_states
276
+
277
+ last_hidden_state = self.embedding(input_ids)
278
+ assert last_hidden_state.shape == (
279
+ batch_size,
280
+ sequence_length,
281
+ hidden_size,
282
+ ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
283
+ hidden_states += (last_hidden_state,)
284
 
285
+ for layer in self.layers:
286
+ last_hidden_state = layer(last_hidden_state)
287
+ assert last_hidden_state.shape == (
288
+ batch_size,
289
+ sequence_length,
290
+ hidden_size,
291
+ ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
292
+ hidden_states += (last_hidden_state,)
293
+
294
+ last_hidden_state = self.norm_f(last_hidden_state)
295
+ assert last_hidden_state.shape == (
296
+ batch_size,
297
+ sequence_length,
298
+ hidden_size,
299
+ ), f"{last_hidden_state.shape} != {(batch_size, sequence_length, hidden_size)}"
300
+ hidden_states += (last_hidden_state,)
301
+
302
+ assert (
303
+ len(hidden_states) == self.config.n_layer + 2
304
+ ), f"{len(hidden_states)} != {self.config.n_layer + 2}"
305
+
306
+ # return BaseModelOutput(
307
  return BaseModelOutputWithPast(
308
+ hidden_states=hidden_states if output_hidden_states else None,
309
+ last_hidden_state=last_hidden_state,
310
  )
311
 
 
312
  class MambaModelForCausalLM(MambaPreTrainedModel):
313
  _tied_weights_keys = ["lm_head.weight"]
314