feat: exclude mamba blocks for jamba (#1578)
Browse files
src/axolotl/utils/models.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
"""Module for models and model loading"""
|
|
|
2 |
# pylint: disable=too-many-lines
|
3 |
|
4 |
import logging
|
@@ -504,6 +505,9 @@ def load_model(
|
|
504 |
bnb_config = {
|
505 |
"load_in_8bit": True,
|
506 |
}
|
|
|
|
|
|
|
507 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
508 |
**bnb_config,
|
509 |
)
|
|
|
1 |
"""Module for models and model loading"""
|
2 |
+
|
3 |
# pylint: disable=too-many-lines
|
4 |
|
5 |
import logging
|
|
|
505 |
bnb_config = {
|
506 |
"load_in_8bit": True,
|
507 |
}
|
508 |
+
# Exclude mamba blocks from int8 quantization for jamba
|
509 |
+
if cfg.model_config_type == "jamba":
|
510 |
+
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
511 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
512 |
**bnb_config,
|
513 |
)
|