|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
from functools import wraps |
|
|
|
import trl |
|
import inspect |
|
from trl import SFTTrainer |
|
from . import is_bfloat16_supported |
|
from unsloth_zoo.training_utils import ( |
|
unsloth_train as _unsloth_train, |
|
) |
|
from unsloth_zoo.vision_utils import ( |
|
UnslothVisionDataCollator, |
|
) |
|
from packaging.version import Version |
|
import dataclasses |
|
|
|
__all__ = [ |
|
"UnslothTrainingArguments", |
|
"UnslothTrainer", |
|
"unsloth_train", |
|
"_patch_trl_trainer", |
|
"UnslothVisionDataCollator", |
|
] |
|
|
|
|
|
from transformers import __version__ as transformers_version |
|
if Version(transformers_version) > Version("4.45.2"): |
|
def unsloth_train(trainer, *args, **kwargs): |
|
return trainer.train(*args, **kwargs) |
|
pass |
|
else: |
|
def unsloth_train(trainer, *args, **kwargs): |
|
if len(args) != 0 or len(kwargs) != 0: |
|
raise RuntimeError( |
|
"Unsloth: Our custom gradient accumulation fixed trainer does not support other arguments.\n"\ |
|
"If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\ |
|
'`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`' |
|
) |
|
print( |
|
"Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"\ |
|
"If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\ |
|
'`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`' |
|
) |
|
return _unsloth_train(trainer) |
|
pass |
|
pass |
|
|
|
try: |
|
from trl import SFTConfig as TrainingArguments |
|
except: |
|
from transformers import TrainingArguments |
|
pass |
|
@dataclass |
|
class UnslothTrainingArguments(TrainingArguments): |
|
embedding_learning_rate : Optional[float] = field( |
|
default = None, |
|
metadata = {"help" : "Different learning rates for embeddings and lm_head."} |
|
) |
|
pass |
|
|
|
|
|
def _create_unsloth_optimizer( |
|
model, |
|
optimizer_cls, |
|
optimizer_kwargs, |
|
embedding_lr = 5e-5, |
|
): |
|
lr = optimizer_kwargs["lr"] |
|
weight_decay = optimizer_kwargs.get("weight_decay", 0.0) |
|
|
|
param_groups = \ |
|
{ |
|
"non_embeddings" : {}, |
|
"embeddings" : {}, |
|
} |
|
|
|
for name, param in model.named_parameters(): |
|
if not param.requires_grad: continue |
|
if name.endswith("modules_to_save.default.weight"): |
|
partial_name = name[:-len(".modules_to_save.default.weight")] |
|
partial_name = partial_name[partial_name.rfind(".")+1:] |
|
print(f"Unsloth: Setting lr = {embedding_lr:.2e} instead of {lr:.2e} for {partial_name}.") |
|
param_groups["embeddings"] [name] = param |
|
else: |
|
param_groups["non_embeddings"][name] = param |
|
pass |
|
pass |
|
|
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params" : list(param_groups["non_embeddings"].values()), |
|
"weight_decay" : weight_decay, |
|
"lr" : lr, |
|
}, |
|
{ |
|
"params" : list(param_groups["embeddings"].values()), |
|
"weight_decay" : weight_decay, |
|
"lr" : embedding_lr, |
|
}, |
|
] |
|
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) |
|
return optimizer |
|
pass |
|
|
|
|
|
class UnslothTrainer(SFTTrainer): |
|
def create_optimizer(self): |
|
embedding_learning_rate = getattr(self.args, "embedding_learning_rate", None) |
|
if embedding_learning_rate is None: return super().create_optimizer() |
|
|
|
if self.optimizer is None: |
|
optimizer_cls, optimizer_kwargs = SFTTrainer.get_optimizer_cls_and_kwargs(self.args) |
|
self.optimizer = _create_unsloth_optimizer( |
|
self.model, |
|
optimizer_cls, |
|
optimizer_kwargs, |
|
embedding_learning_rate, |
|
) |
|
pass |
|
return self.optimizer |
|
pass |
|
pass |
|
|
|
|
|
|
|
def _backwards_compatible_trainer(trainer_class, config_class): |
|
original_init = trainer_class.__init__ |
|
|
|
@wraps(original_init) |
|
def new_init(self, *args, **kwargs): |
|
|
|
trainer_params = set(inspect.signature(original_init).parameters.keys()) |
|
|
|
if "processing_class" in trainer_params and "tokenizer" in kwargs: |
|
kwargs["processing_class"] = kwargs.pop("tokenizer") |
|
pass |
|
|
|
if ("args" in kwargs) and (Version(trl.__version__) >= Version("0.13.0.dev0")): |
|
training_args = kwargs.pop("args", None) |
|
|
|
|
|
trainer_params.remove('self') |
|
trainer_params.remove('args') |
|
|
|
|
|
config_fields = { |
|
field.name: field for field in dataclasses.fields(config_class) |
|
if field.init |
|
} |
|
|
|
|
|
config_dict = { |
|
name: getattr(training_args, name) |
|
for name in config_fields |
|
if hasattr(training_args, name) |
|
} |
|
|
|
|
|
from transformers import TrainingArguments |
|
moved_params = \ |
|
set(inspect.signature(config_class) .parameters.keys()) - \ |
|
set(inspect.signature(TrainingArguments).parameters.keys()) |
|
|
|
|
|
trainer_kwargs = {} |
|
additional_config_kwargs = {} |
|
|
|
for key, value in kwargs.items(): |
|
if key in trainer_params: trainer_kwargs[key] = value |
|
elif key in moved_params or key in config_fields: |
|
additional_config_kwargs[key] = value |
|
else: |
|
additional_config_kwargs[key] = value |
|
pass |
|
pass |
|
|
|
|
|
config_dict.update(additional_config_kwargs) |
|
|
|
|
|
config = config_class(**config_dict) |
|
|
|
|
|
kwargs = trainer_kwargs |
|
kwargs["args"] = config |
|
pass |
|
original_init(self, *args, **kwargs) |
|
pass |
|
return new_init |
|
pass |
|
|
|
|
|
def _patch_trl_trainer(): |
|
import trl |
|
if hasattr(trl, "__UNSLOTH_BACKWARDS_COMPATIBLE__"): return |
|
if Version(trl.__version__) <= Version("0.11.0"): return |
|
|
|
import trl.trainer |
|
trl_classes = dir(trl.trainer) |
|
trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer")) |
|
trl_configs = set(x[:-len("Config")] for x in trl_classes if x.endswith("Config")) |
|
trl_classes = list(trl_trainers & trl_configs) |
|
|
|
for x in trl_classes: |
|
try: exec(f"trl.{x}Trainer.__init__ = _backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals()) |
|
except: continue |
|
pass |
|
|
|
trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ = True |
|
pass |
|
|