|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .llama import * |
|
from ._utils import __version__ |
|
import math |
|
|
|
try: |
|
from transformers.models.gemma.modeling_gemma import ( |
|
GemmaAttention, |
|
GemmaDecoderLayer, |
|
GemmaModel, |
|
GemmaForCausalLM, |
|
GemmaRotaryEmbedding, |
|
apply_rotary_pos_emb, |
|
repeat_kv, |
|
) |
|
except: |
|
from packaging.version import Version |
|
transformers_version = Version(transformers_version) |
|
if not transformers_version >= Version("4.38"): |
|
raise ImportError( |
|
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\ |
|
f"The minimum required version is 4.38.\n"\ |
|
f'Try `pip install --upgrade "transformers>=4.38"`\n'\ |
|
f"to obtain the latest transformers build, then restart this session."\ |
|
) |
|
pass |
|
pass |
|
|
|
from transformers.modeling_attn_mask_utils import ( |
|
_prepare_4d_causal_attention_mask_for_sdpa, |
|
) |
|
|
|
try: |
|
from transformers.models.gemma.modeling_gemma import ( |
|
GemmaSdpaAttention, |
|
GemmaFlashAttention2, |
|
) |
|
except: |
|
GemmaSdpaAttention = GemmaAttention |
|
GemmaFlashAttention2 = GemmaAttention |
|
pass |
|
|
|
|
|
torch_nn_functional_gelu = torch.nn.functional.gelu |
|
def fast_geglu_inference(self, X): |
|
|
|
|
|
bsz, _, hd = X.shape |
|
|
|
|
|
|
|
gate = fast_linear_forward(self.gate_proj, X) |
|
up = fast_linear_forward(self. up_proj, X) |
|
gate = torch_nn_functional_gelu(gate, approximate = "tanh") |
|
gate *= up |
|
|
|
|
|
down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd]) |
|
return down |
|
pass |
|
|
|
|
|
|
|
def GemmaDecoderLayer_fast_forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
causal_mask: Optional[BlockDiagonalCausalMask] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
padding_mask: Optional[torch.LongTensor] = None, |
|
*args, **kwargs, |
|
): |
|
if use_cache and hasattr(self, "_flag_for_generation"): |
|
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0") |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight) |
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
causal_mask=causal_mask, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
padding_mask=padding_mask, |
|
) |
|
hidden_states += residual |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight) |
|
hidden_states = fast_geglu_inference(self.mlp, hidden_states) |
|
hidden_states += residual |
|
else: |
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True) |
|
hidden_states, self_attn_weights, present_key_value = self.self_attn( |
|
hidden_states=hidden_states, |
|
causal_mask=causal_mask, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
padding_mask=padding_mask, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
pass |
|
|
|
outputs = (hidden_states,) |
|
if output_attentions: outputs += (self_attn_weights,) |
|
if use_cache: outputs += (present_key_value,) |
|
return outputs |
|
pass |
|
|
|
|
|
from math import sqrt as math_sqrt |
|
|
|
|
|
|
|
def GemmaModel_fast_forward_inference( |
|
self, |
|
input_ids, |
|
past_key_values, |
|
position_ids, |
|
attention_mask = None, |
|
): |
|
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") |
|
input_ids = input_ids[:,:self.max_seq_length] |
|
hidden_states = self.model.embed_tokens(input_ids) |
|
hidden_states = hidden_states.to(self.config.torch_dtype) |
|
|
|
|
|
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype) |
|
|
|
bsz, q_len, hd = hidden_states.shape |
|
seq_len = past_key_values[0][0].shape[-2] |
|
if bsz != 1: |
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
|
attention_mask, |
|
(bsz, q_len), |
|
hidden_states, |
|
seq_len, |
|
) |
|
pass |
|
|
|
next_decoder_cache = [] |
|
for idx, decoder_layer in enumerate(self.model.layers): |
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) |
|
hidden_states, present_key_value = LlamaAttention_fast_forward_inference( |
|
decoder_layer.self_attn, |
|
hidden_states = hidden_states, |
|
past_key_value = past_key_values[idx], |
|
position_ids = position_ids, |
|
attention_mask = attention_mask, |
|
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), |
|
) |
|
hidden_states += residual |
|
|
|
residual = hidden_states |
|
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight) |
|
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states) |
|
hidden_states += residual |
|
|
|
next_decoder_cache.append(present_key_value) |
|
pass |
|
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight) |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state = hidden_states, |
|
past_key_values = next_decoder_cache, |
|
hidden_states = [], |
|
attentions = [], |
|
) |
|
pass |
|
|
|
|
|
|
|
|
|
class GemmaFixedRotaryEmbedding(torch.nn.Module): |
|
|
|
|
|
|
|
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, |
|
config = None, |
|
): |
|
super().__init__() |
|
if config is not None: return |
|
self.dim = dim |
|
self.max_position_embeddings = max_position_embeddings |
|
self.base = base |
|
|
|
self.current_rope_size = min(4 * 8192, self.max_position_embeddings) |
|
|
|
|
|
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) |
|
pass |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
|
|
|
|
self.current_rope_size = seq_len |
|
|
|
|
|
freq_exponents = (2.0 / self.dim) * ( |
|
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float() |
|
) |
|
timescale = self.base**freq_exponents |
|
positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float() |
|
radians_new = positions[..., None] / timescale[None, None, :] |
|
radians_new = radians_new.squeeze(0) |
|
|
|
emb = torch.cat((radians_new, radians_new), dim = -1) |
|
|
|
cos = emb.cos().to(device = "cuda:0", non_blocking = True) |
|
sin = emb.sin().to(device = "cuda:0", non_blocking = True) |
|
self.register_buffer("cos_cached", cos, persistent = False) |
|
self.register_buffer("sin_cached", sin, persistent = False) |
|
pass |
|
|
|
def forward(self, x, position_ids=None, seq_len=None): |
|
|
|
if seq_len > self.current_rope_size: |
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
|
|
|
return ( |
|
self.cos_cached[:seq_len].to(dtype=x.dtype), |
|
self.sin_cached[:seq_len].to(dtype=x.dtype), |
|
) |
|
pass |
|
|
|
def get_cached(self, seq_len = None): |
|
return self.cos_cached, self.sin_cached |
|
pass |
|
|
|
def extend_rope_embedding(self, x, seq_len): |
|
if seq_len <= self.current_rope_size: return |
|
|
|
self.current_rope_size = math.ceil(seq_len / 8192) * 8192 |
|
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype) |
|
pass |
|
pass |
|
|
|
|
|
class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding): |
|
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" |
|
|
|
|
|
|
|
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, |
|
config = None, |
|
): |
|
self.scaling_factor = scaling_factor |
|
super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config) |
|
pass |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
|
|
|
|
self.current_rope_size = seq_len |
|
|
|
|
|
freq_exponents = (2.0 / self.dim) * ( |
|
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float() |
|
) |
|
timescale = self.base**freq_exponents |
|
positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float() |
|
positions = positions / self.scaling_factor |
|
radians_new = positions[..., None] / timescale[None, None, :] |
|
radians_new = radians_new.squeeze(0) |
|
|
|
emb = torch.cat((radians_new, radians_new), dim = -1) |
|
|
|
cos = emb.cos().to(device = "cuda:0", non_blocking = True) |
|
sin = emb.sin().to(device = "cuda:0", non_blocking = True) |
|
self.register_buffer("cos_cached", cos, persistent = False) |
|
self.register_buffer("sin_cached", sin, persistent = False) |
|
pass |
|
pass |
|
|
|
|
|
class FastGemmaModel(FastLlamaModel): |
|
|
|
@staticmethod |
|
def pre_patch(): |
|
init_name, function = patch_linear_scaling( |
|
model_name = "gemma", |
|
rope_module = GemmaFixedRotaryEmbedding, |
|
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding, |
|
attention_module = GemmaAttention, |
|
) |
|
if init_name is not None: |
|
exec(function, globals()) |
|
GemmaAttention.__init__ = eval(init_name) |
|
pass |
|
GemmaAttention .forward = LlamaAttention_fast_forward |
|
GemmaSdpaAttention .forward = LlamaAttention_fast_forward |
|
GemmaFlashAttention2.forward = LlamaAttention_fast_forward |
|
GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward |
|
GemmaModel .forward = LlamaModel_fast_forward |
|
GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference) |
|
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward |
|
fix_prepare_inputs_for_generation(GemmaForCausalLM) |
|
|
|
|
|
|
|
|
|
|
|
|
|
import transformers.models.gemma.modeling_gemma |
|
transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = GemmaFixedRotaryEmbedding |
|
return |
|
pass |
|
|
|
|
|
@staticmethod |
|
def post_patch(model, tokenizer): |
|
|
|
model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = False) |
|
|
|
|
|
|
|
|
|
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm |
|
|
|
|
|
|
|
for name, param in model.named_parameters(): |
|
if ".lora_A." in name or ".lora_B." in name: |
|
param.requires_grad_(True) |
|
else: |
|
param.requires_grad_(False) |
|
pass |
|
|
|
|
|
for name, module in model.named_modules(): |
|
if isinstance(module, GemmaRMSNorm): |
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(module, "variance_epsilon"): |
|
module.variance_epsilon = module.eps |
|
pass |
|
|
|
|
|
import gc |
|
for _ in range(3): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
return model, tokenizer |
|
pass |
|
pass |
|
|