Nanobit commited on
Commit
8b9c15b
·
unverified ·
1 Parent(s): 9e1480e

feat: exclude mamba blocks for jamba (#1578)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +4 -0
src/axolotl/utils/models.py CHANGED
@@ -1,4 +1,5 @@
1
  """Module for models and model loading"""
 
2
  # pylint: disable=too-many-lines
3
 
4
  import logging
@@ -504,6 +505,9 @@ def load_model(
504
  bnb_config = {
505
  "load_in_8bit": True,
506
  }
 
 
 
507
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
508
  **bnb_config,
509
  )
 
1
  """Module for models and model loading"""
2
+
3
  # pylint: disable=too-many-lines
4
 
5
  import logging
 
505
  bnb_config = {
506
  "load_in_8bit": True,
507
  }
508
+ # Exclude mamba blocks from int8 quantization for jamba
509
+ if cfg.model_config_type == "jamba":
510
+ bnb_config["llm_int8_skip_modules"] = ["mamba"]
511
  model_kwargs["quantization_config"] = BitsAndBytesConfig(
512
  **bnb_config,
513
  )