mjschock commited on
Commit
c587c9c
·
verified ·
1 Parent(s): a0b31be

Upload model

Browse files
Files changed (2) hide show
  1. config.json +6 -1
  2. modeling_mamba.py +7 -85
config.json CHANGED
@@ -1,6 +1,10 @@
1
  {
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "configuration_mamba.MambaConfig"
 
4
  },
5
  "bias": false,
6
  "conv_bias": true,
@@ -14,6 +18,7 @@
14
  "model_type": "mamba",
15
  "n_layer": 24,
16
  "pad_vocab_size_multiple": 8,
 
17
  "transformers_version": "4.37.2",
18
  "vocab_size": 50280
19
  }
 
1
  {
2
+ "architectures": [
3
+ "MambaModelForCausalLM"
4
+ ],
5
  "auto_map": {
6
+ "AutoConfig": "configuration_mamba.MambaConfig",
7
+ "AutoModelForCausalLM": "modeling_mamba.MambaModelForCausalLM"
8
  },
9
  "bias": false,
10
  "conv_bias": true,
 
18
  "model_type": "mamba",
19
  "n_layer": 24,
20
  "pad_vocab_size_multiple": 8,
21
+ "torch_dtype": "float32",
22
  "transformers_version": "4.37.2",
23
  "vocab_size": 50280
24
  }
modeling_mamba.py CHANGED
@@ -230,48 +230,26 @@ class MambaModel(MambaPreTrainedModel):
230
  self.config = config
231
 
232
  self.embedding = nn.Embedding(config.vocab_size, config.d_model)
233
- self.layers = nn.ModuleList([MambaBlock(config, layer_idx) for layer_idx in range(config.n_layer)])
 
 
234
  self.norm_f = MambaRMSNorm(config.d_model)
235
 
236
  self.gradient_checkpointing = False
237
  self.post_init()
238
 
239
- # def get_input_embeddings(self):
240
- # return self.embedding
241
-
242
- # def set_input_embeddings(self, value):
243
- # self.embedding = value
244
-
245
- # def forward(
246
- # self,
247
- # input_ids: torch.LongTensor = None,
248
- # **kwargs,
249
- # ) -> Union[Tuple, BaseModelOutputWithPast]:
250
- # x = self.embedding(input_ids)
251
- # all_hidden_states = list()
252
- # for layer in self.layers:
253
- # x = layer(x)
254
- # all_hidden_states.append(x)
255
-
256
- # hidden_states = self.norm_f(x)
257
-
258
- # return BaseModelOutputWithPast(
259
- # last_hidden_state=hidden_states,
260
- # hidden_states=all_hidden_states,
261
- # )
262
-
263
  def forward(
264
  self,
265
  input_ids: torch.LongTensor = None,
266
  output_hidden_states=False,
267
  return_dict: Optional[bool] = None,
268
  **kwargs,
269
- # ) -> BaseModelOutput:
270
- # ) -> Union[Tuple, BaseModelOutputWithPast]:
271
  ) -> BaseModelOutputWithPast:
272
  batch_size = input_ids.shape[0]
273
  hidden_size = self.config.d_model
274
- hidden_states: Tuple[torch.Tensor[(batch_size, sequence_length, hidden_size)]] = ()
 
 
275
  sequence_length = input_ids.shape[1]
276
  output_hidden_states = output_hidden_states or self.config.output_hidden_states
277
 
@@ -304,12 +282,12 @@ class MambaModel(MambaPreTrainedModel):
304
  len(hidden_states) == self.config.n_layer + 2
305
  ), f"{len(hidden_states)} != {self.config.n_layer + 2}"
306
 
307
- # return BaseModelOutput(
308
  return BaseModelOutputWithPast(
309
  hidden_states=hidden_states if output_hidden_states else None,
310
  last_hidden_state=last_hidden_state,
311
  )
312
 
 
313
  class MambaModelForCausalLM(MambaPreTrainedModel):
314
  _tied_weights_keys = ["lm_head.weight"]
315
 
