from typing import Any, Dict, Optional, Tuple, Union import torch from torch import nn from torch.utils.data import DistributedSampler, RandomSampler from transformers import PreTrainedModel, Trainer, logging from transformers.configuration_fsmt import FSMTConfig from transformers.file_utils import is_torch_tpu_available from transformers.optimization import ( Adafactor, AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup, ) from transformers.trainer_pt_utils import get_tpu_sampler logger = logging.get_logger(__name__) arg_to_scheduler = { "linear": get_linear_schedule_with_warmup, "cosine": get_cosine_schedule_with_warmup, "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, "polynomial": get_polynomial_decay_schedule_with_warmup, "constant": get_constant_schedule, "constant_w_warmup": get_constant_schedule_with_warmup, } class Seq2SeqTrainer(Trainer): def __init__(self, config=None, data_args=None, *args, **kwargs): super().__init__(*args, **kwargs) if config is None: assert isinstance( self.model, PreTrainedModel ), f"If no `config` is passed the model to be trained has to be of type `PreTrainedModel`, but is {self.model.__class__}" self.config = self._actual_model(self.model).config else: self.config = config self.data_args = data_args self.vocab_size = self.config.tgt_vocab_size if isinstance(self.config, FSMTConfig) else self.config.vocab_size if self.args.label_smoothing != 0 or (self.data_args is not None and self.data_args.ignore_pad_token_for_loss): assert ( self.config.pad_token_id is not None ), "Make sure that `config.pad_token_id` is correcly defined when ignoring `pad_token` for loss calculation or doing label smoothing." if self.config.pad_token_id is None and self.config.eos_token_id is not None: logger.warn( f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.." ) if self.args.label_smoothing == 0: self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) else: # dynamically import label_smoothed_nll_loss from seq2seq_utils import label_smoothed_nll_loss self.loss_fn = label_smoothed_nll_loss def create_optimizer_and_scheduler(self, num_training_steps: int): """ Setup the optimizer and the learning rate scheduler. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. """ if self.optimizer is None: no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": self.args.weight_decay, }, { "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] if self.args.adafactor: self.optimizer = Adafactor( optimizer_grouped_parameters, lr=self.args.learning_rate, scale_parameter=False, relative_step=False, ) else: self.optimizer = AdamW( optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon ) if self.lr_scheduler is None: self.lr_scheduler = self._get_lr_scheduler(num_training_steps) else: # ignoring --lr_scheduler logger.warn("scheduler is passed to `Seq2SeqTrainer`, `--lr_scheduler` arg is ignored.") def _get_lr_scheduler(self, num_training_steps): schedule_func = arg_to_scheduler[self.args.lr_scheduler] if self.args.lr_scheduler == "constant": scheduler = schedule_func(self.optimizer) elif self.args.lr_scheduler == "constant_w_warmup": scheduler = schedule_func(self.optimizer, num_warmup_steps=self.args.warmup_steps) else: scheduler = schedule_func( self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps ) return scheduler def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if isinstance(self.train_dataset, torch.utils.data.IterableDataset): return None elif is_torch_tpu_available(): return get_tpu_sampler(self.train_dataset) else: if self.args.sortish_sampler: self.train_dataset.make_sortish_sampler( self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1 ) return ( RandomSampler(self.train_dataset) if self.args.local_rank == -1 else DistributedSampler(self.train_dataset) ) def _compute_loss(self, model, inputs, labels): if self.args.label_smoothing == 0: if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: # force training to ignore pad token logits = model(**inputs, use_cache=False)[0] loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1)) else: # compute usual loss via models loss, logits = model(**inputs, labels=labels, use_cache=False)[:2] else: # compute label smoothed loss logits = model(**inputs, use_cache=False)[0] lprobs = torch.nn.functional.log_softmax(logits, dim=-1) loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id) return loss, logits def compute_loss(self, model, inputs): labels = inputs.pop("labels") loss, _ = self._compute_loss(model, inputs, labels) return loss def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform an evaluation step on :obj:`model` using obj:`inputs`. Subclass and override to inject custom behavior. Args: model (:obj:`nn.Module`): The model to evaluate. inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument :obj:`labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (:obj:`bool`): Whether or not to return the loss only. Return: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ inputs = self._prepare_inputs(inputs) gen_kwargs = { "max_length": self.data_args.val_max_target_length if self.data_args is not None else self.config.max_length, "num_beams": self.data_args.eval_beams if self.data_args is not None else self.config.num_beams, } if self.args.predict_with_generate and not self.args.prediction_loss_only: generated_tokens = model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], **gen_kwargs, ) # in case the batch is shorter than max length, the output should be padded if generated_tokens.shape[-1] < gen_kwargs["max_length"]: generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) labels = inputs.pop("labels") with torch.no_grad(): # compute loss on predict data loss, logits = self._compute_loss(model, inputs, labels) loss = loss.mean().detach() if self.args.prediction_loss_only: return (loss, None, None) logits = generated_tokens if self.args.predict_with_generate else logits if labels.shape[-1] < gen_kwargs["max_length"]: labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) return (loss, logits, labels) def _pad_tensors_to_max_len(self, tensor, max_length): # If PAD token is not defined at least EOS token has to be defined pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else self.config.eos_token_id if pad_token_id is None: raise ValueError( f"Make sure that either `config.pad_token_id` or `config.eos_token_id` is defined if tensor has to be padded to `max_length`={max_length}" ) padded_tensor = pad_token_id * torch.ones( (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device ) padded_tensor[:, : tensor.shape[-1]] = tensor return padded_tensor