# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from dataclasses import dataclass, field from typing import Optional from functools import wraps import trl import inspect from trl import SFTTrainer from . import is_bfloat16_supported from unsloth_zoo.training_utils import ( unsloth_train as _unsloth_train, ) from unsloth_zoo.vision_utils import ( UnslothVisionDataCollator, ) from packaging.version import Version import dataclasses __all__ = [ "UnslothTrainingArguments", "UnslothTrainer", "unsloth_train", "_patch_trl_trainer", "UnslothVisionDataCollator", ] # Unsloth gradient accumulation fix: from transformers import __version__ as transformers_version if Version(transformers_version) > Version("4.45.2"): def unsloth_train(trainer, *args, **kwargs): return trainer.train(*args, **kwargs) pass else: def unsloth_train(trainer, *args, **kwargs): if len(args) != 0 or len(kwargs) != 0: raise RuntimeError( "Unsloth: Our custom gradient accumulation fixed trainer does not support other arguments.\n"\ "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\ '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`' ) print( "Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"\ "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\ '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`' ) return _unsloth_train(trainer) pass pass try: from trl import SFTConfig as TrainingArguments except: from transformers import TrainingArguments pass @dataclass class UnslothTrainingArguments(TrainingArguments): embedding_learning_rate : Optional[float] = field( default = None, metadata = {"help" : "Different learning rates for embeddings and lm_head."} ) pass def _create_unsloth_optimizer( model, optimizer_cls, optimizer_kwargs, embedding_lr = 5e-5, ): lr = optimizer_kwargs["lr"] weight_decay = optimizer_kwargs.get("weight_decay", 0.0) param_groups = \ { "non_embeddings" : {}, "embeddings" : {}, } for name, param in model.named_parameters(): if not param.requires_grad: continue if name.endswith("modules_to_save.default.weight"): partial_name = name[:-len(".modules_to_save.default.weight")] partial_name = partial_name[partial_name.rfind(".")+1:] print(f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}.") param_groups["embeddings"] [name] = param else: param_groups["non_embeddings"][name] = param pass pass optimizer_grouped_parameters = [ { "params" : list(param_groups["non_embeddings"].values()), "weight_decay" : weight_decay, "lr" : lr, }, { "params" : list(param_groups["embeddings"].values()), "weight_decay" : weight_decay, "lr" : embedding_lr, }, ] optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) return optimizer pass class UnslothTrainer(SFTTrainer): def create_optimizer(self): embedding_learning_rate = getattr(self.args, "embedding_learning_rate", None) if embedding_learning_rate is None: return super().create_optimizer() if self.optimizer is None: optimizer_cls, optimizer_kwargs = SFTTrainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = _create_unsloth_optimizer( self.model, optimizer_cls, optimizer_kwargs, embedding_learning_rate, ) pass return self.optimizer pass pass # From `trl>=0.13.0`, they changed how to pass several params to the trainer # We need to patch to make the transition smooth def _backwards_compatible_trainer(trainer_class, config_class): original_init = trainer_class.__init__ @wraps(original_init) def new_init(self, *args, **kwargs): # All Trainer tokenizer are now called processing_class trainer_params = set(inspect.signature(original_init).parameters.keys()) if "processing_class" in trainer_params and "tokenizer" in kwargs: kwargs["processing_class"] = kwargs.pop("tokenizer") pass if ("args" in kwargs) and (Version(trl.__version__) >= Version("0.13.0.dev0")): training_args = kwargs.pop("args", None) # Get parameters that Trainer.__init__ actually expects trainer_params.remove('self') trainer_params.remove('args') # Get fields that should be passed to Config init config_fields = { field.name: field for field in dataclasses.fields(config_class) if field.init } # Create config dict with valid fields from training_args config_dict = { name: getattr(training_args, name) for name in config_fields if hasattr(training_args, name) } # Get parameters that exist in Config but not in TrainingArguments from transformers import TrainingArguments moved_params = \ set(inspect.signature(config_class) .parameters.keys()) - \ set(inspect.signature(TrainingArguments).parameters.keys()) # Separate kwargs into trainer kwargs and config kwargs trainer_kwargs = {} additional_config_kwargs = {} for key, value in kwargs.items(): if key in trainer_params: trainer_kwargs[key] = value elif key in moved_params or key in config_fields: additional_config_kwargs[key] = value else: additional_config_kwargs[key] = value pass pass # Update config_dict with additional kwargs config_dict.update(additional_config_kwargs) # Create Config with all the collected parameters config = config_class(**config_dict) # Reconstruct kwargs for Trainer kwargs = trainer_kwargs kwargs["args"] = config pass original_init(self, *args, **kwargs) pass return new_init pass def _patch_trl_trainer(): import trl if hasattr(trl, "__UNSLOTH_BACKWARDS_COMPATIBLE__"): return if Version(trl.__version__) <= Version("0.11.0"): return import trl.trainer trl_classes = dir(trl.trainer) trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer")) trl_configs = set(x[:-len("Config")] for x in trl_classes if x.endswith("Config")) trl_classes = list(trl_trainers & trl_configs) for x in trl_classes: try: exec(f"trl.{x}Trainer.__init__ = _backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals()) except: continue pass trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ = True pass