@@ -334,62 +312,6 @@ class MambaModelForCausalLM(MambaPreTrainedModel):
334
  def _tie_weights(self):
335
  self.lm_head.weight = self.backbone.embedding.weight
336
 
337
- # def get_input_embeddings(self):
338
- # return self.model.embedding
339
-
340
- # def set_input_embeddings(self, value):
341
- # self.model.embedding = value
342
-
343
- # def get_output_embeddings(self):
344
- # return self.lm_head
345
-
346
- # def set_output_embeddings(self, new_embeddings):
347
- # self.lm_head = new_embeddings
348
-
349
- # def set_decoder(self, decoder):
350
- # self.model = decoder
351
-
352
- # def get_decoder(self):
353
- # return self.model
354
-
355
- # def forward(
356
- # self,
357
- # input_ids: torch.LongTensor = None,
358
- # labels: Optional[torch.LongTensor] = None,
359
- # output_attentions: Optional[bool] = None,
360
- # output_hidden_states: Optional[bool] = None,
361
- # return_dict: Optional[bool] = None,
362
- # **kwargs,
363
- # ) -> Union[Tuple, CausalLMOutputWithPast]:
364
- # outputs = self.backbone(
365
- # input_ids=input_ids,
366
- # return_dict=return_dict,
367
- # )
368
- # hidden_states = outputs[0]
369
- # logits = self.lm_head(hidden_states)
370
- # logits = logits.float()
371
- # loss = None
372
-
373
- # if labels is not None:
374
- # shift_logits = logits[..., :-1, :].contiguous()
375
- # shift_labels = labels[..., 1:].contiguous()
376
- # loss_fct = CrossEntropyLoss()
377
- # shift_logits = shift_logits.view(-1, self.config.vocab_size)
378
- # shift_labels = shift_labels.view(-1)
379
-
380
- # shift_labels = shift_labels.to(shift_logits.device)
381
- # loss = loss_fct(shift_logits, shift_labels)
382
-
383
- # if not return_dict:
384
- # output = (logits,) + outputs[1:]
385
- # return (loss,) + output if loss is not None else output
386
-
387
- # return CausalLMOutputWithPast(
388
- # loss=loss,
389
- # logits=logits,
390
- # hidden_states=outputs.hidden_states,
391
- # )
392
-
393
  def forward(
394
  self,
395
  input_ids,
 
230
  self.config = config
231
 
232
  self.embedding = nn.Embedding(config.vocab_size, config.d_model)
233
+ self.layers = nn.ModuleList(
234
+ [MambaBlock(config, layer_idx) for layer_idx in range(config.n_layer)]
235
+ )
236
  self.norm_f = MambaRMSNorm(config.d_model)
237
 
238
  self.gradient_checkpointing = False
239
  self.post_init()
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  def forward(
242
  self,
243
  input_ids: torch.LongTensor = None,
244
  output_hidden_states=False,
245
  return_dict: Optional[bool] = None,
246
  **kwargs,
 
 
247
  ) -> BaseModelOutputWithPast:
248
  batch_size = input_ids.shape[0]
249
  hidden_size = self.config.d_model
250
+ hidden_states: Tuple[
251
+ torch.Tensor[(batch_size, sequence_length, hidden_size)]
252
+ ] = ()
253
  sequence_length = input_ids.shape[1]
254
  output_hidden_states = output_hidden_states or self.config.output_hidden_states
255
 
 
282
  len(hidden_states) == self.config.n_layer + 2
283
  ), f"{len(hidden_states)} != {self.config.n_layer + 2}"
284
 
 
285
  return BaseModelOutputWithPast(
286
  hidden_states=hidden_states if output_hidden_states else None,
287
  last_hidden_state=last_hidden_state,
288
  )
289
 
290
+
291
  class MambaModelForCausalLM(MambaPreTrainedModel):
292
  _tied_weights_keys = ["lm_head.weight"]
293
 
 
312
  def _tie_weights(self):
313
  self.lm_head.weight = self.backbone.embedding.weight
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  def forward(
316
  self,
317
  input_ids,