|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import gc |
|
import math |
|
from typing import Optional, Tuple, List, Union |
|
from ._utils import * |
|
from ._utils import __version__ |
|
from torch.nn.functional import scaled_dot_product_attention |
|
from transformers import __version__ as transformers_version |
|
from transformers.models.llama.modeling_llama import ( |
|
logger, |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
) |
|
from transformers.modeling_attn_mask_utils import ( |
|
_prepare_4d_causal_attention_mask_for_sdpa, |
|
) |
|
from ..kernels import * |
|
from ..tokenizer_utils import * |
|
if HAS_FLASH_ATTENTION: |
|
from flash_attn import flash_attn_func |
|
|
|
|
|
from transformers.models.llama.modeling_llama import ( |
|
LlamaAttention, |
|
LlamaDecoderLayer, |
|
LlamaModel, |
|
LlamaForCausalLM, |
|
) |
|
|
|
|
|
try: |
|
from transformers.models.llama.modeling_llama import ( |
|
LlamaSdpaAttention, |
|
LlamaFlashAttention2, |
|
) |
|
except: |
|
LlamaSdpaAttention = LlamaAttention |
|
LlamaFlashAttention2 = LlamaAttention |
|
pass |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig |
|
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING |
|
from transformers import set_seed as transformers_set_seed |
|
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model |
|
from peft import PeftModelForCausalLM |
|
from ..save import patch_saving_functions |
|
import re, os, inspect, math, sys |
|
try: |
|
from huggingface_hub.utils import get_token |
|
except: |
|
|
|
from huggingface_hub.utils._token import get_token |
|
pass |
|
from triton import __version__ as triton_version |
|
BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if xformers is not None else None |
|
|
|
|
|
def original_apply_qkv(self, X): |
|
Q = self.q_proj(X) |
|
K = self.k_proj(X) |
|
V = self.v_proj(X) |
|
return Q, K, V |
|
pass |
|
|
|
|
|
def original_apply_o(self, X): |
|
O = self.o_proj(X) |
|
return O |
|
pass |
|
|
|
from math import sqrt as math_sqrt |
|
KV_CACHE_INCREMENT = 256 |
|
torch_nn_functional_softmax = torch.nn.functional.softmax |
|
|
|
|
|
def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,): |
|
if "past_key_values" in kwargs: |
|
input_ids = input_ids[:,[-1]] |
|
kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]] |
|
if "cache_position" in kwargs: |
|
kwargs["position_ids"] = kwargs["cache_position"] |
|
return { "input_ids" : input_ids, **kwargs, } |
|
pass |
|
|
|
|
|
def fix_prepare_inputs_for_generation(module): |
|
|
|
if hasattr(module, "prepare_inputs_for_generation"): |
|
module.prepare_inputs_for_generation = _fast_prepare_inputs_for_generation |
|
pass |
|
pass |
|
|
|
torch_matmul = torch.matmul |
|
def LlamaAttention_fast_forward_inference( |
|
self, |
|
hidden_states: torch.Tensor, |
|
past_key_value: Optional[Tuple[torch.Tensor]], |
|
position_ids, |
|
do_prefill = False, |
|
attention_mask = None, |
|
): |
|
""" |
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L406 |
|
Fast inference using KV cache. |
|
QK^T can be computed in 4 chunks |
|
|
|
[Q, q] @ [K, k].T where q, k are the new tokens. |
|
[QK^T, Qk^T] |
|
[qK^T, qk^T] |
|
|
|
Since the attention mask wipes Qk^T, we just get |
|
[QK^T, 0] |
|
[qK^T, qk^T] |
|
|
|
Since softmax is row-wise, we get |
|
softmax([QK^T, 0]) |
|
softmax([qK^T, qk^T]) |
|
|
|
We then multiply by [V] |
|
[v] |
|
softmax([QK^T, 0]) [softmax(QK^T)V] * |
|
softmax([qK^T, qk^T]) [softmax([qK^T, qk^T]) @ [V, v]] |
|
|
|
But notice * [softmax(QK^T)V] is just the last attention. |
|
We just need to compute the last final row. |
|
|
|
This means we can pass in a row of Q, but we need to |
|
remember K and V, which are called the KV cache. |
|
""" |
|
Xn = hidden_states |
|
bsz, _, hd = hidden_states.size() |
|
K1, V1 = past_key_value |
|
dtype = Xn.dtype |
|
|
|
n_heads = self.num_heads |
|
n_groups = self.num_key_value_groups |
|
n_kv_heads = self.num_key_value_heads |
|
head_dim = self.head_dim |
|
attention_size = n_heads*head_dim |
|
|
|
seq_len = K1.shape[-2] |
|
kv_seq_len = seq_len + 1 |
|
|
|
|
|
|
|
if do_prefill: |
|
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0") |
|
self.paged_attention_K = self.paged_attention[:,0] |
|
self.paged_attention_V = self.paged_attention[:,1] |
|
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) |
|
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) |
|
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") |
|
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") |
|
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") |
|
|
|
|
|
if attention_size != self.hidden_size: |
|
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") |
|
else: |
|
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size] |
|
pass |
|
|
|
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") |
|
self.scalar = 1.0 / math_sqrt(self.head_dim) |
|
self.half_head_dim = head_dim // 2 |
|
elif kv_seq_len >= self.paged_attention.shape[0]: |
|
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim)) |
|
self.paged_attention_K = self.paged_attention[:,0] |
|
self.paged_attention_V = self.paged_attention[:,1] |
|
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT)) |
|
pass |
|
|
|
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) |
|
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) |
|
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) |
|
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) |
|
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) |
|
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) |
|
cos, sin = self.rotary_emb.get_cached(kv_seq_len) |
|
cos = cos[position_ids].unsqueeze(1) |
|
sin = sin[position_ids].unsqueeze(1) |
|
h = self.half_head_dim |
|
|
|
RH_Q = self.RH_Q |
|
RH_Q[:,:,:,:h] = Qn[:,:,:,h:] |
|
RH_Q[:,:,:,h:] = Qn[:,:,:,:h] |
|
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h]) |
|
Qn *= cos |
|
Qn.addcmul_(RH_Q, sin) |
|
|
|
RH_K = RH_Q[:,:n_kv_heads,:,:] |
|
RH_K[:,:,:,:h] = Kn[:,:,:,h:] |
|
RH_K[:,:,:,h:] = Kn[:,:,:,:h] |
|
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) |
|
Kn *= cos |
|
Kn.addcmul_(RH_K, sin) |
|
|
|
|
|
|
|
|
|
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3) |
|
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3) |
|
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3) |
|
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) |
|
|
|
|
|
sliding_window = getattr(self.config, "sliding_window", None) |
|
if sliding_window is not None and kv_seq_len > sliding_window: |
|
|
|
slicing_tokens = 1 - sliding_window |
|
Knn = Kn[:, :, slicing_tokens:, :] |
|
Vnn = Vn[:, :, slicing_tokens:, :] |
|
else: |
|
Knn, Vnn = Kn, Vn |
|
pass |
|
|
|
|
|
_, _, cached_len, _ = Knn.shape |
|
if n_groups != 1: |
|
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) |
|
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) |
|
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) |
|
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim) |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
if bsz == 1: |
|
Qn *= self.scalar |
|
|
|
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) |
|
|
|
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) |
|
A = torch_matmul(A, Vnn, out = Qn) |
|
else: |
|
A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) |
|
pass |
|
A = A.transpose(1, 2) |
|
A = A.reshape(bsz, 1, attention_size) |
|
A = fast_linear_forward(self.o_proj, A, out = self.temp_O) |
|
return A, (Kn, Vn) |
|
pass |
|
|
|
|
|
torch_nn_functional_silu = torch.nn.functional.silu |
|
def fast_swiglu_inference(self, X): |
|
|
|
|
|
bsz, _, hd = X.shape |
|
|
|
|
|
|
|
gate = fast_linear_forward(self.gate_proj, X) |
|
up = fast_linear_forward(self. up_proj, X) |
|
gate = torch_nn_functional_silu(gate, inplace = True) |
|
gate *= up |
|
|
|
|
|
down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd]) |
|
return down |
|
pass |
|
|
|
|
|
def fast_rms_layernorm_inference(self, X): |
|
old_dtype = X.dtype |
|
XX = X.to(torch.float32) |
|
variance = XX.square().mean(-1, keepdim = True) |
|
variance += self.variance_epsilon |
|
XX *= variance.rsqrt_() |
|
X = XX.to(old_dtype) |
|
X *= self.weight |
|
return X |
|
pass |
|
|
|
|
|
def fast_rms_layernorm_inference_gemma(self, X, out_weight = None): |
|
XX = X.to(torch.float32) |
|
variance = XX.square().mean(-1, keepdim = True) |
|
variance += self.variance_epsilon |
|
XX *= variance.rsqrt_() |
|
|
|
if out_weight is None: |
|
out_weight = self.weight + 1.0 |
|
else: |
|
out_weight[:] = self.weight |
|
out_weight += 1.0 |
|
pass |
|
|
|
XX *= out_weight |
|
return XX.to(X.dtype) |
|
pass |
|
|
|
|
|
|
|
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) |
|
def fast_layernorm_compiled(layernorm, X): |
|
old_dtype = X.dtype |
|
X = X.float() |
|
mean = X.mean(-1, keepdim = True) |
|
Xbar = X - mean |
|
X = Xbar * torch.rsqrt(Xbar.square().mean(-1, keepdim = True) + \ |
|
layernorm.variance_epsilon) * \ |
|
layernorm.weight.float() |
|
return X.to(old_dtype) |
|
pass |
|
|
|
|
|
|
|
def LlamaAttention_fast_forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
causal_mask: Optional[BlockDiagonalCausalMask] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
padding_mask: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
*args, **kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
|
|
|
if hasattr(self, "paged_attention"): |
|
del self.paged_attention_K |
|
del self.paged_attention_V |
|
del self.paged_attention |
|
del self.temp_QA |
|
del self.temp_KV |
|
del self.RH_Q |
|
del self.attention |
|
pass |
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
n_heads = self.num_heads |
|
n_groups = self.num_key_value_groups |
|
n_kv_heads = self.num_key_value_heads |
|
head_dim = self.head_dim |
|
assert(n_kv_heads * n_groups == n_heads) |
|
|
|
Q, K, V = self.apply_qkv(self, hidden_states) |
|
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) |
|
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) |
|
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) |
|
|
|
kv_seq_len = K.shape[-2] |
|
if past_key_value is not None: |
|
kv_seq_len += past_key_value[0].shape[-2] |
|
|
|
if position_embeddings: |
|
cos, sin = position_embeddings |
|
else: |
|
|
|
rotary_emb = self.rotary_emb |
|
rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len) |
|
|
|
if position_ids is None: |
|
|
|
cos, sin = rotary_emb.get_cached(kv_seq_len) |
|
else: |
|
cos, sin = rotary_emb(V, seq_len=kv_seq_len) |
|
|
|
Q, K = ( |
|
fast_rope_embedding(Q, K, cos, sin) |
|
if position_ids is None |
|
else inplace_rope_embedding(Q, K, cos, sin, position_ids) |
|
) |
|
|
|
if past_key_value is not None: |
|
K = torch.cat([past_key_value[0], K], dim = 2) |
|
V = torch.cat([past_key_value[1], V], dim = 2) |
|
pass |
|
past_key_value = (K, V) if use_cache else None |
|
|
|
|
|
if (not HAS_FLASH_ATTENTION and attention_mask is None): |
|
|
|
|
|
Q = Q.transpose(1, 2) |
|
K = K.transpose(1, 2) |
|
V = V.transpose(1, 2) |
|
|
|
|
|
if n_groups != 1: |
|
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) |
|
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) |
|
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) |
|
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) |
|
if hidden_states.requires_grad: |
|
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) |
|
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) |
|
else: |
|
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) |
|
pass |
|
A = xformers_attention(Q, K, V, attn_bias = causal_mask) |
|
A = A.view(bsz, q_len, n_heads, head_dim) |
|
|
|
elif HAS_FLASH_ATTENTION and attention_mask is None: |
|
Q = Q.transpose(1, 2) |
|
K = K.transpose(1, 2) |
|
V = V.transpose(1, 2) |
|
A = flash_attn_func(Q, K, V, causal = True) |
|
else: |
|
|
|
if n_groups != 1: |
|
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) |
|
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) |
|
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) |
|
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) |
|
pass |
|
|
|
|
|
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() |
|
|
|
|
|
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) |
|
|
|
A = A.transpose(1, 2).contiguous() |
|
pass |
|
attn_output = A.reshape(bsz, q_len, n_heads*head_dim) |
|
attn_output = self.apply_o(self, attn_output) |
|
attn_weights = None |
|
return attn_output, attn_weights, past_key_value |
|
pass |
|
|
|
|
|
|
|
def LlamaDecoderLayer_fast_forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
causal_mask = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
padding_mask: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
*args, **kwargs, |
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
(see `past_key_values`). |
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
|
""" |
|
if use_cache and hasattr(self, "_flag_for_generation"): |
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states) |
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states = hidden_states, |
|
causal_mask = causal_mask, |
|
attention_mask = attention_mask, |
|
position_ids = position_ids, |
|
past_key_value = past_key_value, |
|
output_attentions = output_attentions, |
|
use_cache = use_cache, |
|
padding_mask = padding_mask, |
|
position_embeddings = position_embeddings, |
|
) |
|
hidden_states += residual |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states) |
|
hidden_states = fast_swiglu_inference(self.mlp, hidden_states) |
|
hidden_states += residual |
|
else: |
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) |
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states = hidden_states, |
|
causal_mask = causal_mask, |
|
attention_mask = attention_mask, |
|
position_ids = position_ids, |
|
past_key_value = past_key_value, |
|
output_attentions = output_attentions, |
|
use_cache = use_cache, |
|
padding_mask = padding_mask, |
|
position_embeddings = position_embeddings, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
pass |
|
|
|
outputs = (hidden_states,) |
|
if output_attentions: outputs += (self_attn_weights,) |
|
if use_cache: outputs += (present_key_value,) |
|
return outputs |
|
pass |
|
|
|
|
|
|
|
__DTYPE_MAP = { |
|
"float32": torch.float32, |
|
torch.float32: torch.float32, |
|
"float16": torch.float16, |
|
torch.float16: torch.float16, |
|
"bfloat16": torch.bfloat16, |
|
torch.bfloat16: torch.bfloat16, |
|
} |
|
|
|
|
|
def LlamaModel_fast_forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
causal_mask: Optional[BlockDiagonalCausalMask] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
*args, **kwargs, |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
assert(output_attentions is False) |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError("Unsloth: You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
|
elif input_ids is not None: |
|
batch_size, seq_length = input_ids.shape |
|
elif inputs_embeds is not None: |
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
else: |
|
raise ValueError("Unsloth: You have to specify either decoder_input_ids or decoder_inputs_embeds") |
|
|
|
seq_length_with_past = seq_length |
|
|
|
|
|
if hasattr(self, "max_seq_length"): |
|
if seq_length > self.max_seq_length: |
|
logger.warning_once( |
|
f"Unsloth: Input IDs of length {seq_length} > the model's max sequence length of {self.max_seq_length}.\n"\ |
|
"We shall truncate it ourselves. It's imperative if you correct this issue first." |
|
) |
|
if input_ids is not None: |
|
input_ids = input_ids[:,:self.max_seq_length] |
|
elif inputs_embeds is not None: |
|
inputs_embeds = inputs_embeds[:,:self.max_seq_length,:] |
|
pass |
|
pass |
|
|
|
past_key_values_length = 0 |
|
|
|
if past_key_values is not None: |
|
past_key_values_length = past_key_values[0][0].shape[2] |
|
seq_length_with_past = seq_length_with_past + past_key_values_length |
|
pass |
|
|
|
|
|
if False: |
|
position_ids = torch.arange( |
|
past_key_values_length, seq_length + past_key_values_length, |
|
dtype = torch.int32, |
|
device = "cuda:0", |
|
) |
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
|
elif position_ids is not None: |
|
position_ids = position_ids.view(-1, seq_length).to(torch.int32) |
|
else: |
|
position_ids = None |
|
pass |
|
|
|
if position_ids is not None: |
|
if position_ids.shape[0] != batch_size: |
|
position_ids = position_ids.repeat((batch_size, 1)) |
|
pass |
|
|
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) |
|
if torch_dtype is not None: |
|
inputs_embeds = inputs_embeds.to(torch_dtype) |
|
else: |
|
raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") |
|
pass |
|
|
|
|
|
IS_GEMMA = self.config.model_type.startswith("gemma") |
|
IS_GEMMA2 = self.config.model_type.startswith("gemma2") |
|
IS_COHERE = self.config.model_type.startswith("cohere") |
|
IS_GRANITE = self.config.model_type.startswith("granite") |
|
train_embed_tokens = self.embed_tokens.weight.requires_grad |
|
|
|
if IS_GEMMA: |
|
|
|
|
|
|
|
|
|
normalizer = torch.tensor(math_sqrt(self.config.hidden_size), dtype = inputs_embeds.dtype) |
|
|
|
if train_embed_tokens: |
|
|
|
inputs_embeds = inputs_embeds * normalizer |
|
else: |
|
inputs_requires_grad = inputs_embeds.requires_grad |
|
if not inputs_embeds.is_leaf: |
|
inputs_embeds = inputs_embeds.detach() |
|
inputs_requires_grad = True |
|
elif inputs_requires_grad: |
|
inputs_embeds.requires_grad_(False) |
|
pass |
|
inputs_embeds *= normalizer |
|
|
|
if inputs_requires_grad: inputs_embeds.requires_grad_(True) |
|
pass |
|
pass |
|
|
|
|
|
|
|
if self._has_no_labels and (attention_mask is not None) and (past_key_values is None) and \ |
|
(not train_embed_tokens): |
|
|
|
|
|
inputs_requires_grad = inputs_embeds.requires_grad |
|
if not inputs_embeds.is_leaf: |
|
inputs_embeds = inputs_embeds.detach() |
|
inputs_requires_grad = True |
|
elif inputs_requires_grad: |
|
inputs_embeds.requires_grad_(False) |
|
pass |
|
inputs_embeds *= attention_mask.unsqueeze(0).transpose(0, 1).transpose(1, 2) |
|
if inputs_requires_grad: inputs_embeds.requires_grad_(True) |
|
pass |
|
|
|
|
|
if attention_mask is None: |
|
padding_mask = None |
|
elif self.training: |
|
attention_mask = None |
|
padding_mask = None |
|
else: |
|
|
|
|
|
|
|
padding_mask = None |
|
|
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
|
attention_mask, |
|
(batch_size, seq_length), |
|
inputs_embeds, |
|
past_key_values_length, |
|
sliding_window = getattr(self.config, "sliding_window", None), |
|
) |
|
pass |
|
|
|
hidden_states = inputs_embeds |
|
if IS_GRANITE: |
|
hidden_states = self.embedding_multiplier * hidden_states |
|
|
|
if past_key_values is None and self.training: |
|
use_cache = False |
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = () if use_cache else None |
|
|
|
|
|
if hasattr(self, "_gradient_checkpointing_boundaries"): |
|
boundaries = self._gradient_checkpointing_boundaries |
|
else: |
|
boundaries = None |
|
pass |
|
|
|
|
|
gradient_checkpointing = False |
|
offloaded_gradient_checkpointing = False |
|
|
|
if (self.gradient_checkpointing and self.training and not use_cache): |
|
|
|
gradient_checkpointing = True |
|
|
|
if output_attentions is False and hasattr(self, "_offloaded_gradient_checkpointing"): |
|
offloaded_gradient_checkpointing = True |
|
pass |
|
|
|
|
|
use_static_mask = True |
|
dynamic_SWA_mask = None |
|
dynamic_GA_mask = None |
|
if IS_GEMMA2: |
|
if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None: |
|
self.SWA_mask = True |
|
self.GA_mask = False |
|
elif attention_mask is not None: |
|
|
|
|
|
|
|
dynamic_SWA_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
|
attention_mask, |
|
(batch_size, seq_length), |
|
inputs_embeds, |
|
past_key_values_length, |
|
sliding_window = self.config.sliding_window, |
|
)[0][0] |
|
dynamic_GA_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
|
attention_mask, |
|
(batch_size, seq_length), |
|
inputs_embeds, |
|
past_key_values_length, |
|
sliding_window = None, |
|
)[0][0] |
|
use_static_mask = False |
|
|
|
elif not hasattr(self, "SWA_mask"): |
|
if HAS_FLEX_ATTENTION: |
|
|
|
self.SWA_mask = create_flex_attention_sliding_window_mask(self.max_seq_length, self.config.sliding_window) |
|
self.GA_mask = create_flex_attention_causal_mask(self.max_seq_length) |
|
else: |
|
n = self.max_seq_length |
|
|
|
|
|
|
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
self.SWA_mask = AttentionMaskConverter( |
|
is_causal = True, |
|
sliding_window = self.config.sliding_window, |
|
)\ |
|
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ |
|
.squeeze(0).squeeze(0) |
|
|
|
self.GA_mask = AttentionMaskConverter( |
|
is_causal = True, |
|
)\ |
|
.to_causal_4d(1, n, n, dtype = inputs_embeds.dtype, device = "cuda:0",)\ |
|
.squeeze(0).squeeze(0) |
|
pass |
|
pass |
|
pass |
|
|
|
if transformers_version > "4.47.1" and hasattr(self, "rotary_emb"): |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) |
|
else: |
|
position_embeddings = None |
|
|
|
|
|
for idx, decoder_layer in enumerate(self.layers): |
|
|
|
if output_hidden_states: all_hidden_states += (hidden_states,) |
|
past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
|
mask = causal_mask |
|
if IS_GEMMA2: |
|
if (idx % 2 == 0): |
|
mask = self.SWA_mask if use_static_mask else dynamic_SWA_mask |
|
else: |
|
mask = self. GA_mask if use_static_mask else dynamic_GA_mask |
|
pass |
|
|
|
if offloaded_gradient_checkpointing: |
|
hidden_states = Unsloth_Offloaded_Gradient_Checkpointer.apply( |
|
decoder_layer, |
|
hidden_states, |
|
mask, |
|
attention_mask, |
|
position_ids, |
|
past_key_values, |
|
output_attentions, |
|
use_cache, |
|
None, |
|
position_embeddings, |
|
)[0] |
|
|
|
elif gradient_checkpointing: |
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings) |
|
return custom_forward |
|
pass |
|
|
|
layer_outputs = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(decoder_layer), |
|
hidden_states, |
|
mask, |
|
attention_mask, |
|
position_ids, |
|
use_reentrant = True, |
|
preserve_rng_state = False, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
causal_mask=mask, |
|
attention_mask = attention_mask, |
|
position_ids = position_ids, |
|
past_key_value = past_key_value, |
|
output_attentions = output_attentions, |
|
use_cache = use_cache, |
|
padding_mask = padding_mask, |
|
position_embeddings = position_embeddings, |
|
) |
|
hidden_states = layer_outputs[0] |
|
pass |
|
|
|
if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
|
if output_attentions: all_self_attns += (layer_outputs[1],) |
|
pass |
|
|
|
|
|
if use_cache: |
|
hidden_states = \ |
|
(fast_rms_layernorm_inference_gemma if IS_GEMMA else fast_rms_layernorm_inference)\ |
|
(self.norm, hidden_states) |
|
elif IS_COHERE: |
|
hidden_states = self.norm(hidden_states) |
|
else: |
|
hidden_states = fast_rms_layernorm(self.norm, hidden_states, gemma = IS_GEMMA) |
|
pass |
|
|
|
if output_hidden_states: all_hidden_states += (hidden_states,) |
|
next_cache = next_decoder_cache if use_cache else None |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
pass |
|
|
|
|
|
|
|
def LlamaModel_fast_forward_inference( |
|
self, |
|
input_ids, |
|
past_key_values, |
|
position_ids, |
|
attention_mask = None, |
|
): |
|
input_ids = input_ids[:,:self.max_seq_length] |
|
hidden_states = self.model.embed_tokens(input_ids) |
|
hidden_states = hidden_states.to(self.config.torch_dtype) |
|
bsz, q_len, hd = hidden_states.shape |
|
seq_len = past_key_values[0][0].shape[-2] |
|
if bsz != 1: |
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
|
attention_mask, |
|
(bsz, q_len), |
|
hidden_states, |
|
seq_len, |
|
sliding_window = getattr(self.config, "sliding_window", None), |
|
) |
|
else: |
|
attention_mask = None |
|
pass |
|
|
|
next_decoder_cache = [] |
|
for idx, decoder_layer in enumerate(self.model.layers): |
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) |
|
hidden_states, present_key_value = LlamaAttention_fast_forward_inference( |
|
decoder_layer.self_attn, |
|
hidden_states = hidden_states, |
|
past_key_value = past_key_values[idx], |
|
position_ids = position_ids, |
|
attention_mask = attention_mask, |
|
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), |
|
) |
|
hidden_states += residual |
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states) |
|
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) |
|
hidden_states += residual |
|
|
|
next_decoder_cache.append(present_key_value) |
|
pass |
|
hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states) |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state = hidden_states, |
|
past_key_values = next_decoder_cache, |
|
hidden_states = [], |
|
attentions = [], |
|
) |
|
pass |
|
|
|
|
|
def CausalLM_fast_forward(fast_forward_inference): |
|
def _CausalLM_fast_forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
causal_mask: Optional[BlockDiagonalCausalMask] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
num_logits_to_keep: Optional[int] = 0, |
|
*args, **kwargs, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
if past_key_values is not None: |
|
outputs = fast_forward_inference( |
|
self, |
|
input_ids, |
|
past_key_values, |
|
position_ids = position_ids, |
|
attention_mask = attention_mask, |
|
) |
|
else: |
|
causal_mask = xformers.attn_bias.LowerTriangularMask() |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
self.model._has_no_labels = labels is None |
|
outputs = self.model( |
|
input_ids=input_ids, |
|
causal_mask=causal_mask, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
pass |
|
hidden_states = outputs[0] |
|
|
|
bsz, q_len, hd = hidden_states.shape |
|
lm_head = self.lm_head.weight |
|
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) |
|
logit_scaling = getattr(self.config, "logit_scale", 0) |
|
|
|
if bsz == 1 and q_len == 1: |
|
logits = torch.mv(lm_head, hidden_states.ravel().to(lm_head.dtype)) |
|
logits = logits.unsqueeze(0).unsqueeze(0) |
|
elif num_logits_to_keep != 0: |
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype)) |
|
else: |
|
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" |
|
|
|
if bsz*q_len <= 1024: RETURN_LOGITS = True |
|
|
|
if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: |
|
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) |
|
loss = fused_linear_cross_entropy( |
|
hidden_states = hidden_states, |
|
lm_weight = lm_head, |
|
labels = labels, |
|
num_items_in_batch = n_items, |
|
logit_softcapping = logit_softcapping, |
|
) |
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
output = CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=EMPTY_LOGITS, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
return output |
|
pass |
|
logits = self.lm_head(hidden_states.to(lm_head.dtype)) |
|
pass |
|
|
|
torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) |
|
if torch_dtype is not None: |
|
logits = logits.to(torch_dtype) |
|
else: |
|
raise TypeError("Unsloth: torch_dtype for models is not bfloat16, float16 or float32!") |
|
pass |
|
|
|
loss = None |
|
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) |
|
logit_scaling = getattr(self.config, "logit_scale", 0) |
|
if self.config.model_type == "granite": |
|
|
|
|
|
|
|
|
|
logit_scaling = 1 / getattr(self.config, "logits_scaling", 1) |
|
|
|
if labels is not None: |
|
shift_logits = logits |
|
if not hasattr(self, "extra_ignored_labels"): |
|
|
|
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") |
|
pass |
|
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) |
|
loss = fast_cross_entropy_loss( |
|
logits = shift_logits, |
|
labels = shift_labels, |
|
logit_softcapping = logit_softcapping, |
|
logit_scaling = logit_scaling, |
|
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None), |
|
) |
|
else: |
|
if logit_scaling != 0: |
|
if logits.requires_grad: |
|
logits = logit_scaling * logits |
|
else: |
|
logits *= logit_scaling |
|
pass |
|
pass |
|
if logit_softcapping != 0: |
|
if logits.requires_grad: |
|
logits = (1.0 / logit_softcapping) * logits |
|
logits = torch.tanh(logits) |
|
logits = logit_softcapping * logits |
|
else: |
|
logits *= (1.0 / logit_softcapping) |
|
torch.tanh(logits, out = logits) |
|
logits *= logit_softcapping |
|
pass |
|
pass |
|
pass |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
pass |
|
return _CausalLM_fast_forward |
|
pass |
|
|
|
|
|
@torch._disable_dynamo |
|
def PeftModelForCausalLM_fast_forward( |
|
self, |
|
input_ids=None, |
|
causal_mask=None, |
|
attention_mask=None, |
|
inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
task_ids=None, |
|
num_logits_to_keep=0, |
|
**kwargs, |
|
): |
|
return self.base_model( |
|
input_ids=input_ids, |
|
causal_mask=causal_mask, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
num_logits_to_keep=num_logits_to_keep, |
|
**kwargs, |
|
) |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LlamaRotaryEmbedding(torch.nn.Module): |
|
|
|
|
|
|
|
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, |
|
config = None, |
|
): |
|
super().__init__() |
|
if config is not None: |
|
|
|
base = config.rope_theta |
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 |
|
dim = int((config.hidden_size // config.num_attention_heads)) |
|
device = "cuda" |
|
max_position_embeddings = config.max_position_embeddings |
|
pass |
|
|
|
self.dim = dim |
|
self.max_position_embeddings = max_position_embeddings |
|
self.base = base |
|
|
|
self.current_rope_size = min(4 * 8192, self.max_position_embeddings) |
|
|
|
|
|
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) |
|
pass |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
|
|
|
|
self.current_rope_size = seq_len |
|
inv_freq = 1.0 / ( |
|
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim) |
|
) |
|
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float() |
|
|
|
freqs = torch.outer(t, inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) |
|
pass |
|
|
|
def forward(self, x, position_ids=None, seq_len=None): |
|
|
|
if seq_len > self.current_rope_size: |
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
|
|
|
return ( |
|
self.cos_cached[:seq_len].to(dtype = x.dtype), |
|
self.sin_cached[:seq_len].to(dtype = x.dtype), |
|
) |
|
pass |
|
|
|
def get_cached(self, seq_len = None): |
|
return self.cos_cached, self.sin_cached |
|
pass |
|
|
|
def extend_rope_embedding(self, x, seq_len): |
|
if seq_len <= self.current_rope_size: return |
|
|
|
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 |
|
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype) |
|
pass |
|
pass |
|
|
|
|
|
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): |
|
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" |
|
|
|
|
|
|
|
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, |
|
config = None, |
|
): |
|
self.scaling_factor = scaling_factor |
|
super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config) |
|
pass |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
self.current_rope_size = seq_len |
|
inv_freq = 1.0 / ( |
|
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim) |
|
) |
|
t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float() |
|
t = t / self.scaling_factor |
|
|
|
freqs = torch.outer(t, inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) |
|
pass |
|
pass |
|
|
|
|
|
|
|
|
|
class LlamaExtendedRotaryEmbedding(torch.nn.Module): |
|
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, |
|
config = None, |
|
): |
|
super().__init__() |
|
if config is not None: |
|
|
|
base = config.rope_theta |
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 |
|
dim = int((config.hidden_size // config.num_attention_heads)) |
|
device = "cuda" |
|
max_position_embeddings = config.max_position_embeddings |
|
pass |
|
|
|
self.dim = dim |
|
self.max_position_embeddings = max_position_embeddings |
|
self.base = base |
|
|
|
self.current_rope_size = min(4 * 8192, self.max_position_embeddings) |
|
|
|
|
|
inv_freq = 1.0 / ( |
|
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim) |
|
) |
|
inv_freq = self.apply_scaling(inv_freq) |
|
self.register_buffer("inv_freq", inv_freq, persistent = False) |
|
|
|
|
|
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) |
|
pass |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
|
|
|
|
self.current_rope_size = seq_len |
|
|
|
t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float() |
|
|
|
freqs = torch.outer(t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) |
|
pass |
|
|
|
|
|
def apply_scaling(self, freqs: torch.Tensor): |
|
|
|
scale_factor = 8 |
|
low_freq_factor = 1 |
|
high_freq_factor = 4 |
|
old_context_len = 8192 |
|
|
|
low_freq_wavelen = old_context_len / low_freq_factor |
|
high_freq_wavelen = old_context_len / high_freq_factor |
|
new_freqs = [] |
|
for freq in freqs: |
|
wavelen = 2 * math.pi / freq |
|
if wavelen < high_freq_wavelen: |
|
new_freqs.append(freq) |
|
elif wavelen > low_freq_wavelen: |
|
new_freqs.append(freq / scale_factor) |
|
else: |
|
assert low_freq_wavelen != high_freq_wavelen |
|
smooth = (old_context_len / wavelen - low_freq_factor) / ( |
|
high_freq_factor - low_freq_factor |
|
) |
|
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) |
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) |
|
pass |
|
|
|
def forward(self, x, position_ids=None, seq_len=None): |
|
|
|
if seq_len > self.current_rope_size: |
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
|
|
|
return ( |
|
self.cos_cached[:seq_len].to(dtype = x.dtype), |
|
self.sin_cached[:seq_len].to(dtype = x.dtype), |
|
) |
|
pass |
|
|
|
def get_cached(self, seq_len = None): |
|
return self.cos_cached, self.sin_cached |
|
pass |
|
|
|
def extend_rope_embedding(self, x, seq_len): |
|
if seq_len <= self.current_rope_size: return |
|
|
|
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 |
|
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype) |
|
pass |
|
pass |
|
|
|
|
|
class LongRopeRotaryEmbedding(torch.nn.Module): |
|
|
|
def __init__(self, |
|
dim = None, |
|
max_position_embeddings = 131072, |
|
original_max_position_embeddings = 4096, |
|
base = 10000, |
|
short_factor = None, |
|
long_factor = None, |
|
device = None, |
|
config = None, |
|
): |
|
super().__init__() |
|
assert(short_factor is not None) |
|
assert(long_factor is not None) |
|
assert(type(original_max_position_embeddings) is int) |
|
|
|
if config is not None: |
|
|
|
base = config.rope_theta |
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 |
|
dim = int((config.hidden_size // config.num_attention_heads)) |
|
device = "cuda" |
|
max_position_embeddings = config.max_position_embeddings |
|
pass |
|
|
|
self.dim = dim |
|
self.max_position_embeddings = max_position_embeddings |
|
self.original_max_position_embeddings = original_max_position_embeddings |
|
self.base = base |
|
|
|
self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings) |
|
|
|
|
|
|
|
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device="cpu").float() / self.dim |
|
short_factor = torch.tensor(short_factor, device = "cpu", dtype = torch.float32) |
|
long_factor = torch.tensor(long_factor, device = "cpu", dtype = torch.float32) |
|
short_inv_freq = 1.0 / (short_factor * self.base**inv_freq_shape) |
|
long_inv_freq = 1.0 / (long_factor * self.base**inv_freq_shape) |
|
|
|
|
|
scale = self.max_position_embeddings / self.original_max_position_embeddings |
|
if scale <= 1.0: |
|
scaling_factor = 1.0 |
|
else: |
|
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) |
|
pass |
|
self.scaling_factor = scaling_factor |
|
|
|
|
|
self.register_buffer("short_inv_freq", short_inv_freq, persistent = False) |
|
self.register_buffer("long_inv_freq", long_inv_freq, persistent = False) |
|
|
|
|
|
|
|
|
|
dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 |
|
t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float() |
|
freqs = torch.outer(t, self.short_inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) |
|
sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) |
|
self.register_buffer("short_cos_cached", cos_cached, persistent=False) |
|
self.register_buffer("short_sin_cached", sin_cached, persistent=False) |
|
pass |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
|
|
|
|
self.current_rope_size = seq_len |
|
|
|
t = torch.arange(self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64).float() |
|
|
|
freqs = torch.outer(t, self.long_inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) |
|
sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) |
|
self.register_buffer("long_cos_cached", cos_cached, persistent=False) |
|
self.register_buffer("long_sin_cached", sin_cached, persistent=False) |
|
pass |
|
|
|
def forward(self, x, position_ids=None, seq_len=None): |
|
|
|
if seq_len > self.current_rope_size: |
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
|
|
|
if seq_len < self.original_max_position_embeddings: |
|
return ( |
|
self.short_cos_cached[:seq_len].to(dtype = x.dtype), |
|
self.short_sin_cached[:seq_len].to(dtype = x.dtype), |
|
) |
|
else: |
|
return ( |
|
self.long_cos_cached[:seq_len].to(dtype = x.dtype), |
|
self.long_sin_cached[:seq_len].to(dtype = x.dtype), |
|
) |
|
pass |
|
pass |
|
|
|
def get_cached(self, seq_len = None): |
|
if seq_len < self.original_max_position_embeddings: |
|
return self.short_cos_cached, self.short_sin_cached |
|
return self.long_cos_cached, self.long_sin_cached |
|
pass |
|
|
|
def extend_rope_embedding(self, x, seq_len): |
|
if seq_len <= self.current_rope_size: return |
|
|
|
self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 |
|
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype) |
|
pass |
|
pass |
|
|
|
|
|
def _wrap_fast_inference(generate, device_type, dtype, model): |
|
|
|
@torch.inference_mode |
|
def _fast_generate(*args, **kwargs): |
|
|
|
if hasattr(model, "config") and hasattr(model.config, "max_position_embeddings"): |
|
if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs: |
|
if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > model.config.max_position_embeddings: |
|
raise ValueError( |
|
f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {model.config.max_position_embeddings}!\n'\ |
|
'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.' |
|
) |
|
pass |
|
|
|
|
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
internal_model._flag_for_generation = True |
|
internal_model = internal_model.model |
|
pass |
|
internal_model._flag_for_generation = True |
|
|
|
|
|
if accelerate_new_send_to_device is not None: |
|
import accelerate.utils.operations |
|
accelerate.utils.operations.send_to_device = accelerate_new_send_to_device |
|
pass |
|
|
|
|
|
kwargs["cache_implementation"] = "dynamic" |
|
|
|
kwargs["num_logits_to_keep"] = 1 |
|
|
|
|
|
kwargs.pop("token_type_ids", None) |
|
|
|
|
|
model_eos_token_id = getattr(model.config, "eos_token_id", None) |
|
if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"): |
|
model_eos_token_id = model_eos_token_id[0] |
|
|
|
kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.autocast(device_type = device_type, dtype = dtype): |
|
output = generate(*args, **kwargs) |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation |
|
internal_model = internal_model.model |
|
pass |
|
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation |
|
|
|
|
|
if accelerate_new_send_to_device is not None: |
|
accelerate.utils.operations.send_to_device = accelerate_old_send_to_device |
|
pass |
|
|
|
return output |
|
pass |
|
return _fast_generate |
|
pass |
|
|
|
|
|
class FastLlamaModel: |
|
|
|
@staticmethod |
|
def pre_patch(): |
|
init_name, function = patch_llama_rope_scaling( |
|
model_name = "llama", |
|
rope_module = LlamaRotaryEmbedding, |
|
scaled_rope_module = LlamaLinearScalingRotaryEmbedding, |
|
extended_rope_module = LlamaExtendedRotaryEmbedding, |
|
attention_module = LlamaAttention, |
|
longrope_module = LongRopeRotaryEmbedding, |
|
) |
|
if init_name is not None: |
|
exec(function, globals()) |
|
LlamaAttention.__init__ = eval(init_name) |
|
pass |
|
LlamaAttention .forward = LlamaAttention_fast_forward |
|
LlamaSdpaAttention .forward = LlamaAttention_fast_forward |
|
LlamaFlashAttention2.forward = LlamaAttention_fast_forward |
|
LlamaDecoderLayer .forward = LlamaDecoderLayer_fast_forward |
|
LlamaModel .forward = LlamaModel_fast_forward |
|
LlamaForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference) |
|
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward |
|
fix_prepare_inputs_for_generation(LlamaForCausalLM) |
|
|
|
|
|
|
|
|
|
|
|
|
|
import transformers.models.llama.modeling_llama |
|
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = LlamaRotaryEmbedding |
|
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding = LlamaLinearScalingRotaryEmbedding |
|
return |
|
pass |
|
|
|
|
|
@staticmethod |
|
def from_pretrained( |
|
model_name = "unsloth/llama-3-8b-bnb-4bit", |
|
max_seq_length = None, |
|
dtype = None, |
|
load_in_4bit = True, |
|
token = None, |
|
device_map = "sequential", |
|
rope_scaling = None, |
|
fix_tokenizer = True, |
|
model_patcher = None, |
|
tokenizer_name = None, |
|
trust_remote_code = False, |
|
**kwargs, |
|
): |
|
if trust_remote_code: |
|
print( |
|
"Unsloth: WARNING `trust_remote_code` is True.\n"\ |
|
"Are you certain you want to do remote code execution?" |
|
) |
|
pass |
|
if token is None: token = get_token() |
|
if model_patcher is None: model_patcher = FastLlamaModel |
|
SUPPORTS_BFLOAT16 = is_bfloat16_supported() |
|
gpu_stats = torch.cuda.get_device_properties(0) |
|
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) |
|
|
|
statistics = \ |
|
f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.\n"\ |
|
f" \\\ /| GPU: {gpu_stats.name}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ |
|
f"O^O/ \_/ \\ Torch: {torch.__version__}. CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {torch.version.cuda}. Triton: {triton_version}\n"\ |
|
f"\ / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ |
|
f' "-____-" Free Apache license: http://github.com/unslothai/unsloth' |
|
print(statistics) |
|
|
|
|
|
old_hf_transfer = os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") |
|
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") == "1": |
|
print("Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!") |
|
pass |
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer |
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
|
model_patcher.pre_patch() |
|
get_statistics() |
|
|
|
if dtype is None: |
|
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 |
|
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16: |
|
logger.warning_once("Device does not support bfloat16. Will change to float16.") |
|
dtype = torch.float16 |
|
|
|
assert(dtype == torch.float16 or dtype == torch.bfloat16 or dtype == torch.float32) |
|
|
|
|
|
model_config = AutoConfig.from_pretrained(model_name, token = token) |
|
model_max_seq_length = model_config.max_position_embeddings |
|
|
|
|
|
model_function = MODEL_FOR_CAUSAL_LM_MAPPING[model_config.__class__] |
|
has_rope_scaling = False |
|
try: |
|
with open(inspect.getfile(model_function), "r") as file: |
|
has_rope_scaling = "self.config.rope_scaling" in file.read() |
|
except: pass |
|
has_rope_scaling = True |
|
|
|
|
|
if max_seq_length is None: |
|
max_seq_length = model_max_seq_length |
|
pass |
|
|
|
if (rope_scaling is None) and (max_seq_length > model_max_seq_length): |
|
|
|
rope_scaling = max_seq_length / model_max_seq_length |
|
|
|
logger.warning_once( |
|
f"Unsloth: {model_name} can only handle sequence lengths of at most "\ |
|
f"{model_max_seq_length}.\nBut with kaiokendev's RoPE scaling of "\ |
|
f"{round(rope_scaling, 3)}, it can be magically be extended to "\ |
|
f"{max_seq_length}!" |
|
) |
|
|
|
|
|
if not has_rope_scaling: |
|
raise RuntimeError( |
|
"However, {model_name} doesn't support RoPE Scaling!\n"\ |
|
"Please file a feature request at https://github.com/unslothai/unsloth." |
|
) |
|
pass |
|
|
|
rope_scaling = {"type": "linear", "factor": rope_scaling,} |
|
|
|
|
|
kwargs["rope_scaling"] = rope_scaling |
|
pass |
|
|
|
pre_check = check_nvidia() |
|
|
|
bnb_config = None |
|
if load_in_4bit: |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit = True, |
|
bnb_4bit_use_double_quant = True, |
|
bnb_4bit_quant_type = "nf4", |
|
bnb_4bit_compute_dtype = dtype, |
|
) |
|
pass |
|
|
|
|
|
|
|
max_position_embeddings = max(max_seq_length, model_max_seq_length) |
|
kwargs.pop("attn_implementation", None); |
|
|
|
|
|
if load_in_4bit: kwargs["quantization_config"] = bnb_config |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map = device_map, |
|
torch_dtype = dtype, |
|
|
|
token = token, |
|
max_position_embeddings = max_position_embeddings, |
|
trust_remote_code = trust_remote_code, |
|
attn_implementation = "eager", |
|
**kwargs, |
|
) |
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer |
|
|
|
post_check = check_nvidia() |
|
|
|
|
|
tokenizer_name = model_name if tokenizer_name is None else tokenizer_name |
|
tokenizer = load_correct_tokenizer( |
|
tokenizer_name = tokenizer_name, |
|
model_max_length = max_position_embeddings, |
|
padding_side = "right", |
|
token = token, |
|
trust_remote_code = trust_remote_code, |
|
fix_tokenizer = fix_tokenizer, |
|
) |
|
|
|
model, tokenizer = patch_tokenizer(model, tokenizer) |
|
model, tokenizer = model_patcher.post_patch(model, tokenizer) |
|
|
|
|
|
for idx, layer in enumerate(model.model.layers): |
|
layer.self_attn.apply_qkv = original_apply_qkv |
|
layer.self_attn.apply_o = original_apply_o |
|
pass |
|
|
|
|
|
from transformers.trainer import Trainer |
|
try: |
|
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop": |
|
inner_training_loop = inspect.getsource(Trainer._inner_training_loop) |
|
Trainer._original_training_loop = inner_training_loop |
|
else: |
|
inner_training_loop = Trainer._original_training_loop |
|
except: |
|
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!') |
|
pass |
|
|
|
if ((post_check - pre_check) >= 1).sum() > 1: |
|
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!') |
|
|
|
import transformers.trainer |
|
items_in_trainer = dir(transformers.trainer) |
|
good_items = [] |
|
for item in items_in_trainer: |
|
|
|
if item.startswith(("deepspeed", "xm", "met", "smp")): continue |
|
if item in inner_training_loop: good_items.append(item) |
|
pass |
|
exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals()) |
|
|
|
start = re.search('logger\.info\([\"\'].+?Running training', inner_training_loop).span(0)[0] |
|
end = inner_training_loop.find("\n\n", start) |
|
original_debug = inner_training_loop[start:end] |
|
spaces = re.search('\n([\s\t]{1,})', original_debug).group(0)[1:] |
|
front_spaces = re.match('([\s\t]{1,})', inner_training_loop).group(0) |
|
|
|
|
|
|
|
debug_info = """debug_info = \\ |
|
f"==((====))== Unsloth - 2x faster free finetuning | Num GPUs = {args.world_size}\\n"\\ |
|
f" {chr(92)}{chr(92)} /| Num examples = {num_examples:,} | Num Epochs = {num_train_epochs:,}\\n"\\ |
|
f"O^O/ {chr(92)}_/ {chr(92)} Batch size per device = {self._train_batch_size:,} | Gradient Accumulation steps = {args.gradient_accumulation_steps}\\n"\\ |
|
f"{chr(92)} / Total batch size = {total_train_batch_size:,} | Total steps = {max_steps:,}\\n"\\ |
|
f' "-____-" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}' |
|
logger.warning(debug_info) |
|
import subprocess, re, gc, numpy as np |
|
a = np.array([0,]) |
|
try: |
|
a = subprocess.check_output('nvidia-smi --query-gpu=memory.used --format=csv', shell = True) |
|
a = re.findall(rb'([\\d]{1,})[\\s]{1,}M', a) |
|
a = np.array([int(x.decode('utf-8'))/1024 for x in a]) |
|
except: |
|
if not torch.cuda.is_available(): |
|
raise RuntimeError('Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!') |
|
if ((a - PRE_CHECK) >= 1).sum() > 1: |
|
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!') |
|
for _ in range(3): |
|
gc.collect() |
|
torch.cuda.empty_cache()""" |
|
|
|
debug_info = debug_info.split('\n') |
|
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]) |
|
inner_training_loop = inner_training_loop.replace(original_debug, debug_info) |
|
|
|
debug_info = """n_total_devices = total_train_batch_size // \\ |
|
args.gradient_accumulation_steps // self._train_batch_size |
|
if n_total_devices > 1: |
|
logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!') |
|
debug_info =""" |
|
debug_info = debug_info.split('\n') |
|
debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]) |
|
inner_training_loop = inner_training_loop.replace("debug_info =", debug_info, 1) |
|
|
|
front_spaces = re.match(r"[\t\s]{1,}", inner_training_loop).group(0) |
|
inner_training_loop = re.sub(r"^" + front_spaces, "", inner_training_loop, flags = re.MULTILINE) |
|
inner_training_loop = inner_training_loop.replace( |
|
"train_dataloader = tpu_spmd_dataloader(train_dataloader)", |
|
"raise RuntimeError('Unsloth: TPUs are not yet supported!')" |
|
) |
|
inner_training_loop = inner_training_loop.replace( |
|
"self.accelerator.free_memory()", |
|
"self.accelerator.free_memory()\n" + \ |
|
front_spaces + "if self.is_deepspeed_enabled:"\ |
|
"raise RuntimeError('Unsloth: Deepspeed is not yet supported!')\n", 1, |
|
) |
|
|
|
check_batches = """train_dataloader = self.get_train_dataloader() |
|
ga = args.gradient_accumulation_steps |
|
bsz = self._train_batch_size |
|
total_batches = bsz * ga * args.world_size |
|
n_total_devices = total_batches // ga // bsz |
|
if n_total_devices > 1: |
|
logger.warning_once('Unsloth currently does not support multi GPU setups - but we are working on it!') |
|
divisor = n_total_devices / 1 |
|
bsz = self._train_batch_size = max(int(bsz / divisor), 1) |
|
if total_batches // ga // bsz > 1: |
|
divisor = n_total_devices / 1 |
|
ga = args.gradient_accumulation_steps = max(int(ga / divisor), 1)""" |
|
check_batches = check_batches.split('\n') |
|
check_batches = "\n".join([check_batches[0]] + [front_spaces + x[8:] for x in check_batches[1:]]) |
|
inner_training_loop = inner_training_loop.replace( |
|
"train_dataloader = self.get_train_dataloader()", |
|
check_batches, 1, |
|
) |
|
inner_training_loop = inner_training_loop.replace( |
|
"_inner_training_loop", |
|
"_fast_inner_training_loop", 1, |
|
) |
|
exec(inner_training_loop, globals()) |
|
|
|
Trainer._inner_training_loop = _fast_inner_training_loop |
|
inner_training_loop = inner_training_loop.replace( |
|
"is_torch_tpu_available()", |
|
"False", |
|
) |
|
if "n_total_devices >" not in inner_training_loop: |
|
raise RuntimeError('Unsloth currently does not support multi GPU setups - but we are working on it!') |
|
pass |
|
inner_training_loop = inner_training_loop.replace( |
|
"is_sagemaker_mp_enabled()", |
|
"False", |
|
) |
|
exec(inner_training_loop, globals()) |
|
Trainer._inner_training_loop = _fast_inner_training_loop |
|
|
|
|
|
model.max_seq_length = max_position_embeddings |
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
internal_model.max_seq_length = max_position_embeddings |
|
internal_model = internal_model.model |
|
pass |
|
internal_model.max_seq_length = max_position_embeddings |
|
|
|
|
|
if fix_tokenizer: |
|
tokenizer = check_tokenizer( |
|
model = model, |
|
tokenizer = tokenizer, |
|
model_name = model_name, |
|
model_max_length = max_position_embeddings, |
|
padding_side = "right", |
|
token = token, |
|
) |
|
pass |
|
patch_saving_functions(tokenizer) |
|
|
|
|
|
|
|
if False: |
|
name = model.config._name_or_path |
|
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"): |
|
name = name[:len(name) - len("-bnb-4bit")] |
|
model.config.update({"_name_or_path" : name}) |
|
pass |
|
pass |
|
|
|
|
|
model.config.update({"unsloth_version" : __version__}) |
|
|
|
|
|
patch_saving_functions(model) |
|
Trainer._inner_training_loop = _fast_inner_training_loop |
|
|
|
|
|
patch_gradient_accumulation_fix(Trainer) |
|
|
|
|
|
tokenizer.padding_side = "left" |
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
internal_model._saved_temp_tokenizer = tokenizer |
|
internal_model = internal_model.model |
|
pass |
|
internal_model._saved_temp_tokenizer = tokenizer |
|
|
|
return model, tokenizer |
|
pass |
|
|
|
|
|
@staticmethod |
|
def post_patch(model, tokenizer): |
|
model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = True) |
|
return model, tokenizer |
|
pass |
|
|
|
|
|
@staticmethod |
|
def get_peft_model( |
|
model, |
|
r = 16, |
|
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", |
|
"gate_proj", "up_proj", "down_proj"], |
|
lora_alpha = 16, |
|
lora_dropout = 0, |
|
bias = "none", |
|
layers_to_transform = None, |
|
layers_pattern = None, |
|
use_gradient_checkpointing = True, |
|
random_state = 3407, |
|
max_seq_length = 2048, |
|
use_rslora = False, |
|
modules_to_save = None, |
|
init_lora_weights = True, |
|
loftq_config = {}, |
|
temporary_location = "_unsloth_temporary_saved_buffers", |
|
**kwargs, |
|
): |
|
transformers_set_seed(random_state) |
|
|
|
if type(r) is not int: |
|
raise TypeError(f"Unsloth: Rank of {str(r)} must be an integer.") |
|
if r <= 0: |
|
raise TypeError(f"Unsloth: Rank of {str(r)} must be larger than 0.") |
|
|
|
if isinstance(model, PeftModelForCausalLM): |
|
|
|
assert(hasattr(model, "peft_config")) |
|
|
|
peft_config = model.peft_config["default"].to_dict() |
|
check_parameters = [ |
|
"r", "lora_alpha", "lora_dropout", |
|
"bias", "layers_to_transform", "layers_pattern", |
|
"use_rslora", "init_lora_weights", |
|
] |
|
check_all = True |
|
for param in check_parameters: |
|
check_all = check_all and (peft_config[param] == eval(param)) |
|
pass |
|
|
|
|
|
old_target_modules = list(peft_config["target_modules"]) |
|
modules_to_save = peft_config["modules_to_save"] |
|
if modules_to_save is None: modules_to_save = {} |
|
modules_to_save = list(modules_to_save) |
|
old_target_modules += modules_to_save |
|
|
|
|
|
new_target_modules = list(target_modules) + \ |
|
list(modules_to_save if modules_to_save is not None else []) |
|
|
|
|
|
new_target_modules = set(new_target_modules) |
|
check_all = check_all and ( |
|
len(set(old_target_modules) ^ new_target_modules) == 0 |
|
) |
|
|
|
check_all = check_all and ( |
|
(loftq_config == {} or loftq_config is None) and \ |
|
(peft_config["loftq_config"] == {} or peft_config["loftq_config"] is None) |
|
) |
|
|
|
if check_all: |
|
|
|
logger.warning( |
|
"Unsloth: Already have LoRA adapters! We shall skip this step." |
|
) |
|
|
|
|
|
|
|
if "embed_tokens" in new_target_modules: |
|
print("Unsloth: Training embed_tokens in mixed precision to save VRAM") |
|
|
|
dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype |
|
model.model.model.embed_tokens.modules_to_save.default\ |
|
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) |
|
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) |
|
|
|
|
|
model.model.model.embed_tokens.original_module\ |
|
.to(device = "cpu", non_blocking = True) |
|
model.model.model.embed_tokens.original_module.requires_grad_(False) |
|
pass |
|
|
|
if "lm_head" in new_target_modules: |
|
print("Unsloth: Training lm_head in mixed precision to save VRAM") |
|
|
|
dtype = model.model.model.lm_head.modules_to_save.default.weight.dtype |
|
model.model.lm_head.modules_to_save.default\ |
|
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) |
|
model.model.lm_head.modules_to_save.default.requires_grad_(True) |
|
|
|
|
|
model.model.lm_head.original_module\ |
|
.to(device = "cpu", non_blocking = True) |
|
model.model.lm_head.original_module.requires_grad_(False) |
|
pass |
|
|
|
return model |
|
else: |
|
raise TypeError( |
|
"Unsloth: Your model already has LoRA adapters. Your new parameters are different." |
|
) |
|
pass |
|
pass |
|
|
|
if loftq_config is None: loftq_config = {} |
|
|
|
signature = str(inspect.signature(LoraConfig)) |
|
SUPPORTS_LOFTQ = "loftq_config" in signature |
|
SUPPORTS_RSLORA = "use_rslora" in signature |
|
|
|
assert(max_seq_length <= model.max_seq_length) |
|
|
|
if lora_dropout != 0: |
|
logger.warning_once( |
|
f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"\ |
|
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit." |
|
) |
|
pass |
|
|
|
if bias != "none": |
|
logger.warning_once( |
|
f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"\ |
|
f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit." |
|
) |
|
pass |
|
|
|
if not (type(init_lora_weights) is bool or \ |
|
init_lora_weights == "gaussian" or init_lora_weights == "loftq"): |
|
raise ValueError( |
|
'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq"].' |
|
) |
|
pass |
|
|
|
if init_lora_weights == "loftq": |
|
|
|
if not SUPPORTS_LOFTQ: |
|
import peft |
|
raise RuntimeError( |
|
f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"\ |
|
"Please install PEFT 0.7.2 or higher.\n"\ |
|
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git" |
|
) |
|
pass |
|
|
|
if loftq_config == {}: |
|
from peft import LoftQConfig |
|
logger.warning_once( |
|
"Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"\ |
|
"We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`." |
|
) |
|
loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1) |
|
pass |
|
|
|
if hasattr(model.config, "quantization_config"): |
|
raise ValueError( |
|
"Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\ |
|
"Reload your model without any quantization by setting `load_in_4bit = False`." |
|
) |
|
pass |
|
pass |
|
|
|
assert(type(use_rslora) is bool) |
|
if use_rslora: |
|
if not SUPPORTS_RSLORA: |
|
|
|
import peft |
|
raise RuntimeError( |
|
f"Unsloth: Your PEFT version of {peft.__version__} does not support `use_rslora`.\n"\ |
|
"Please install PEFT 0.7.2 or higher.\n"\ |
|
"You can also install from source: `pip install git+https://github.com/huggingface/peft.git" |
|
) |
|
pass |
|
pass |
|
|
|
accepted_modules = frozenset(("q_proj", "k_proj", "v_proj", "o_proj", |
|
"gate_proj", "up_proj", "down_proj",),) |
|
model.config.update({"unsloth_version" : __version__}) |
|
|
|
if type(modules_to_save) is tuple: |
|
modules_to_save = list(modules_to_save) |
|
pass |
|
|
|
train_lm_head = False |
|
train_embed_tokens = False |
|
final_modules = [] |
|
for module in target_modules: |
|
if module == "lm_head": |
|
|
|
|
|
|
|
|
|
train_lm_head = True |
|
if modules_to_save is None: modules_to_save = ["lm_head"] |
|
else: modules_to_save.append("lm_head") |
|
|
|
elif module == "embed_tokens": |
|
|
|
|
|
|
|
|
|
train_embed_tokens = True |
|
if modules_to_save is None: modules_to_save = ["embed_tokens"] |
|
else: modules_to_save.append("embed_tokens") |
|
|
|
else: |
|
try: |
|
assert(module in accepted_modules) |
|
final_modules.append(module) |
|
except AssertionError as e: |
|
final_modules.append(module) |
|
print( |
|
"Unsloth: You added custom modules, but Unsloth hasn't optimized for this.\n"\ |
|
"Beware - your finetuning might be noticeably slower!" |
|
) |
|
pass |
|
pass |
|
pass |
|
|
|
|
|
if hasattr(model, "_need_to_train_embeddings"): |
|
if not train_lm_head or not train_embed_tokens: |
|
print( |
|
"Unsloth: You added new tokens but did not specify if you wanted to "\ |
|
"train the lm_head and embed_tokens.\nWe must turn it on for you." |
|
) |
|
train_lm_head = True |
|
train_embed_tokens = True |
|
|
|
if modules_to_save is None: modules_to_save = ["embed_tokens"] |
|
else: modules_to_save.append("embed_tokens") |
|
|
|
if modules_to_save is None: modules_to_save = ["lm_head"] |
|
else: modules_to_save.append("lm_head") |
|
pass |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if modules_to_save is not None: |
|
for module in modules_to_save: |
|
if module == "lm_head": |
|
train_lm_head = True |
|
elif module == "embed_tokens": |
|
train_embed_tokens = True |
|
else: |
|
raise TypeError( |
|
f"Unsloth: Module = {module} is not allowed. Only 'lm_head' and 'embed_tokens' is allowed." |
|
) |
|
pass |
|
pass |
|
if isinstance(modules_to_save, (tuple, list)): |
|
modules_to_save = list(set(modules_to_save)) |
|
pass |
|
|
|
|
|
arguments = dict( |
|
r = r, |
|
lora_alpha = lora_alpha, |
|
target_modules = final_modules, |
|
lora_dropout = lora_dropout, |
|
bias = bias, |
|
task_type = TaskType.CAUSAL_LM, |
|
layers_to_transform = layers_to_transform, |
|
init_lora_weights = init_lora_weights, |
|
loftq_config = loftq_config, |
|
use_rslora = use_rslora, |
|
modules_to_save = modules_to_save, |
|
**kwargs, |
|
) |
|
if not SUPPORTS_LOFTQ: del arguments["loftq_config"] |
|
if not SUPPORTS_RSLORA: del arguments["use_rslora"] |
|
|
|
_saved_temp_tokenizer = model._saved_temp_tokenizer |
|
|
|
lora_config = LoraConfig(**arguments) |
|
|
|
|
|
input_embeddings_device = model. get_input_embeddings().weight.device |
|
output_embeddings_device = model.get_output_embeddings().weight.device |
|
|
|
if use_gradient_checkpointing == "unsloth": |
|
if train_embed_tokens: |
|
print("Unsloth: Offloading input_embeddings to disk to save VRAM") |
|
offload_input_embeddings(model, temporary_location) |
|
pass |
|
|
|
|
|
for _ in range(3): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
pass |
|
|
|
if train_lm_head: |
|
print("Unsloth: Offloading output_embeddings to disk to save VRAM") |
|
offload_output_embeddings(model, temporary_location) |
|
pass |
|
|
|
|
|
for _ in range(3): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
pass |
|
pass |
|
|
|
model = _get_peft_model(model, lora_config) |
|
|
|
model._saved_temp_tokenizer = _saved_temp_tokenizer |
|
|
|
model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing) |
|
|
|
|
|
if train_embed_tokens: |
|
print("Unsloth: Training embed_tokens in mixed precision to save VRAM") |
|
assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) |
|
|
|
dtype = model.model.model.embed_tokens.modules_to_save.default.weight.dtype |
|
model.model.model.embed_tokens.modules_to_save.default\ |
|
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) |
|
model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) |
|
pass |
|
|
|
if train_lm_head: |
|
print("Unsloth: Training lm_head in mixed precision to save VRAM") |
|
assert(hasattr(model.model.lm_head, "modules_to_save")) |
|
|
|
dtype = model.model.lm_head.modules_to_save.default.weight.dtype |
|
model.model.lm_head.modules_to_save.default\ |
|
.to(device = "cuda:0", dtype=(dtype if (dtype != torch.float16) else torch.float32), non_blocking = True) |
|
model.model.lm_head.modules_to_save.default.requires_grad_(True) |
|
pass |
|
|
|
|
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.padding_side = "right" |
|
pass |
|
internal_model = internal_model.model |
|
pass |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.padding_side = "right" |
|
pass |
|
|
|
|
|
for _ in range(3): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
pass |
|
|
|
return model |
|
pass |
|
|
|
|
|
@staticmethod |
|
def patch_peft_model( |
|
model, |
|
use_gradient_checkpointing = True, |
|
): |
|
if not isinstance(model, PeftModelForCausalLM): |
|
raise TypeError( |
|
"Unsloth: Your model needs to call `.get_peft_model` first!" |
|
) |
|
pass |
|
|
|
|
|
model_type = model.config.model_type |
|
|
|
if model_type == "llama": apply_lora_mlp = apply_lora_mlp_swiglu |
|
elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu |
|
elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu |
|
elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx |
|
elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx |
|
elif model_type == "cohere": apply_lora_mlp = apply_lora_mlp_swiglu |
|
elif model_type == "granite": apply_lora_mlp = apply_lora_mlp_swiglu |
|
else: |
|
raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!") |
|
pass |
|
|
|
model = prepare_model_for_kbit_training( |
|
model, |
|
use_gradient_checkpointing = use_gradient_checkpointing, |
|
use_reentrant = True, |
|
) |
|
|
|
|
|
for active_adapter in model.peft_config.keys(): |
|
|
|
if False: |
|
name = model.peft_config[active_adapter].base_model_name_or_path |
|
if name.startswith("unsloth/") and name.endswith("-bnb-4bit"): |
|
name = name[:len(name) - len("-bnb-4bit")] |
|
model.peft_config[active_adapter].base_model_name_or_path = name |
|
pass |
|
|
|
|
|
|
|
pass |
|
|
|
from transformers.trainer import Trainer |
|
if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop": |
|
raise RuntimeError( |
|
'Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so '\ |
|
'enabling it will require much more work, so we have to prioritize. Please understand!\n'\ |
|
'We do have a separate beta version, which you can contact us about!\n'\ |
|
'Thank you for your understanding and we appreciate it immensely!' |
|
) |
|
pass |
|
|
|
|
|
|
|
all_configs = model.peft_config |
|
for key, current_config in all_configs.items(): |
|
if hasattr(current_config, "loftq_config") and current_config.loftq_config is None: |
|
new_args = current_config.__dict__ |
|
new_args["loftq_config"] = {} |
|
current_config = current_config.__class__(**new_args) |
|
all_configs[key] = current_config |
|
pass |
|
pass |
|
|
|
|
|
n_mlp = 0 |
|
n_qkv = 0 |
|
n_o = 0 |
|
import types |
|
|
|
active_adapter = model.active_adapters[0] if \ |
|
hasattr(model, "active_adapters") else model.active_adapter |
|
|
|
|
|
lora_dropout = model.peft_config[active_adapter].lora_dropout |
|
bias = model.peft_config[active_adapter].bias |
|
|
|
|
|
from functools import partial |
|
_apply_lora_mlp = \ |
|
partial(apply_lora_mlp, inplace = False) \ |
|
if model_type == "cohere" else \ |
|
apply_lora_mlp |
|
pass |
|
|
|
if lora_dropout == 0 and bias == "none": |
|
for idx, layer in enumerate(model.model.model.layers): |
|
|
|
|
|
gate_proj = layer.mlp.gate_proj |
|
up_proj = layer.mlp. up_proj |
|
down_proj = layer.mlp.down_proj |
|
|
|
if hasattr(gate_proj, "lora_A") and \ |
|
hasattr( up_proj, "lora_A") and \ |
|
hasattr(down_proj, "lora_A") and \ |
|
(getattr(gate_proj, "base_layer", gate_proj).bias is None) and \ |
|
(getattr( up_proj, "base_layer", up_proj).bias is None) and \ |
|
(getattr(down_proj, "base_layer", down_proj).bias is None) and \ |
|
(len(getattr(gate_proj, "lora_magnitude_vector", []) or []) == 0) and \ |
|
(len(getattr( up_proj, "lora_magnitude_vector", []) or []) == 0) and \ |
|
(len(getattr(down_proj, "lora_magnitude_vector", []) or []) == 0): |
|
|
|
|
|
layer.mlp.forward = types.MethodType(_apply_lora_mlp, layer.mlp) |
|
n_mlp += 1 |
|
else: |
|
logger.warning_once( |
|
"Not an error, but Unsloth cannot patch MLP layers with our manual autograd engine since either LoRA adapters\n"\ |
|
"are not enabled or a bias term (like in Qwen) is used." |
|
) |
|
pass |
|
|
|
|
|
q_proj = layer.self_attn.q_proj |
|
k_proj = layer.self_attn.k_proj |
|
v_proj = layer.self_attn.v_proj |
|
if hasattr(q_proj, "lora_A") and \ |
|
hasattr(k_proj, "lora_A") and \ |
|
hasattr(v_proj, "lora_A") and \ |
|
(getattr(q_proj, "base_layer", q_proj).bias is None) and \ |
|
(getattr(k_proj, "base_layer", k_proj).bias is None) and \ |
|
(getattr(v_proj, "base_layer", v_proj).bias is None) and \ |
|
(len(getattr(q_proj, "lora_magnitude_vector", []) or []) == 0) and \ |
|
(len(getattr(k_proj, "lora_magnitude_vector", []) or []) == 0) and \ |
|
(len(getattr(v_proj, "lora_magnitude_vector", []) or []) == 0): |
|
|
|
layer.self_attn.apply_qkv = apply_lora_qkv |
|
n_qkv += 1 |
|
else: |
|
if model_type == "qwen2": n_qkv += 1 |
|
else: |
|
logger.warning_once( |
|
"Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\ |
|
"are not enabled or a bias term (like in Qwen) is used." |
|
) |
|
pass |
|
pass |
|
|
|
|
|
o_proj = layer.self_attn.o_proj |
|
if hasattr(o_proj, "lora_A") and \ |
|
(getattr(o_proj, "base_layer", o_proj).bias is None) and \ |
|
(len(getattr(o_proj, "lora_magnitude_vector", []) or []) == 0): |
|
|
|
layer.self_attn.apply_o = apply_lora_o |
|
n_o += 1 |
|
else: |
|
logger.warning_once( |
|
"Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters\n"\ |
|
"are not enabled or a bias term (like in Qwen) is used." |
|
) |
|
pass |
|
pass |
|
pass |
|
|
|
logger.warning_once( |
|
f"Unsloth {__version__} patched {len(model.model.model.layers)} layers with "\ |
|
f"{n_qkv} QKV layers, {n_o} O layers and {n_mlp} MLP layers.", |
|
) |
|
patch_saving_functions(model) |
|
|
|
|
|
|
|
max_seq_length = model.max_seq_length |
|
extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0") |
|
model.model.extra_ignored_labels = extra_ignored_labels |
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
internal_model.max_seq_length = max_seq_length |
|
internal_model = internal_model.model |
|
pass |
|
internal_model.max_seq_length = max_seq_length |
|
|
|
|
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.padding_side = "right" |
|
pass |
|
internal_model = internal_model.model |
|
pass |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.padding_side = "right" |
|
pass |
|
|
|
|
|
for _ in range(3): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
pass |
|
return model |
|
pass |
|
|
|
|
|
@staticmethod |
|
def for_inference(model): |
|
|
|
|
|
|
|
|
|
|
|
internal_model = model |
|
internal_model.gradient_checkpointing = False |
|
internal_model.training = False |
|
|
|
while hasattr(internal_model, "model"): |
|
internal_model = internal_model.model |
|
internal_model.gradient_checkpointing = False |
|
internal_model.training = False |
|
pass |
|
if hasattr(internal_model, "training"): |
|
internal_model.training = False |
|
pass |
|
|
|
|
|
internal_model = model |
|
while not hasattr(internal_model, "lm_head"): |
|
internal_model = internal_model.model |
|
pass |
|
lm_head = internal_model.lm_head.weight |
|
device_type = lm_head.device.type |
|
dtype = model.config.torch_dtype |
|
|
|
if type(dtype) is str: |
|
if dtype == "float16": dtype = torch.float16 |
|
elif dtype == "bfloat16": dtype = torch.bfloat16 |
|
pass |
|
|
|
|
|
if model.generate.__name__ != "_fast_generate": |
|
model._unwrapped_old_generate = model.generate |
|
model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) |
|
pass |
|
|
|
|
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.padding_side = "left" |
|
pass |
|
internal_model = internal_model.model |
|
pass |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.padding_side = "left" |
|
pass |
|
|
|
|
|
if hasattr(model, "get_input_embeddings"): |
|
embeddings = model.get_input_embeddings() |
|
if hasattr(embeddings, "training"): embeddings.training = False |
|
pass |
|
if hasattr(model, "get_output_embeddings"): |
|
embeddings = model.get_output_embeddings() |
|
if hasattr(embeddings, "training"): embeddings.training = False |
|
pass |
|
|
|
return model |
|
pass |
|
|
|
|
|
@staticmethod |
|
def for_training(model, use_gradient_checkpointing = True): |
|
internal_model = model |
|
internal_model.gradient_checkpointing = use_gradient_checkpointing |
|
internal_model.training = True |
|
|
|
|
|
for param in model.parameters(): |
|
if hasattr(param, "_fast_lora"): |
|
del param._fast_lora |
|
pass |
|
|
|
while hasattr(internal_model, "model"): |
|
internal_model = internal_model.model |
|
internal_model.gradient_checkpointing = use_gradient_checkpointing |
|
internal_model.training = True |
|
pass |
|
if hasattr(internal_model, "training"): |
|
internal_model.training = True |
|
pass |
|
|
|
|
|
if hasattr(model, "_unwrapped_old_generate"): |
|
model.generate = model._unwrapped_old_generate |
|
del model._unwrapped_old_generate |
|
pass |
|
|
|
|
|
internal_model = model |
|
while hasattr(internal_model, "model"): |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.padding_side = "right" |
|
pass |
|
internal_model = internal_model.model |
|
pass |
|
if hasattr(internal_model, "_saved_temp_tokenizer"): |
|
internal_model._saved_temp_tokenizer.padding_side = "right" |
|
pass |
|
|
|
|
|
if hasattr(model, "get_input_embeddings"): |
|
embeddings = model.get_input_embeddings() |
|
if hasattr(embeddings, "training"): embeddings.training = True |
|
pass |
|
if hasattr(model, "get_output_embeddings"): |
|
embeddings = model.get_output_embeddings() |
|
if hasattr(embeddings, "training"): embeddings.training = True |
|
pass |
|
|
|
return model |
|
pass |
|
pass |
|
|
|
|