winglian commited on
Commit
8a1572a
·
unverified ·
1 Parent(s): 702a669

Unsloth optims for Llama (#1609)

Browse files

* WIP for unsloth integrations

* import the unsloth code in the right context

* add unsloth mlp, qkv, o lora optimizations

* apply unsloth mlp and qkv kernels

src/axolotl/monkeypatch/unsloth_.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """module for patching with unsloth optimizations"""
2
+
3
+ import inspect
4
+ import logging
5
+ import re
6
+ import types
7
+ from typing import Tuple
8
+
9
+ from peft import PeftModelForCausalLM
10
+ from transformers.models.llama.modeling_llama import (
11
+ LlamaFlashAttention2,
12
+ LlamaForCausalLM,
13
+ )
14
+
15
+ LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
16
+
17
+ ORIGINAL_CEL_CODE = """ if labels is not None:
18
+ # Shift so that tokens < n predict n
19
+ shift_logits = logits[..., :-1, :].contiguous()
20
+ shift_labels = labels[..., 1:].contiguous()
21
+ # Flatten the tokens
22
+ loss_fct = CrossEntropyLoss()
23
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
24
+ shift_labels = shift_labels.view(-1)
25
+ # Enable model parallelism
26
+ shift_labels = shift_labels.to(shift_logits.device)
27
+ loss = loss_fct(shift_logits, shift_labels)
28
+ """
29
+
30
+ PATCHED_CEL_CODE = """ if labels is not None:
31
+ shift_logits = logits[..., :-1, :].contiguous()
32
+ shift_labels = labels[..., 1:].contiguous()
33
+ loss = fast_cross_entropy_loss(
34
+ logits = shift_logits,
35
+ labels = shift_labels,
36
+ )
37
+ """
38
+
39
+ ORIGINAL_QKV_CODE = """
40
+ query_states = self.q_proj(hidden_states)
41
+ key_states = self.k_proj(hidden_states)
42
+ value_states = self.v_proj(hidden_states)
43
+ """.lstrip(
44
+ "\n"
45
+ )
46
+
47
+ PATCHED_QKV_CODE = """
48
+ query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
49
+ """.lstrip(
50
+ "\n"
51
+ )
52
+
53
+ ORIGINAL_O_CODE = """
54
+ attn_output = self.o_proj(attn_output)
55
+ """.lstrip(
56
+ "\n"
57
+ )
58
+
59
+ PATCHED_O_CODE = """
60
+ attn_output = self.apply_o(self, attn_output)
61
+ """.lstrip(
62
+ "\n"
63
+ )
64
+
65
+
66
+ def original_apply_qkv(self, hidden_states):
67
+ query_states = self.q_proj(hidden_states)
68
+ key_states = self.k_proj(hidden_states)
69
+ value_states = self.v_proj(hidden_states)
70
+ return query_states, key_states, value_states
71
+
72
+
73
+ def original_apply_o(self, hidden_states):
74
+ attn_output = self.o_proj(hidden_states)
75
+ return attn_output
76
+
77
+
78
+ def get_forward_code() -> str:
79
+ forward = inspect.getsource(LlamaForCausalLM.forward)
80
+ return forward
81
+
82
+
83
+ def test_cel_is_patchable() -> bool:
84
+ forward = get_forward_code()
85
+ return ORIGINAL_CEL_CODE in forward
86
+
87
+
88
+ def get_self_attn_code() -> str:
89
+ forward = inspect.getsource(LlamaFlashAttention2.forward)
90
+ return forward
91
+
92
+
93
+ def test_self_attn_is_patchable() -> bool:
94
+ qkv = get_self_attn_code()
95
+ return ORIGINAL_QKV_CODE in qkv and ORIGINAL_QKV_CODE in qkv
96
+
97
+
98
+ def integrate_cross_entropy_loss_patch():
99
+ forward = get_forward_code()
100
+ LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
101
+ forward, _ = detab_code(forward)
102
+ assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
103
+
104
+ forward = forward.replace(
105
+ "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
106
+ )
107
+ forward = forward.replace(
108
+ "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
109
+ "",
110
+ )
111
+ forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
112
+ forward = forward.replace(
113
+ "def forward(",
114
+ "def fast_cross_entropy_loss_forward(",
115
+ 1,
116
+ )
117
+
118
+ # load imports necessary
119
+ import transformers.models.llama.modeling_llama
120
+
121
+ items_to_import = []
122
+ for item in dir(transformers.models.llama.modeling_llama):
123
+ if item in forward:
124
+ items_to_import.append(item)
125
+
126
+ exec( # pylint: disable=exec-used # nosec B102
127
+ "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
128
+ globals(),
129
+ )
130
+
131
+ exec( # pylint: disable=exec-used # nosec B102
132
+ "from transformers.models.llama.modeling_llama import ("
133
+ + ", ".join(x for x in items_to_import)
134
+ + ")",
135
+ globals(),
136
+ )
137
+ exec(forward, globals()) # pylint: disable=exec-used # nosec B102
138
+ print("patching unsloth fast_cross_entropy_loss")
139
+ LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
140
+
141
+
142
+ def detab_code(code: str) -> Tuple[str, str]:
143
+ spaces = re.match(r"([\s\t]{1,})", code).group(0)
144
+ code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
145
+ return code, spaces
146
+
147
+
148
+ def patch_self_attn_lora():
149
+ self_attn_forward = get_self_attn_code()
150
+ LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
151
+ self_attn_forward
152
+ )
153
+ self_attn_forward, _ = detab_code(self_attn_forward)
154
+ assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
155
+ assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
156
+
157
+ self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
158
+ self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
159
+ self_attn_forward = self_attn_forward.replace(
160
+ "def forward(",
161
+ "def unsloth_attn_forward(",
162
+ 1,
163
+ )
164
+
165
+ # load imports necessary
166
+ import transformers.models.llama.modeling_llama
167
+
168
+ items_to_import = []
169
+ for item in dir(transformers.models.llama.modeling_llama):
170
+ if item in self_attn_forward:
171
+ items_to_import.append(item)
172
+
173
+ exec( # pylint: disable=exec-used # nosec B102
174
+ "from transformers.models.llama.modeling_llama import ("
175
+ + ", ".join(x for x in items_to_import)
176
+ + ")",
177
+ globals(),
178
+ )
179
+ exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
180
+ print("patching unsloth attn lora")
181
+ LlamaFlashAttention2.forward = (
182
+ unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
183
+ )
184
+
185
+
186
+ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
187
+ if peft_model.base_model.config.model_type in ["llama", "mistral"]:
188
+ from unsloth.kernels import apply_lora_mlp_swiglu
189
+
190
+ apply_lora_mlp = apply_lora_mlp_swiglu
191
+ elif peft_model.base_model.config.model_type == "gemma":
192
+ from unsloth.kernels import apply_lora_mlp_geglu_approx
193
+
194
+ apply_lora_mlp = apply_lora_mlp_geglu_approx
195
+ else:
196
+ raise NotImplementedError(
197
+ f"Model type {peft_model.base_model.config.model_type} not supported"
198
+ )
199
+
200
+ for idx, layer in enumerate(peft_model.model.model.layers):
201
+ layer_modules = [
202
+ getattr(layer.mlp, linear_proj)
203
+ for linear_proj in ["gate_proj", "up_proj", "down_proj"]
204
+ ]
205
+ is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
206
+ mlp_no_bias = all(
207
+ getattr(module, "base_layer", module).bias is None
208
+ for module in layer_modules
209
+ )
210
+ mlp_not_dora = all(
211
+ getattr(module, "lora_magnitude_vector", None) is None
212
+ for module in layer_modules
213
+ )
214
+
215
+ if is_mlp_lora and mlp_no_bias and mlp_not_dora:
216
+ layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
217
+ else:
218
+ logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
219
+
220
+
221
+ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
222
+ from unsloth.kernels import apply_lora_o, apply_lora_qkv
223
+
224
+ for idx, layer in enumerate(peft_model.model.model.layers):
225
+ if cfg.unsloth_lora_qkv:
226
+ layer_modules = [
227
+ getattr(layer.self_attn, linear_proj)
228
+ for linear_proj in ["q_proj", "k_proj", "v_proj"]
229
+ ]
230
+ is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
231
+ qkv_no_bias = all(
232
+ getattr(module, "base_layer", module).bias is None
233
+ for module in layer_modules
234
+ )
235
+ qkv_not_dora = all(
236
+ getattr(module, "lora_magnitude_vector", None) is None
237
+ for module in layer_modules
238
+ )
239
+
240
+ if is_qkv_lora and qkv_no_bias and qkv_not_dora:
241
+ layer.self_attn.apply_qkv = apply_lora_qkv
242
+ else:
243
+ layer.self_attn.apply_qkv = original_apply_qkv
244
+ logging.warning(
245
+ "unable to apply unsloth lora qkv patch to layer %d", idx
246
+ )
247
+ if cfg.unsloth_lora_o:
248
+ layer_modules = [
249
+ getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
250
+ ]
251
+ is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
252
+ o_no_bias = all(
253
+ getattr(module, "base_layer", module).bias is None
254
+ for module in layer_modules
255
+ )
256
+ o_not_dora = all(
257
+ getattr(module, "lora_magnitude_vector", None) is None
258
+ for module in layer_modules
259
+ )
260
+
261
+ if is_o_lora and o_no_bias and o_not_dora:
262
+ layer.self_attn.apply_o = apply_lora_o
263
+ else:
264
+ layer.self_attn.apply_o = original_apply_o
265
+ logging.warning(
266
+ "unable to apply unsloth lora o_proj patch to layer %d", idx
267
+ )
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -549,6 +549,11 @@ class AxolotlInputConfig(
549
  flash_attn_fuse_mlp: Optional[bool] = None
550
  flash_optimum: Optional[bool] = None
551
 
 
 
 
 
 
552
  deepspeed: Optional[Union[str, Dict[str, Any]]] = None
553
  fsdp: Optional[List[str]] = None
554
  fsdp_config: Optional[Dict[str, Any]] = None
 
549
  flash_attn_fuse_mlp: Optional[bool] = None
550
  flash_optimum: Optional[bool] = None
551
 
552
+ unsloth_cross_entropy_loss: Optional[bool] = None
553
+ unsloth_lora_mlp: Optional[bool] = None
554
+ unsloth_lora_qkv: Optional[bool] = None
555
+ unsloth_lora_o: Optional[bool] = None
556
+
557
  deepspeed: Optional[Union[str, Dict[str, Any]]] = None
558
  fsdp: Optional[List[str]] = None
559
  fsdp_config: Optional[Dict[str, Any]] = None
src/axolotl/utils/models.py CHANGED
@@ -390,6 +390,16 @@ def load_model(
390
  "Shifted-sparse attention not currently implemented without flash attention."
391
  )
392
 
 
 
 
 
 
 
 
 
 
 
393
  # Modify mistral derived models
394
  if (
395
  cfg.model_config_type == "mistral"
@@ -828,6 +838,15 @@ def load_model(
828
  if cfg.adapter is not None:
829
  log_gpu_memory_usage(LOG, "after adapters", model.device)
830
 
 
 
 
 
 
 
 
 
 
831
  # TODO resume_from_checkpoint handling
832
  return model, lora_config
833
 
 
390
  "Shifted-sparse attention not currently implemented without flash attention."
391
  )
392
 
393
+ if cfg.unsloth_cross_entropy_loss:
394
+ from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
395
+
396
+ integrate_cross_entropy_loss_patch()
397
+
398
+ if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
399
+ from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
400
+
401
+ patch_self_attn_lora()
402
+
403
  # Modify mistral derived models
404
  if (
405
  cfg.model_config_type == "mistral"
 
838
  if cfg.adapter is not None:
839
  log_gpu_memory_usage(LOG, "after adapters", model.device)
840
 
841
+ if cfg.unsloth_lora_mlp:
842
+ from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
843
+
844
+ integrate_lora_mlp_patch(model)
845
+ if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
846
+ from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
847
+
848
+ integrate_lora_patch(model, cfg)
849
+
850
  # TODO resume_from_checkpoint handling
851
  return model, lora_config
852