# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. __version__ = "2024.12.12" __all__ = [ "prepare_model_for_kbit_training", "xformers", "xformers_attention", "xformers_version", "__version__", "HAS_FLASH_ATTENTION", "HAS_FLASH_ATTENTION_SOFTCAPPING", "PRE_CHECK", "platform_system", "patch_tokenizer", "get_statistics", "Unsloth_Offloaded_Gradient_Checkpointer", "offload_to_disk", "offload_input_embeddings", "offload_output_embeddings", "is_bfloat16_supported", "unsloth_offloaded_gradient_checkpoint", "torch_compile_options", "patch_linear_scaling", "patch_llama_rope_scaling", "check_nvidia", "create_boolean_mask", "torch_amp_custom_fwd", "torch_amp_custom_bwd", "accelerate_old_send_to_device", "accelerate_new_send_to_device", "patch_gradient_accumulation_fix", "patch_compiling_bitsandbytes", "patch_regional_compilation", "patch_layernorm", "patch_torch_compile", "patch_model_and_tokenizer", "patch_unsloth_gradient_checkpointing", "unpatch_unsloth_gradient_checkpointing", "patch_gradient_checkpointing", "unpatch_gradient_checkpointing", "HAS_CUT_CROSS_ENTROPY", "EMPTY_LOGITS", "fused_linear_cross_entropy", "patch_unsloth_smart_gradient_checkpointing", "unpatch_unsloth_smart_gradient_checkpointing", "create_gradient_checkpointing_buffer", "patch_compiled_autograd", "process_vision_info", "unsloth_compile_transformers", "patch_fast_lora", ] import torch from typing import Union, Optional, List, Any, Callable, Tuple from platform import system as platform_system platform_system = platform_system() import numpy as np import warnings, subprocess, re, inspect, psutil, os, math from unsloth_zoo.utils import Version from unsloth_zoo.tokenizer_utils import ( patch_tokenizer as _patch_tokenizer, ) from unsloth_zoo.patching_utils import ( patch_compiling_bitsandbytes, patch_layernorm, patch_torch_compile, patch_model_and_tokenizer, patch_compiled_autograd, ) from unsloth_zoo.gradient_checkpointing import ( Unsloth_Offloaded_Gradient_Checkpointer, unsloth_offloaded_gradient_checkpoint, patch_unsloth_gradient_checkpointing, unpatch_unsloth_gradient_checkpointing, Unsloth_Gradient_Checkpointer, unsloth_gradient_checkpoint, patch_gradient_checkpointing, unpatch_gradient_checkpointing, patch_unsloth_smart_gradient_checkpointing, unpatch_unsloth_smart_gradient_checkpointing, create_gradient_checkpointing_buffer, ) from unsloth_zoo.loss_utils import ( HAS_CUT_CROSS_ENTROPY, fused_linear_cross_entropy, ) from unsloth_zoo.vision_utils import ( process_vision_info, ) from unsloth_zoo.compiler import ( get_transformers_model_type, unsloth_compile_transformers as _unsloth_compile_transformers, ) # ============================================= # Disable some warnings which can get annoying warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch") warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub") warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub") warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl") warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl") warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers") warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess") warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers") warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate") warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing") warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess") # Stop "Special tokens have been added in the vocabulary, ..." import logging logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1) # Ignore logging messages class HideLoggingMessage(logging.Filter): def __init__(self, text): self.text = text def filter(self, x): return not (self.text in x.getMessage()) pass # The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here. from transformers.training_args import logger as transformers_training_args_logger transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups")) del transformers_training_args_logger # Using the default loss: `ForCausalLMLoss`. try: from transformers.modeling_utils import logger as transformers_modeling_utils_logger transformers_modeling_utils_logger.addFilter(HideLoggingMessage("ForCausalLMLoss")) del transformers_modeling_utils_logger except: pass # The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function. try: from accelerate.utils.modeling import logger as accelerate_utils_modeling_logger accelerate_utils_modeling_logger.addFilter(HideLoggingMessage("The model weights are not tied")) del accelerate_utils_modeling_logger except: pass # Setting `pad_token_id` to `eos_token_id` try: from transformers.generation.utils import logger as transformers_generation_utils_logger transformers_generation_utils_logger.addFilter(HideLoggingMessage("Setting `pad_token_id` to `eos_token_id`")) del transformers_generation_utils_logger except: pass # ============================================= # ============================================= # Edits all Config files to enable RoPE Scaling for all models # Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now. def patch_mistral_nemo_config(config): if "head_dim (" not in config: add_head_dim = "If it is not specified, will default to `8`.\n"\ " head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"\ " The attention head dimension." config = config.replace("If it is not specified, will default to `8`.", add_head_dim) add_head_dim = "num_key_value_heads=8,\n head_dim=None," config = config.replace("num_key_value_heads=8,", add_head_dim) add_head_dim = "self.sliding_window = sliding_window\n self.head_dim = head_dim or hidden_size // num_attention_heads\n" config = config.replace("self.sliding_window = sliding_window", add_head_dim) pass return config pass from transformers import __version__ as transformers_version from transformers import PretrainedConfig model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2", "granite"] for model_name in model_architectures: config_filepath = f"transformers.models.{model_name}.configuration_{model_name}" model_filepath = f"transformers.models.{model_name}.modeling_{model_name}" config_filename = f"{model_name.title()}Config" exec(f"from {config_filepath} import {config_filename}", globals()) try: config = inspect.getsource(eval(config_filename)) except: continue if "rope_scaling" in config: continue config = re.sub( r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:", r"rope_scaling=None,"\ r"\n **kwargs):\n"\ r"\n self.rope_scaling = rope_scaling\n", config, ) # Just for Mistral Nemo if model_name == "mistral": if Version(transformers_version) <= Version("4.42.4"): config = patch_mistral_nemo_config(config) pass exec(config, globals()) exec(f"import {config_filepath}", globals()) exec(f"{config_filepath}.{config_filename} = {config_filename}", globals()) pass # ============================================= # ============================================= # torch.cuda.amp.custom_fwd is deprecated >= 2.4 torch_version = torch.__version__ if Version(torch_version) < Version("2.4.0"): torch_amp_custom_fwd = torch.cuda.amp.custom_fwd torch_amp_custom_bwd = torch.cuda.amp.custom_bwd else: torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda") torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda") pass # ============================================= # ============================================= # Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0' import transformers.cache_utils if hasattr(transformers.cache_utils, "DynamicCache") and \ transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__": source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__) start = source.find("def") spaces = start*" " source = source.split("\n") source = "\n".join(x[start:] for x in source) where = source.find("raise KeyError") source = source[:where] + \ f"if len(self) == 0:\n{spaces}{spaces}"\ " raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \ f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:] source = source.replace("__getitem__", "__cache_utils_getitem__", 1) exec(source) transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__ pass # ============================================= # ============================================= # Weird Databricks errors from transformers.utils import is_openai_available if is_openai_available(): try: from openai import OpenAI except: print("Unsloth: OpenAI failed to import - ignoring for now.") import transformers.utils def _is_openai_available(): return False transformers.utils.is_openai_available = _is_openai_available pass pass # ============================================= # Get Flash Attention v2 if Ampere (RTX 30xx, A100) import bitsandbytes as bnb from transformers import AutoTokenizer from transformers.utils.import_utils import _is_package_available major_version, minor_version = torch.cuda.get_device_capability() SUPPORTS_BFLOAT16 = False HAS_FLASH_ATTENTION = False HAS_FLASH_ATTENTION_SOFTCAPPING = False if major_version >= 8: SUPPORTS_BFLOAT16 = True if _is_package_available("flash_attn"): # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl" try: from flash_attn.flash_attn_interface import flash_attn_cuda HAS_FLASH_ATTENTION = True # Also check for softcapping from flash_attn import __version__ as flash_attn_version HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3") if not HAS_FLASH_ATTENTION_SOFTCAPPING: print( "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\ "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\ "To update flash-attn, do the below:\n"\ '\npip install --no-deps --upgrade "flash-attn>=2.6.3"' ) except: print( "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\ "A possible explanation is you have a new CUDA version which isn't\n"\ "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\ "We shall now use Xformers instead, which does not have any performance hits!\n"\ "We found this negligible impact by benchmarking on 1x A100." ) # Stop Flash Attention from importing! import transformers.utils.import_utils transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False import transformers.utils transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False HAS_FLASH_ATTENTION = False pass else: HAS_FLASH_ATTENTION = False else: # Tri Dao's benchmark shows xformers is faster for now. HAS_FLASH_ATTENTION = False pass from transformers.models.llama.modeling_llama import logger # ============================================= # Get Xformers try: from xformers import __version__ as xformers_version # Temporarily disable 0.0.27 and higher - inference issues if False: #Version(xformers_version) >= Version("0.0.27"): raise ImportError( "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\ "then press Disconnect Runtime and then Restart it.\n"\ "\n"\ "%%capture\n" "# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n" '!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n' '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\ '\n'\ f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"\ 'Please downgrade xformers via `pip install --force-reinstall "xformers<=0.0.27"' ) pass if Version(torch_version) < Version("2.2.0") and Version(xformers_version) >= Version("0.0.24"): raise ImportError( f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\ f"Please install xformers < 0.0.24 for torch = {torch_version}." ) elif Version(torch_version) < Version("2.3.0") and Version(xformers_version) >= Version("0.0.26"): raise ImportError( f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\ f"Please install xformers < 0.0.26 for torch = {torch_version}." ) elif Version(torch_version) < Version("2.4.0") and Version(xformers_version) > Version("0.0.27"): raise ImportError( f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\ f"Please install xformers <= 0.0.27 for torch = {torch_version}." ) pass from xformers._cpp_lib import _register_extensions try: _register_extensions() # Check if C++ modules are loaded correctly except Exception as error: raise ImportError( "Unsloth: Xformers was not installed correctly.\n"\ "Please install xformers separately first.\n"\ "Then confirm if it's correctly installed by running:\n"\ "python -m xformers.info\n\n" "Longer error message:\n" + str(error) ) pass import xformers.ops.fmha as xformers xformers_attention = xformers.memory_efficient_attention except: xformers = None xformers_attention = None xformers_version = None pass # Check TRL version from trl import __version__ as trl_version # Unsloth now supports all TRL versions! if False:#Version(trl_version) >= Version("0.9.0"): raise ImportError( "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\ "then press Disconnect Runtime and then Restart it.\n"\ "\n"\ "%%capture\n" "# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n" '!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n' '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\ '\n'\ f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"\ 'Please downgrade TRL via `pip install --force-reinstall trl' ) pass # ============================================= # Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout' accelerate_old_send_to_device = None accelerate_new_send_to_device = None if xformers_version is not None and Version(xformers_version) >= Version("0.0.27"): import accelerate.utils.operations if hasattr(accelerate.utils.operations, "send_to_device") and \ accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device": accelerate_old_send_to_device = accelerate.utils.operations.send_to_device from accelerate.utils.operations import * send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device) send_to_device = re.sub( r"([ ]{4,})return tensor\.to\(device\)", r"\1try: return tensor.to(device)\n\1except: return tensor", send_to_device, ).replace("def send_to_device", "def _fixed_send_to_device") exec(send_to_device) # accelerate.utils.operations.send_to_device = _fixed_send_to_device accelerate_new_send_to_device = _fixed_send_to_device pass pass # Transformers 4.46 breaks dynamic caching. This is a hack import transformers.generation.configuration_utils if hasattr(transformers.generation.configuration_utils, "ALL_CACHE_IMPLEMENTATIONS"): if type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS) is list: transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("dynamic") pass pass # ============================================= # ============================================= # Torch compile settings UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "1") == "1" # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) def is_big_gpu(index): sms = torch.cuda.get_device_properties(index).multi_processor_count if sms < 80: # V100 # log.warning("not enough SMs to use max_autotune_gemm mode") return False return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu patch_torch_compile( debug = UNSLOTH_COMPILE_DEBUG, O3 = UNSLOTH_COMPILE_MAXIMUM, ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, ) torch_compile_options = { "epilogue_fusion" : True, "max_autotune" : True, "shape_padding" : True, "trace.enabled" : UNSLOTH_COMPILE_DEBUG, "triton.cudagraphs" : False, } import accelerate def torch_compile_kwargs(*args, **kwargs): print("Unsloth: Enabled auto compiling") return {"dynamic" : True, "fullgraph" : False, "options" : torch_compile_options,} pass accelerate.utils.dataclasses.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs accelerate.utils.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs accelerate.accelerator.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs del accelerate def patch_regional_compilation(): # Regional torch 2.5 Recompilation - weirdly very slow?? if torch.nn.ModuleList.__name__ == "UnslothModuleList": return # Only works for torch 2.5 if Version(torch.__version__) < Version("2.5.0"): return old_module_list = torch.nn.ModuleList os.environ["UNSLOTH_PATCHED"] = "1" def UnslothModuleList(*args, **kwargs): if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list: args = [old_module_list([torch.compile(x, dynamic = True, options = torch_compile_options, fullgraph = False) for x in args[0]])] return old_module_list(*args, **kwargs) pass UnslothModuleList.__doc__ = old_module_list.__doc__ torch.nn.ModuleList = UnslothModuleList return pass # ============================================= def prepare_model_for_kbit_training( model : Any, use_gradient_checkpointing : Optional = True, use_reentrant : Optional[bool] = True, ) -> Any: """ Calculates where to place the gradient checkpoints given n_layers. We also freeze all other layers's gradients Args: model: Any LlamaModel with layers. use_gradient_checkpointing (`bool`, *optional*): Default enabled. Provides memory savings by not saving all activations, but only some. use_reentrant (`bool`, *optional*): https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354 Optimal gradient checkpointing algorithm which will be the default in future Pytorch versions. """ # Freeze all parameters except LoRA with torch.no_grad(): for name, param in model.named_parameters(): if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name: param.requires_grad_(True) # Also must be in float32! if param.dtype != torch.float32: name = name.replace("base_model", "model", 1) layer_number = re.search(r"\.[\d]{1,}\.", name).group(0) name = name.replace(layer_number, f"[{layer_number[1:-1]}].") name = name.replace(".weight", "", 1) exec(f"{name}.to(torch.float32)") pass else: param.requires_grad_(False) pass pass # Gradient checkpointing! if use_gradient_checkpointing == "unsloth": # Saves VRAM! original_model = model while hasattr(original_model, "model"): original_model._offloaded_gradient_checkpointing = True original_model = original_model.model pass original_model._offloaded_gradient_checkpointing = True model.gradient_checkpointing_enable() elif use_gradient_checkpointing == True: model.gradient_checkpointing_enable() pass # If use_reentrant = True which is the Pytorch default, we just make the input requires_grad. if use_reentrant: if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() else: def make_inputs_require_grad(module, input, output): output.requires_grad_(True) model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) return model pass # ============================================= # Weirdly LoraLayer.update_layer downcasts PEFT layers to float16?? # For mixed precision, we need it to be in float32 not float16. from peft import __version__ as peft_version if Version(peft_version) < Version("0.12.0"): from peft.tuners.lora.layer import LoraLayer try: source = inspect.getsource(LoraLayer.update_layer) text = "if weight is not None:\n" start = source.find(text) + len(text) end = source.find("self.to(weight.device)", start) spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0] source = source.replace(source[start : end], spaces) spaces = len(re.match(r"[\s]{1,}", source).group(0)) lines = source.split("\n") source = "\n".join(x[spaces:] for x in lines) source = re.sub("([^\.])nn\.", r"\1torch.nn.", source) source = source.replace("def update_layer", "def LoraLayer_update_layer") exec(source, globals()) # Fix up incorrect downcasting of LoRA weights from peft.tuners.lora.layer import LoraLayer LoraLayer.update_layer = LoraLayer_update_layer from peft.tuners.lora import LoraLayer LoraLayer.update_layer = LoraLayer_update_layer except: logger.warning_once( "Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"\ "Luckily, your training run will still work in the meantime!" ) pass pass # ============================================= import psutil def _get_statistics(statistics = None, force_download = True): # We log some basic stats about which environment is being used. # We simply download a README.md file from HF - all data is made public. # This is simply so we can check if some envs are broken or not. # You can disable this by commenting the below out try: n_cpus = psutil.cpu_count(logical = False) keynames = "\n" + "\n".join(os.environ.keys()) if statistics is not None: pass elif "\nCOLAB_" in keynames and n_cpus == 1: statistics = "colab" elif "\nCOLAB_" in keynames: statistics = "colabpro" elif "\nKAGGLE_" in keynames: statistics = "kaggle" elif "\nRUNPOD_" in keynames: statistics = "runpod" elif "\nAWS_" in keynames: statistics = "aws" elif "\nAZURE_" in keynames: statistics = "azure" # elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp" elif "\nINVOCATION_ID" in keynames: statistics = "lambda" # else: statistics = "other" else: def try_vllm_check(): vendor_files = ( "/sys/class/dmi/id/product_version", "/sys/class/dmi/id/bios_vendor", "/sys/class/dmi/id/product_name", "/sys/class/dmi/id/chassis_asset_tag", "/sys/class/dmi/id/sys_vendor", ) from pathlib import Path for vendor_file in vendor_files: path = Path(vendor_file) if path.is_file(): file_content = path.read_text().lower() if "amazon" in file_content: return "aws" elif "microsoft corporation" in file_content: return "azure" elif "google" in file_content: return "gcp" return "other" pass try: statistics = try_vllm_check() except: statistics = "other" pass if statistics is not None: from transformers import AutoModelForCausalLM stats_model = AutoModelForCausalLM.from_pretrained( f"unslothai/{statistics}", force_download = force_download, ) del stats_model pass except: pass pass def get_statistics(): # We log some basic stats about which environment is being used. # We simply download a README.md file from HF - all data is made public. # This is simply so we can check if some envs are broken or not. # You can disable this by setting UNSLOTH_DISABLE_STATISTICS import os if "UNSLOTH_DISABLE_STATISTICS" in os.environ: return from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled disabled = False if not are_progress_bars_disabled(): disable_progress_bars() disabled = True pass _get_statistics(None) _get_statistics("repeat", force_download = False) try: vram = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024 if vram <= 8 : vram = 8 elif vram <= 16: vram = 16 elif vram <= 20: vram = 20 elif vram <= 24: vram = 24 elif vram <= 40: vram = 40 elif vram <= 48: vram = 48 elif vram <= 80: vram = 80 else: vram = 96 _get_statistics(f"vram-{vram}") except: pass pass try: devices = torch.cuda.device_count() _get_statistics(f"{devices if devices <= 8 else 9}") except: pass if disabled: enable_progress_bars() pass # ============================================= # Fixes Bitsandbytes to remove missing warnings from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod from inspect import getsource from accelerate.utils.dataclasses import DistributedType BitsAndBytesConfig__init__ = getsource(BitsAndBytesConfig.__init__) BitsAndBytesConfig__init__ = re.sub( r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n", "", BitsAndBytesConfig__init__, flags = re.MULTILINE, ) BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n") length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0)) BitsAndBytesConfig__init__ = "\n".join(x[length_spaces:] for x in BitsAndBytesConfig__init__) BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace( "__init__", "_BitsAndBytesConfig__init__", ) def _prepare_backend( self, cpu = False, sagemaker_dp = False, backend: str = None, ) -> tuple[str, DistributedType]: return None, DistributedType.NO pass import accelerate.state accelerate.state.PartialState._prepare_backend = _prepare_backend import accelerate.accelerator prepare = inspect.getsource(accelerate.accelerator.Accelerator.prepare) prepare = prepare.split("\n") spaces = prepare[0].find("def") prepare = "\n".join(x[spaces:] for x in prepare) x = "for obj in args:" s = " "*spaces prepare = prepare.replace(x, f'self.state.distributed_type = DistributedType.NO\n{s}{x}', 1) exec(prepare, globals()) accelerate.accelerator.Accelerator.prepare = prepare exec(BitsAndBytesConfig__init__, globals()) import transformers.utils.quantization_config transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__ # ============================================= # Offloading to disk for modules (lm_head, embed_tokens) import pickle def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_temporary_saved_buffers"): file_location = os.path.join(temporary_location, model.config._name_or_path) if not os.path.exists(file_location): os.makedirs(file_location) pass filename = os.path.join(file_location, f"{name}.pt") W = W.weight if hasattr(W, "weight") else W torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,) offloaded_W = torch.load(filename, map_location = "cpu", mmap = True) offloaded_W._offloaded_file_location = filename return offloaded_W pass def offload_input_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"): offloaded_W = offload_to_disk(model.get_input_embeddings(), model, "input_embeddings", temporary_location) new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W) new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location model.set_input_embeddings(new_input_embeddings) return pass def offload_output_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"): offloaded_W = offload_to_disk(model.get_output_embeddings(), model, "output_embeddings", temporary_location) new_output_embeddings = torch.nn.Linear(1, 1, bias = None) del new_output_embeddings.weight new_output_embeddings.weight = offloaded_W new_output_embeddings.in_features = offloaded_W.shape[1] new_output_embeddings.out_features = offloaded_W.shape[0] new_output_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location model.set_output_embeddings(new_output_embeddings) return pass # Fixes a weird Torch 2.3 bug which says T4s have bfloat16 def is_bfloat16_supported(): return SUPPORTS_BFLOAT16 pass # Patches models to add RoPE Scaling def patch_linear_scaling( model_name = "gemma2", rope_module = None, scaled_rope_module = None, attention_module = None, ): assert(rope_module is not None and scaled_rope_module is not None) assert(attention_module is not None) rope_name = rope_module.__name__ scaled_rope_name = scaled_rope_module.__name__ model_filepath = f"transformers.models.{model_name}.modeling_{model_name}" exec_code = \ f"import torch.nn as nn\n"\ f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\ f"from {model_filepath} import logger, "\ f"{model_name.title()}Attention, {model_name.title()}Config" try: function = inspect.getsource(attention_module.__init__) except: # Most likely already patched! return None, None where = function.find("def") function = function.split("\n") function = "\n".join(x[where:] for x in function) init_name = f"{model_name.title()}Attention__init__" function = function.replace("def __init__", f"def {init_name}") function = function.replace( "super().__init__()", f"super({model_name.title()}Attention, self).__init__()", ) fix_rope_function = """ if getattr(self.config, "rope_scaling", None) is None: self.rotary_emb = {rope_function}( dim = self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = {scaled_rope_function}( dim = self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}") pass """ fix_rope_function = fix_rope_function.format( rope_function = rope_module.__name__, scaled_rope_function = scaled_rope_module.__name__, ) rotary_emb = re.findall( "self.rotary_emb = .+?\)", function, flags = re.DOTALL | re.MULTILINE, ) if len(rotary_emb) == 0: return None, function rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function return init_name, function pass # Patches for Llama-3 LlamaExtendedRotaryEmbedding def patch_llama_rope_scaling( model_name = "llama", rope_module = None, scaled_rope_module = None, extended_rope_module = None, attention_module = None, longrope_module = None, ): assert(\ rope_module is not None and \ scaled_rope_module is not None and \ extended_rope_module is not None ) assert(attention_module is not None) rope_name = rope_module.__name__ scaled_rope_name = scaled_rope_module.__name__ model_filepath = f"transformers.models.{model_name}.modeling_{model_name}" exec_code = \ f"import torch.nn as nn\n"\ f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\ f"from {model_filepath} import logger, "\ f"{model_name.title()}Attention, {model_name.title()}Config" try: function = inspect.getsource(attention_module.__init__) except: # Most likely already patched! return None, None where = function.find("def") function = function.split("\n") function = "\n".join(x[where:] for x in function) init_name = f"{model_name.title()}Attention__init__" function = function.replace("def __init__", f"def {init_name}") function = function.replace( "super().__init__()", f"super({model_name.title()}Attention, self).__init__()", ) fix_rope_function = """ if getattr(self.config, "rope_scaling", None) is None: self.rotary_emb = {rope_function}( dim = self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) else: scaling_type1 = self.config.rope_scaling.get("type", None) scaling_type2 = self.config.rope_scaling.get("rope_type", None) scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2 scaling_factor = self.config.rope_scaling.get("factor") if scaling_type == "linear": self.rotary_emb = {scaled_rope_function}( dim = self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) elif scaling_type == "llama3": self.rotary_emb = {extended_rope_function}( dim = self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) elif scaling_type == "longrope": self.rotary_emb = {longrope_rope_function}( dim = self.head_dim, max_position_embeddings = self.max_position_embeddings, original_max_position_embeddings = self.config.original_max_position_embeddings, base = self.rope_theta, short_factor = self.config.rope_scaling['short_factor'], long_factor = self.config.rope_scaling['long_factor' ], ) else: raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}") pass """ fix_rope_function = fix_rope_function.format( rope_function = rope_module.__name__, scaled_rope_function = scaled_rope_module.__name__, extended_rope_function = extended_rope_module.__name__, longrope_rope_function = \ (longrope_module if longrope_module is not None else rope_module).__name__ ) rotary_emb = re.findall( "self.rotary_emb = .+?\)", function, flags = re.DOTALL | re.MULTILINE, ) if len(rotary_emb) == 0: return None, function rotary_emb = rotary_emb[0] function = function.replace(rotary_emb, fix_rope_function, 1) function = exec_code + "\n\n" + function return init_name, function pass def check_nvidia(): # Unsloth doesn't work yet on AMD devices - we're working on it! output = np.array([0,]) try: output = subprocess.check_output("nvidia-smi --query-gpu=memory.used --format=csv", shell = True) output = re.findall(rb'([\d]{1,})[\s]{1,}M', output) output = np.array([int(x.decode('utf-8'))/1024 for x in output]) except: if not torch.cuda.is_available(): raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!") return output pass PRE_CHECK = check_nvidia() def create_boolean_mask(n = 4096, sliding_window = 2048): # Creates a boolean mask for attention mask = torch.ones(n, n, dtype = torch.bool) if sliding_window == 0: return torch.triu(mask, diagonal = 1, out = mask) pass torch.triu(mask, diagonal = 0, out = mask) torch.triu(mask.T, diagonal = -sliding_window, out = mask.T) mask = mask.T torch.logical_not(mask, out = mask) return mask pass def test_mask_creation(): from transformers.modeling_attn_mask_utils import AttentionMaskConverter for n in range(2, 23): for s in range(1, 23): correct_mask = AttentionMaskConverter( is_causal = True, sliding_window = s, ).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0) correct_mask = (correct_mask == correct_mask.min()) our_mask = create_boolean_mask(n = n, sliding_window = s) assert(torch.all(correct_mask == our_mask)) pass correct_mask = AttentionMaskConverter( is_causal = True, sliding_window = None, ).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0) correct_mask = (correct_mask == correct_mask.min()) our_mask = create_boolean_mask(n = n, sliding_window = 0) assert(torch.all(correct_mask == our_mask)) pass pass def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): batch_samples = [] num_items_in_batch = None # Check if model allows **kwargs model = self.model f = model.base_model.model.forward if hasattr(model, "base_model") else model.forward has_kwargs = tuple(inspect.signature(f).parameters.values())[-1].kind == inspect._VAR_KEYWORD # Iterate to find all batches for _ in range(num_batches): try: batch_samples += [next(epoch_iterator)] except StopIteration: break pass # Get num_items_in_batch if has_kwargs and len(batch_samples) > 0 and "labels" in batch_samples[0]: try: num_items_in_batch = sum( [(x["labels"][..., 1:] != -100).sum() for x in batch_samples] ) if self.args.average_tokens_across_devices: num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.item() except Exception as exception: logger.warning_once(exception) pass return batch_samples, num_items_in_batch pass def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): num_items_in_batch = None if "num_items_in_batch" in kwargs: num_items_in_batch = kwargs["num_items_in_batch"] if num_items_in_batch is None: # Remove it since the model does not support it! kwargs.pop("num_items_in_batch") elif "num_items_in_batch" not in inputs: inputs["num_items_in_batch"] = num_items_in_batch pass pass if num_items_in_batch is None: name = (model.base_model.model if hasattr(model, "base_model") else model).__class__.__name__ logger.warning_once( f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"\ "Using gradient accumulation will be very slightly less accurate.\n"\ "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient" ) pass return self._old_compute_loss(model, inputs, *args, **kwargs) pass def patch_gradient_accumulation_fix(Trainer): # Fixes gradient accumulation import inspect if hasattr(Trainer, "get_batch_samples"): if Trainer.get_batch_samples.__name__ == "_unsloth_get_batch_samples": return if \ not inspect.getsource(Trainer.get_batch_samples).strip()\ .endswith("return batch_samples, num_items_in_batch"): raise NotImplementedError("Unsloth: Please make a Github issue immediately!!") else: if Trainer.get_batch_samples.__name__ != "_unsloth_get_batch_samples": Trainer.get_batch_samples = _unsloth_get_batch_samples pass # Also fix passing in num_items_in_batch if not hasattr(Trainer, "_old_compute_loss"): Trainer._old_compute_loss = Trainer.compute_loss Trainer.compute_loss = _unsloth_pre_compute_loss pass pass else: logger.warning_once( "Unsloth: We fixed a gradient accumulation bug, "\ "but it seems like you don't have the latest transformers version!\n"\ "Please update transformers, TRL and unsloth via:\n"\ '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`' ) pass # Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps if Trainer.training_step.__name__ == "_unsloth_training_step": return if "num_items_in_batch" not in inspect.signature(Trainer.training_step).parameters: return function = inspect.getsource(Trainer.training_step) where = function.find("def") function = function.split("\n") function = "\n".join(x[where:] for x in function) # Import all variables that need importing import transformers.trainer items_in_trainer = dir(transformers.trainer) good_items = [] for item in items_in_trainer: # TODO: Support Deepspeed if item.startswith(("deepspeed", "xm", "met", "smp")): continue if item in function: good_items.append(item) pass exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals()) # Accelerate does / self.args.gradient_accumulation_steps internally, so if we already # summed it up and did the division before hand, we have to negate it. function = function.replace( "loss *= self.args.gradient_accumulation_steps", "if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps", ) function = function.replace("def training_step", "def _unsloth_training_step", 1) # Fix 4.47.0 issue where num_items_in_batch was removed # See https://github.com/huggingface/transformers/pull/35121 function = function.replace( "if self.model_accepts_loss_kwargs:", "if False:", ) # Fix when num_items_in_batch is nothing # https://github.com/huggingface/transformers/pull/35207 function = re.sub( r"else:\n"\ r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n"\ r"(.+?)if num_items_in_batch is None\:\n"\ r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps", "else:\n"\ "\2if num_items_in_batch is None:\n"\ "\3loss = loss / self.args.gradient_accumulation_steps\n"\ "\1self.accelerator.backward(loss, **kwargs)", function, ) exec(function, globals()) Trainer.training_step = _unsloth_training_step pass def patch_tokenizer(model, tokenizer): model, tokenizer = _patch_tokenizer(model, tokenizer) if model is not None: model.config.update({"unsloth_version" : __version__}) return model, tokenizer pass def patch_fast_lora(): import peft.tuners.lora.bnb peft.tuners.lora.bnb.Linear4bit.forward = fast_lora_forward pass def unsloth_compile_transformers( model_name, token = None, revision = None, trust_remote_code = False, sdpa_dynamic_mask = True, sdpa_bool_masks = True, sdpa_gqa_replace = True, sdpa_dynamic_compile = True, compile_attention = True, disable_causal_masks = True, compile_torch_modules = True, compile_custom_modules = True, compile_function_calls = True, fuse_lm_head = True, gradient_checkpointing = True, manual_replacements = True, fast_lora_forwards = True, fast_residual_stream = True, accurate_accumulation = True, epilogue_fusion = True, max_autotune = False, shape_padding = True, cudagraphs = False, debug = False, fullgraph = True, import_from_cache = False, disable = False, return_logits = False, ): if Version(torch_version) < Version("2.4.0"): print( "="*30 + \ "Unsloth: Unfortunately Unsloth vision and other newer optimized models need Torch 2.4 or later.\n"\ f"You have Torch version {torch_version}. Please upgrade your Torch version by visiting https://pytorch.org/\n"\ "For now your models will not get optimized, but will still work for now!" ) return pass if disable: return model_types = get_transformers_model_type( model_name = model_name, token = token, revision = revision, trust_remote_code = trust_remote_code, ) for model_type in model_types: _unsloth_compile_transformers( model_type, sdpa_dynamic_mask = sdpa_dynamic_mask, sdpa_bool_masks = sdpa_bool_masks, sdpa_gqa_replace = sdpa_gqa_replace, sdpa_dynamic_compile = sdpa_dynamic_compile, compile_attention = compile_attention, disable_causal_masks = disable_causal_masks, compile_torch_modules = compile_torch_modules, compile_custom_modules = compile_custom_modules, compile_function_calls = compile_function_calls, fuse_lm_head = fuse_lm_head, gradient_checkpointing = gradient_checkpointing, manual_replacements = manual_replacements, fast_lora_forwards = fast_lora_forwards, fast_residual_stream = fast_residual_stream, accurate_accumulation = accurate_accumulation, epilogue_fusion = epilogue_fusion, max_autotune = max_autotune, shape_padding = shape_padding, cudagraphs = cudagraphs, debug = debug, fullgraph = fullgraph, import_from_cache = import_from_cache, disable = disable, return_logits = return_logits, ) pass return model_types pass # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie # os.environ['UNSLOTH_RETURN_LOGITS'] = '1' LOGITS_ERROR_STRING = \ "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\ 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n\n'\ "import os\n"\ "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\ "... trainer.train() ..." def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING) def return_none(*args, **kwargs): return None class EmptyLogits: def __init__(self): return def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error __getitem__ = raise_logits_error __getattr__ = raise_getattr_error def __repr__(self): return LOGITS_ERROR_STRING def __str__ (self): return LOGITS_ERROR_STRING pass EMPTY_LOGITS = EmptyLogits() functions = dir(torch.Tensor) for j, function in enumerate(functions): if function.startswith("__") and function.endswith("__"): exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals()) try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals()) except: continue pass