|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .llama import * |
|
import os |
|
from ._utils import __version__ |
|
from .llama import ( |
|
LlamaRotaryEmbedding, |
|
LlamaLinearScalingRotaryEmbedding, |
|
) |
|
from .mistral import * |
|
|
|
try: |
|
from transformers.models.granite.modeling_granite import ( |
|
GraniteAttention, |
|
GraniteDecoderLayer, |
|
GraniteModel, |
|
GraniteForCausalLM, |
|
) |
|
except: |
|
from packaging.version import Version |
|
|
|
transformers_version = Version(transformers_version) |
|
if not transformers_version >= Version("4.45.0"): |
|
raise ImportError( |
|
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\ |
|
f"The minimum required version is 4.42.3.\n"\ |
|
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\ |
|
f"to obtain the latest transformers build, then restart this session."\ |
|
) |
|
pass |
|
pass |
|
|
|
from transformers.modeling_attn_mask_utils import ( |
|
_prepare_4d_causal_attention_mask_for_sdpa, |
|
) |
|
|
|
|
|
try: |
|
from transformers.models.granite.modeling_granite import ( |
|
GraniteSdpaAttention, |
|
GraniteFlashAttention2, |
|
) |
|
except: |
|
GraniteSdpaAttention = GraniteAttention |
|
GraniteFlashAttention2 = GraniteAttention |
|
pass |
|
|
|
def GraniteAttention_fast_forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
causal_mask: Optional[BlockDiagonalCausalMask] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
padding_mask: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
*args, **kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
|
|
|
if hasattr(self, "paged_attention"): |
|
del self.paged_attention_K |
|
del self.paged_attention_V |
|
del self.paged_attention |
|
del self.temp_QA |
|
del self.temp_KV |
|
del self.RH_Q |
|
del self.attention |
|
pass |
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
n_heads = self.num_heads |
|
n_groups = self.num_key_value_groups |
|
n_kv_heads = self.num_key_value_heads |
|
head_dim = self.head_dim |
|
assert(n_kv_heads * n_groups == n_heads) |
|
|
|
Q, K, V = self.apply_qkv(self, hidden_states) |
|
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) |
|
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) |
|
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) |
|
|
|
kv_seq_len = K.shape[-2] |
|
if past_key_value is not None: |
|
kv_seq_len += past_key_value[0].shape[-2] |
|
|
|
assert position_embeddings is not None |
|
cos, sin = position_embeddings |
|
if position_ids is None: |
|
Q, K = fast_rope_embedding(Q, K, cos, sin) |
|
else: |
|
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) |
|
|
|
if past_key_value is not None: |
|
K = torch.cat([past_key_value[0], K], dim = 2) |
|
V = torch.cat([past_key_value[1], V], dim = 2) |
|
pass |
|
past_key_value = (K, V) if use_cache else None |
|
|
|
|
|
if (not HAS_FLASH_ATTENTION and attention_mask is None): |
|
|
|
Q = Q.transpose(1, 2) |
|
K = K.transpose(1, 2) |
|
V = V.transpose(1, 2) |
|
K_M = V_M = bsz * kv_seq_len |
|
Q_M = bsz * q_len |
|
|
|
|
|
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) |
|
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) |
|
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) |
|
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) |
|
if hidden_states.requires_grad: |
|
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) |
|
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) |
|
else: |
|
|
|
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) |
|
pass |
|
|
|
A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling) |
|
A = A.view(bsz, q_len, n_heads, head_dim) |
|
|
|
elif HAS_FLASH_ATTENTION and attention_mask is None: |
|
Q = Q.transpose(1, 2) |
|
K = K.transpose(1, 2) |
|
V = V.transpose(1, 2) |
|
window = (kv_seq_len, kv_seq_len) |
|
A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling) |
|
else: |
|
|
|
|
|
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) |
|
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) |
|
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) |
|
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) |
|
|
|
|
|
|
|
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() |
|
|
|
|
|
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False) |
|
|
|
A = A.transpose(1, 2).contiguous() |
|
pass |
|
|
|
attn_output = A.reshape(bsz, q_len, n_heads*head_dim) |
|
attn_output = self.apply_o(self, attn_output) |
|
attn_weights = None |
|
return attn_output, attn_weights, past_key_value |
|
pass |
|
|
|
|
|
def GraniteDecoderLayer_fast_forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
causal_mask: Optional[BlockDiagonalCausalMask] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
padding_mask: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
*args, **kwargs, |
|
): |
|
if use_cache and hasattr(self, "_flag_for_generation"): |
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states) |
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
causal_mask=causal_mask, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
padding_mask=padding_mask, |
|
position_embeddings = position_embeddings, |
|
_flag_for_generation=self._flag_for_generation, |
|
) |
|
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states) |
|
hidden_states = fast_swiglu_inference(self.mlp, hidden_states) |
|
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) |
|
else: |
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) |
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
causal_mask=causal_mask, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
padding_mask=padding_mask, |
|
position_embeddings = position_embeddings, |
|
) |
|
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) |
|
pass |
|
|
|
outputs = (hidden_states,) |
|
if output_attentions: outputs += (self_attn_weights,) |
|
if use_cache: outputs += (present_key_value,) |
|
return outputs |
|
pass |
|
|
|
|
|
from math import sqrt as math_sqrt |
|
KV_CACHE_INCREMENT = 256 |
|
torch_nn_functional_softmax = torch.nn.functional.softmax |
|
torch_matmul = torch.matmul |
|
torch_tanh = torch.tanh |
|
|
|
def GraniteAttention_fast_forward_inference( |
|
self, |
|
hidden_states: torch.Tensor, |
|
past_key_value: Optional[Tuple[torch.Tensor]], |
|
position_ids, |
|
do_prefill = False, |
|
attention_mask = None, |
|
use_sliding_window = False, |
|
position_embeddings : Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
): |
|
|
|
assert position_embeddings is not None, f"Granite model requires position embeddings to be specified" |
|
|
|
Xn = hidden_states |
|
bsz, _, hd = hidden_states.size() |
|
K1, V1 = past_key_value |
|
dtype = Xn.dtype |
|
|
|
n_heads = self.num_heads |
|
n_groups = self.num_key_value_groups |
|
n_kv_heads = self.num_key_value_heads |
|
head_dim = self.head_dim |
|
attention_size = n_heads*head_dim |
|
|
|
seq_len = K1.shape[-2] |
|
kv_seq_len = seq_len + 1 |
|
|
|
|
|
|
|
if do_prefill: |
|
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0") |
|
self.paged_attention_K = self.paged_attention[:,0] |
|
self.paged_attention_V = self.paged_attention[:,1] |
|
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) |
|
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) |
|
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") |
|
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") |
|
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") |
|
|
|
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") |
|
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") |
|
|
|
|
|
self.half_head_dim = head_dim // 2 |
|
elif kv_seq_len >= self.paged_attention.shape[0]: |
|
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim)) |
|
self.paged_attention_K = self.paged_attention[:,0] |
|
self.paged_attention_V = self.paged_attention[:,1] |
|
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT)) |
|
pass |
|
|
|
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) |
|
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) |
|
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) |
|
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) |
|
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) |
|
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
cos, sin = position_embeddings |
|
cos, sin = cos[position_ids], sin[position_ids] |
|
h = self.half_head_dim |
|
|
|
RH_Q = self.RH_Q |
|
RH_Q[:,:,:,:h] = Qn[:,:,:,h:] |
|
RH_Q[:,:,:,h:] = Qn[:,:,:,:h] |
|
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h]) |
|
Qn *= cos |
|
Qn.addcmul_(RH_Q, sin) |
|
|
|
RH_K = RH_Q[:,:n_kv_heads,:,:] |
|
RH_K[:,:,:,:h] = Kn[:,:,:,h:] |
|
RH_K[:,:,:,h:] = Kn[:,:,:,:h] |
|
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) |
|
Kn *= cos |
|
Kn.addcmul_(RH_K, sin) |
|
|
|
|
|
|
|
|
|
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3) |
|
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3) |
|
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3) |
|
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) |
|
|
|
|
|
_, _, cached_len, _ = Kn.shape |
|
if n_groups != 1: |
|
Kn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) |
|
Vn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) |
|
Kn = Kn.reshape(bsz, n_heads, cached_len, head_dim) |
|
Vn = Vn.reshape(bsz, n_heads, cached_len, head_dim) |
|
pass |
|
|
|
|
|
|
|
|
|
Qn *= self.scaling |
|
A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) |
|
|
|
|
|
|
|
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32) |
|
A = torch_matmul(A, Vn, out = Qn) |
|
|
|
|
|
|
|
A = A.transpose(1, 2) |
|
A = A.reshape(bsz, 1, attention_size) |
|
A = fast_linear_forward(self.o_proj, A, out = self.temp_O) |
|
return A, (Kn, Vn) |
|
pass |
|
|
|
|
|
|
|
|
|
def GraniteModel_fast_forward_inference( |
|
self, |
|
input_ids, |
|
past_key_values, |
|
position_ids, |
|
attention_mask = None, |
|
): |
|
input_ids = input_ids[:,:self.max_seq_length] |
|
hidden_states = self.model.embed_tokens(input_ids) |
|
hidden_states = hidden_states.to(self.config.torch_dtype) |
|
hidden_states *= self.model.embedding_multiplier |
|
|
|
bsz, q_len, hd = hidden_states.shape |
|
seq_len = past_key_values[0][0].shape[-2] |
|
if bsz != 1: |
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
|
attention_mask, |
|
(bsz, q_len), |
|
hidden_states, |
|
seq_len, |
|
) |
|
else: |
|
attention_mask = None |
|
pass |
|
|
|
position_embeddings = self.model.rotary_emb(hidden_states, position_ids, self.max_seq_length) |
|
|
|
next_decoder_cache = [] |
|
for idx, decoder_layer in enumerate(self.model.layers): |
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) |
|
hidden_states, present_key_value = GraniteAttention_fast_forward_inference( |
|
decoder_layer.self_attn, |
|
hidden_states = hidden_states, |
|
past_key_value = past_key_values[idx], |
|
position_ids = position_ids, |
|
attention_mask = attention_mask, |
|
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), |
|
position_embeddings = position_embeddings, |
|
) |
|
|
|
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) |
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states) |
|
hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) |
|
hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) |
|
|
|
next_decoder_cache.append(present_key_value) |
|
pass |
|
hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states) |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state = hidden_states, |
|
past_key_values = next_decoder_cache, |
|
hidden_states = [], |
|
attentions = [], |
|
) |
|
pass |
|
|
|
class GraniteRotaryEmbedding(LlamaRotaryEmbedding): |
|
def __init__(self, config): |
|
super().__init__(config = config) |
|
|
|
class FastGraniteModel(FastLlamaModel): |
|
|
|
@staticmethod |
|
def pre_patch(): |
|
init_name, function = patch_linear_scaling( |
|
model_name = "granite", |
|
rope_module = GraniteRotaryEmbedding, |
|
scaled_rope_module = LlamaLinearScalingRotaryEmbedding, |
|
attention_module = GraniteAttention, |
|
) |
|
if init_name is not None: |
|
exec(function, globals()) |
|
GraniteAttention.__init__ = eval(init_name) |
|
pass |
|
GraniteAttention .forward = GraniteAttention_fast_forward |
|
GraniteSdpaAttention .forward = GraniteAttention_fast_forward |
|
GraniteFlashAttention2.forward = GraniteAttention_fast_forward |
|
GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward |
|
GraniteModel .forward = LlamaModel_fast_forward |
|
GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference) |
|
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward |
|
fix_prepare_inputs_for_generation(GraniteForCausalLM) |
|
|
|
import transformers.models.granite.modeling_granite |
|
transformers.models.granite.modeling_granite.GraniteRotaryEmbedding = GraniteRotaryEmbedding |
|
|
|
return |
|
pass |
|
|
|
|
|
@staticmethod |
|
def post_patch(model): |
|
|
|
|
|
|
|
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) |
|
model.config.update({"unsloth_version" : __version__}) |
|
|
|
|
|
lm_head = torch.nn.Linear(1, 1, bias = None) |
|
del lm_head.weight |
|
lm_head.weight = model.lm_head.weight |
|
lm_head.in_features = lm_head.weight.shape[1] |
|
lm_head.out_features = lm_head.weight.shape[0] |
|
model.lm_head = lm_head |
|
|
|
|
|
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr(): |
|
lm_head = torch.nn.Linear(1, 1, bias = None) |
|
del lm_head.weight |
|
lm_head.weight = model.model.embed_tokens.weight |
|
lm_head.in_features = lm_head.weight.shape[1] |
|
lm_head.out_features = lm_head.weight.shape[0] |
|
model.lm_head = lm_head |
|
pass |
|
|
|
|
|
|
|
correct_dtype = lm_head.weight.dtype |
|
|
|
for name, module in model.named_modules(): |
|
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): |
|
weight = module.weight |
|
quant_state = weight.quant_state |
|
|
|
if type(quant_state) is list: |
|
|
|
module.weight.quant_state[2] = correct_dtype |
|
else: |
|
|
|
quant_state.dtype = correct_dtype |
|
pass |
|
pass |
|
|
|
if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")): |
|
|
|
if hasattr(module, "cos_cached") and \ |
|
(module.cos_cached.dtype != correct_dtype): |
|
|
|
module.cos_cached = module.cos_cached.to(correct_dtype) |
|
module.sin_cached = module.sin_cached.to(correct_dtype) |
|
|
|
elif hasattr(module, "short_cos_cached") and \ |
|
(module.short_cos_cached.dtype != correct_dtype): |
|
|
|
module.short_cos_cached = module.short_cos_cached.to(correct_dtype) |
|
module.short_sin_cached = module.short_sin_cached.to(correct_dtype) |
|
pass |
|
pass |
|
pass |
|
|
|
|
|
import gc |
|
for _ in range(3): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return model |
|
pass |
|
pass |
|
|
|
|