Unsloth optims for Llama (#1609)
Browse files* WIP for unsloth integrations
* import the unsloth code in the right context
* add unsloth mlp, qkv, o lora optimizations
* apply unsloth mlp and qkv kernels
src/axolotl/monkeypatch/unsloth_.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""module for patching with unsloth optimizations"""
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
import logging
|
5 |
+
import re
|
6 |
+
import types
|
7 |
+
from typing import Tuple
|
8 |
+
|
9 |
+
from peft import PeftModelForCausalLM
|
10 |
+
from transformers.models.llama.modeling_llama import (
|
11 |
+
LlamaFlashAttention2,
|
12 |
+
LlamaForCausalLM,
|
13 |
+
)
|
14 |
+
|
15 |
+
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
|
16 |
+
|
17 |
+
ORIGINAL_CEL_CODE = """ if labels is not None:
|
18 |
+
# Shift so that tokens < n predict n
|
19 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
20 |
+
shift_labels = labels[..., 1:].contiguous()
|
21 |
+
# Flatten the tokens
|
22 |
+
loss_fct = CrossEntropyLoss()
|
23 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
24 |
+
shift_labels = shift_labels.view(-1)
|
25 |
+
# Enable model parallelism
|
26 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
27 |
+
loss = loss_fct(shift_logits, shift_labels)
|
28 |
+
"""
|
29 |
+
|
30 |
+
PATCHED_CEL_CODE = """ if labels is not None:
|
31 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
32 |
+
shift_labels = labels[..., 1:].contiguous()
|
33 |
+
loss = fast_cross_entropy_loss(
|
34 |
+
logits = shift_logits,
|
35 |
+
labels = shift_labels,
|
36 |
+
)
|
37 |
+
"""
|
38 |
+
|
39 |
+
ORIGINAL_QKV_CODE = """
|
40 |
+
query_states = self.q_proj(hidden_states)
|
41 |
+
key_states = self.k_proj(hidden_states)
|
42 |
+
value_states = self.v_proj(hidden_states)
|
43 |
+
""".lstrip(
|
44 |
+
"\n"
|
45 |
+
)
|
46 |
+
|
47 |
+
PATCHED_QKV_CODE = """
|
48 |
+
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
|
49 |
+
""".lstrip(
|
50 |
+
"\n"
|
51 |
+
)
|
52 |
+
|
53 |
+
ORIGINAL_O_CODE = """
|
54 |
+
attn_output = self.o_proj(attn_output)
|
55 |
+
""".lstrip(
|
56 |
+
"\n"
|
57 |
+
)
|
58 |
+
|
59 |
+
PATCHED_O_CODE = """
|
60 |
+
attn_output = self.apply_o(self, attn_output)
|
61 |
+
""".lstrip(
|
62 |
+
"\n"
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def original_apply_qkv(self, hidden_states):
|
67 |
+
query_states = self.q_proj(hidden_states)
|
68 |
+
key_states = self.k_proj(hidden_states)
|
69 |
+
value_states = self.v_proj(hidden_states)
|
70 |
+
return query_states, key_states, value_states
|
71 |
+
|
72 |
+
|
73 |
+
def original_apply_o(self, hidden_states):
|
74 |
+
attn_output = self.o_proj(hidden_states)
|
75 |
+
return attn_output
|
76 |
+
|
77 |
+
|
78 |
+
def get_forward_code() -> str:
|
79 |
+
forward = inspect.getsource(LlamaForCausalLM.forward)
|
80 |
+
return forward
|
81 |
+
|
82 |
+
|
83 |
+
def test_cel_is_patchable() -> bool:
|
84 |
+
forward = get_forward_code()
|
85 |
+
return ORIGINAL_CEL_CODE in forward
|
86 |
+
|
87 |
+
|
88 |
+
def get_self_attn_code() -> str:
|
89 |
+
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
90 |
+
return forward
|
91 |
+
|
92 |
+
|
93 |
+
def test_self_attn_is_patchable() -> bool:
|
94 |
+
qkv = get_self_attn_code()
|
95 |
+
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
|
96 |
+
|
97 |
+
|
98 |
+
def integrate_cross_entropy_loss_patch():
|
99 |
+
forward = get_forward_code()
|
100 |
+
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
101 |
+
forward, _ = detab_code(forward)
|
102 |
+
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
103 |
+
|
104 |
+
forward = forward.replace(
|
105 |
+
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
106 |
+
)
|
107 |
+
forward = forward.replace(
|
108 |
+
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
109 |
+
"",
|
110 |
+
)
|
111 |
+
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
112 |
+
forward = forward.replace(
|
113 |
+
"def forward(",
|
114 |
+
"def fast_cross_entropy_loss_forward(",
|
115 |
+
1,
|
116 |
+
)
|
117 |
+
|
118 |
+
# load imports necessary
|
119 |
+
import transformers.models.llama.modeling_llama
|
120 |
+
|
121 |
+
items_to_import = []
|
122 |
+
for item in dir(transformers.models.llama.modeling_llama):
|
123 |
+
if item in forward:
|
124 |
+
items_to_import.append(item)
|
125 |
+
|
126 |
+
exec( # pylint: disable=exec-used # nosec B102
|
127 |
+
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
128 |
+
globals(),
|
129 |
+
)
|
130 |
+
|
131 |
+
exec( # pylint: disable=exec-used # nosec B102
|
132 |
+
"from transformers.models.llama.modeling_llama import ("
|
133 |
+
+ ", ".join(x for x in items_to_import)
|
134 |
+
+ ")",
|
135 |
+
globals(),
|
136 |
+
)
|
137 |
+
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
138 |
+
print("patching unsloth fast_cross_entropy_loss")
|
139 |
+
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
140 |
+
|
141 |
+
|
142 |
+
def detab_code(code: str) -> Tuple[str, str]:
|
143 |
+
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
144 |
+
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
145 |
+
return code, spaces
|
146 |
+
|
147 |
+
|
148 |
+
def patch_self_attn_lora():
|
149 |
+
self_attn_forward = get_self_attn_code()
|
150 |
+
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
151 |
+
self_attn_forward
|
152 |
+
)
|
153 |
+
self_attn_forward, _ = detab_code(self_attn_forward)
|
154 |
+
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
|
155 |
+
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
|
156 |
+
|
157 |
+
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
158 |
+
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
159 |
+
self_attn_forward = self_attn_forward.replace(
|
160 |
+
"def forward(",
|
161 |
+
"def unsloth_attn_forward(",
|
162 |
+
1,
|
163 |
+
)
|
164 |
+
|
165 |
+
# load imports necessary
|
166 |
+
import transformers.models.llama.modeling_llama
|
167 |
+
|
168 |
+
items_to_import = []
|
169 |
+
for item in dir(transformers.models.llama.modeling_llama):
|
170 |
+
if item in self_attn_forward:
|
171 |
+
items_to_import.append(item)
|
172 |
+
|
173 |
+
exec( # pylint: disable=exec-used # nosec B102
|
174 |
+
"from transformers.models.llama.modeling_llama import ("
|
175 |
+
+ ", ".join(x for x in items_to_import)
|
176 |
+
+ ")",
|
177 |
+
globals(),
|
178 |
+
)
|
179 |
+
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
180 |
+
print("patching unsloth attn lora")
|
181 |
+
LlamaFlashAttention2.forward = (
|
182 |
+
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
187 |
+
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
188 |
+
from unsloth.kernels import apply_lora_mlp_swiglu
|
189 |
+
|
190 |
+
apply_lora_mlp = apply_lora_mlp_swiglu
|
191 |
+
elif peft_model.base_model.config.model_type == "gemma":
|
192 |
+
from unsloth.kernels import apply_lora_mlp_geglu_approx
|
193 |
+
|
194 |
+
apply_lora_mlp = apply_lora_mlp_geglu_approx
|
195 |
+
else:
|
196 |
+
raise NotImplementedError(
|
197 |
+
f"Model type {peft_model.base_model.config.model_type} not supported"
|
198 |
+
)
|
199 |
+
|
200 |
+
for idx, layer in enumerate(peft_model.model.model.layers):
|
201 |
+
layer_modules = [
|
202 |
+
getattr(layer.mlp, linear_proj)
|
203 |
+
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
|
204 |
+
]
|
205 |
+
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
206 |
+
mlp_no_bias = all(
|
207 |
+
getattr(module, "base_layer", module).bias is None
|
208 |
+
for module in layer_modules
|
209 |
+
)
|
210 |
+
mlp_not_dora = all(
|
211 |
+
getattr(module, "lora_magnitude_vector", None) is None
|
212 |
+
for module in layer_modules
|
213 |
+
)
|
214 |
+
|
215 |
+
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
216 |
+
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
217 |
+
else:
|
218 |
+
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
219 |
+
|
220 |
+
|
221 |
+
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
222 |
+
from unsloth.kernels import apply_lora_o, apply_lora_qkv
|
223 |
+
|
224 |
+
for idx, layer in enumerate(peft_model.model.model.layers):
|
225 |
+
if cfg.unsloth_lora_qkv:
|
226 |
+
layer_modules = [
|
227 |
+
getattr(layer.self_attn, linear_proj)
|
228 |
+
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
229 |
+
]
|
230 |
+
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
231 |
+
qkv_no_bias = all(
|
232 |
+
getattr(module, "base_layer", module).bias is None
|
233 |
+
for module in layer_modules
|
234 |
+
)
|
235 |
+
qkv_not_dora = all(
|
236 |
+
getattr(module, "lora_magnitude_vector", None) is None
|
237 |
+
for module in layer_modules
|
238 |
+
)
|
239 |
+
|
240 |
+
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
|
241 |
+
layer.self_attn.apply_qkv = apply_lora_qkv
|
242 |
+
else:
|
243 |
+
layer.self_attn.apply_qkv = original_apply_qkv
|
244 |
+
logging.warning(
|
245 |
+
"unable to apply unsloth lora qkv patch to layer %d", idx
|
246 |
+
)
|
247 |
+
if cfg.unsloth_lora_o:
|
248 |
+
layer_modules = [
|
249 |
+
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
250 |
+
]
|
251 |
+
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
252 |
+
o_no_bias = all(
|
253 |
+
getattr(module, "base_layer", module).bias is None
|
254 |
+
for module in layer_modules
|
255 |
+
)
|
256 |
+
o_not_dora = all(
|
257 |
+
getattr(module, "lora_magnitude_vector", None) is None
|
258 |
+
for module in layer_modules
|
259 |
+
)
|
260 |
+
|
261 |
+
if is_o_lora and o_no_bias and o_not_dora:
|
262 |
+
layer.self_attn.apply_o = apply_lora_o
|
263 |
+
else:
|
264 |
+
layer.self_attn.apply_o = original_apply_o
|
265 |
+
logging.warning(
|
266 |
+
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
267 |
+
)
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -549,6 +549,11 @@ class AxolotlInputConfig(
|
|
549 |
flash_attn_fuse_mlp: Optional[bool] = None
|
550 |
flash_optimum: Optional[bool] = None
|
551 |
|
|
|
|
|
|
|
|
|
|
|
552 |
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
553 |
fsdp: Optional[List[str]] = None
|
554 |
fsdp_config: Optional[Dict[str, Any]] = None
|
|
|
549 |
flash_attn_fuse_mlp: Optional[bool] = None
|
550 |
flash_optimum: Optional[bool] = None
|
551 |
|
552 |
+
unsloth_cross_entropy_loss: Optional[bool] = None
|
553 |
+
unsloth_lora_mlp: Optional[bool] = None
|
554 |
+
unsloth_lora_qkv: Optional[bool] = None
|
555 |
+
unsloth_lora_o: Optional[bool] = None
|
556 |
+
|
557 |
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
558 |
fsdp: Optional[List[str]] = None
|
559 |
fsdp_config: Optional[Dict[str, Any]] = None
|
src/axolotl/utils/models.py
CHANGED
@@ -390,6 +390,16 @@ def load_model(
|
|
390 |
"Shifted-sparse attention not currently implemented without flash attention."
|
391 |
)
|
392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
# Modify mistral derived models
|
394 |
if (
|
395 |
cfg.model_config_type == "mistral"
|
@@ -828,6 +838,15 @@ def load_model(
|
|
828 |
if cfg.adapter is not None:
|
829 |
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
830 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
831 |
# TODO resume_from_checkpoint handling
|
832 |
return model, lora_config
|
833 |
|
|
|
390 |
"Shifted-sparse attention not currently implemented without flash attention."
|
391 |
)
|
392 |
|
393 |
+
if cfg.unsloth_cross_entropy_loss:
|
394 |
+
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
395 |
+
|
396 |
+
integrate_cross_entropy_loss_patch()
|
397 |
+
|
398 |
+
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
399 |
+
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
400 |
+
|
401 |
+
patch_self_attn_lora()
|
402 |
+
|
403 |
# Modify mistral derived models
|
404 |
if (
|
405 |
cfg.model_config_type == "mistral"
|
|
|
838 |
if cfg.adapter is not None:
|
839 |
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
840 |
|
841 |
+
if cfg.unsloth_lora_mlp:
|
842 |
+
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
|
843 |
+
|
844 |
+
integrate_lora_mlp_patch(model)
|
845 |
+
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
846 |
+
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
|
847 |
+
|
848 |
+
integrate_lora_patch(model, cfg)
|
849 |
+
|
850 |
# TODO resume_from_checkpoint handling
|
851 |
return model, lora_config
|
852 |
|