import inspect import warnings from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.distributed as dist from torch import nn from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from transformers.generation_logits_process import ( EncoderNoRepeatNGramLogitsProcessor, ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, HammingDiversityLogitsProcessor, InfNanRemoveLogitsProcessor, LogitNormalization, LogitsProcessorList, MinLengthLogitsProcessor, NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, ) from transformers.generation_stopping_criteria import ( MaxLengthCriteria, MaxTimeCriteria, StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria, ) from transformers.pytorch_utils import torch_int_div from transformers.utils import ModelOutput from transformers.generation_utils import ( SampleOutput, BeamSearchOutput, BeamSampleOutput, GreedySearchOutput, GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, SampleEncoderDecoderOutput, ) from utils import get_jump_chunks from torch.nn.utils.rnn import pad_sequence class GenerationMixin: """ A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for: - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and `do_sample=False`. - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and `do_sample=True`. - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and `do_sample=False`. - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if `num_beams>1` and `do_sample=True`. - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if `num_beams>1` and `num_beam_groups>1`. - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or `force_words_ids!=None`. """ def _prepare_model_inputs( self, inputs: Optional[torch.Tensor] = None, bos_token_id: Optional[int] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: """ This function extracts the model-specific `inputs` for generation. """ # 1. retrieve all kwargs that are non-None or non-model input related. # some encoder-decoder models have different names for model and encoder if ( self.config.is_encoder_decoder and hasattr(self, "encoder") and self.encoder.main_input_name != self.main_input_name ): input_name = self.encoder.main_input_name else: input_name = self.main_input_name model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} # 2. check whether model_input_name is passed as kwarg # if yes and `inputs` is None use kwarg inputs inputs_kwarg = model_kwargs.pop(input_name, None) if inputs_kwarg is not None and inputs is not None: raise ValueError( f"`inputs`: {inputs}` were passed alongside " f"{input_name} which is not allowed." f"Make sure to either pass {inputs} or {input_name}=..." ) elif inputs_kwarg is not None: inputs = inputs_kwarg # 3. models with `input_ids` can also make use of `inputs_embeds` if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs): inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" # 4. Only encoder-decoder models can have non `input_ids` input format if not self.config.is_encoder_decoder and input_name != "input_ids": raise ValueError( f"If {input_name} is passed as model-specific keyword " "input then model has to be an encoder-decoder and not a " f"{self.__class__.__name__}." ) # 5. if `inputs` is still None, try to create `input_ids` from BOS token if inputs is None: inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) return inputs, input_name, model_kwargs def _can_retrieve_inputs_from_name( self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor] ) -> torch.Tensor: """ If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved from name """ can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set( inspect.signature(self.forward).parameters.keys() ) if can_retrieve_inputs and inputs is not None: raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}") return can_retrieve_inputs def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: """ Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method. """ return {"input_ids": input_ids} def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: """ Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. """ return logits def _prepare_input_ids_for_generation( self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] ) -> torch.LongTensor: if self.config.is_encoder_decoder and encoder_outputs is not None: # make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding shape = encoder_outputs.last_hidden_state.size()[:-1] return torch.ones(shape, dtype=torch.long, device=self.device) * -100 if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id def _prepare_attention_mask_for_generation( self, inputs: torch.Tensor, pad_token_id: int, eos_token_id: int, ) -> torch.LongTensor: is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( (eos_token_id is not None) and (pad_token_id != eos_token_id) ) # Check if input is input_ids and padded -> only then is attention_mask defined if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: return inputs.ne(pad_token_id).long() else: return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None ) -> Dict[str, Any]: # 1. get encoder encoder = self.get_encoder() # 2. prepare encoder args and encoder kwargs from model kwargs irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) } print('encoder_kwargs:', encoder_kwargs) # 3. make sure that encoder returns `ModelOutput` model_input_name = model_input_name if model_input_name is not None else self.main_input_name encoder_kwargs["return_dict"] = True encoder_kwargs[model_input_name] = inputs_tensor model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) return model_kwargs def _prepare_decoder_input_ids_for_generation( self, batch_size: int, decoder_start_token_id: int = None, bos_token_id: int = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, device: torch.device = None, ) -> torch.LongTensor: if model_kwargs is not None and "decoder_input_ids" in model_kwargs: return model_kwargs.pop("decoder_input_ids") else: decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) if device is None: device = self.device return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id ) bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id elif ( hasattr(self.config, "decoder") and hasattr(self.config.decoder, "decoder_start_token_id") and self.config.decoder.decoder_start_token_id is not None ): return self.config.decoder.decoder_start_token_id elif bos_token_id is not None: return bos_token_id elif ( hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id") and self.config.decoder.bos_token_id is not None ): return self.config.decoder.bos_token_id raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) @staticmethod def _expand_inputs_for_generation( input_ids: torch.LongTensor, expand_size: int = 1, is_encoder_decoder: bool = False, attention_mask: Optional[torch.LongTensor] = None, encoder_outputs: Optional[ModelOutput] = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: expanded_return_idx = ( torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) ) input_ids = input_ids.index_select(0, expanded_return_idx) if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) if is_encoder_decoder: if encoder_outputs is None: raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) ) model_kwargs["encoder_outputs"] = encoder_outputs return input_ids, model_kwargs @staticmethod def _update_model_kwargs_for_generation( outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False ) -> Dict[str, Any]: # update past if "past_key_values" in outputs: model_kwargs["past"] = outputs.past_key_values elif "mems" in outputs: model_kwargs["past"] = outputs.mems elif "past_buckets_states" in outputs: model_kwargs["past"] = outputs.past_buckets_states else: model_kwargs["past"] = None # update token_type_ids with last value if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) # update attention mask if not is_encoder_decoder: if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 ) return model_kwargs def _reorder_cache(self, past, beam_idx): raise NotImplementedError( f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}" ) def _get_logits_warper( self, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, temperature: Optional[float] = None, num_beams: Optional[int] = None, renormalize_logits: Optional[bool] = None, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances used for multinomial sampling. """ # init warp parameters top_k = top_k if top_k is not None else self.config.top_k top_p = top_p if top_p is not None else self.config.top_p typical_p = typical_p if typical_p is not None else self.config.typical_p temperature = temperature if temperature is not None else self.config.temperature # instantiate warpers list warpers = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` if temperature is not None and temperature != 1.0: warpers.append(TemperatureLogitsWarper(temperature)) if top_k is not None and top_k != 0: warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if top_p is not None and top_p < 1.0: warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) if typical_p is not None and typical_p < 1.0: warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) # `LogitNormalization` should always be the last logit processor, when present if renormalize_logits is True: warpers.append(LogitNormalization()) return warpers def _get_logits_processor( self, repetition_penalty: float, no_repeat_ngram_size: int, encoder_no_repeat_ngram_size: int, input_ids_seq_length: int, encoder_input_ids: torch.LongTensor, bad_words_ids: List[List[int]], min_length: int, max_length: int, eos_token_id: int, forced_bos_token_id: int, forced_eos_token_id: int, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int, num_beam_groups: int, diversity_penalty: float, remove_invalid_values: bool, exponential_decay_length_penalty: Tuple, logits_processor: Optional[LogitsProcessorList], renormalize_logits: Optional[bool], ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] instances used to modify the scores of the language model head. """ processors = LogitsProcessorList() # init warp parameters repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty no_repeat_ngram_size = ( no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size ) encoder_no_repeat_ngram_size = ( encoder_no_repeat_ngram_size if encoder_no_repeat_ngram_size is not None else self.config.encoder_no_repeat_ngram_size ) bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty forced_bos_token_id = ( forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id ) forced_eos_token_id = ( forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id ) remove_invalid_values = ( remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values ) exponential_decay_length_penalty = ( exponential_decay_length_penalty if exponential_decay_length_penalty is not None else self.config.exponential_decay_length_penalty ) # instantiate processors list # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` if diversity_penalty is not None and diversity_penalty > 0.0: processors.append( HammingDiversityLogitsProcessor( diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups ) ) if repetition_penalty is not None and repetition_penalty != 1.0: processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: if self.config.is_encoder_decoder: processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) else: raise ValueError( "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" ) if bad_words_ids is not None: processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) if min_length is not None and eos_token_id is not None and min_length > 0: processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) if prefix_allowed_tokens_fn is not None: processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) if forced_bos_token_id is not None: processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) if forced_eos_token_id is not None: processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) if remove_invalid_values is True: processors.append(InfNanRemoveLogitsProcessor()) if exponential_decay_length_penalty is not None: processors.append( ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) ) processors = self._merge_criteria_processor_list(processors, logits_processor) # `LogitNormalization` should always be the last logit processor, when present if renormalize_logits is True: processors.append(LogitNormalization()) return processors def _get_stopping_criteria( self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList] ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if max_length is not None: criteria.append(MaxLengthCriteria(max_length=max_length)) if max_time is not None: criteria.append(MaxTimeCriteria(max_time=max_time)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) return criteria def _merge_criteria_processor_list( self, default_list: Union[LogitsProcessorList, StoppingCriteriaList], custom_list: Union[LogitsProcessorList, StoppingCriteriaList], ) -> Union[LogitsProcessorList, StoppingCriteriaList]: if len(custom_list) == 0: return default_list for default in default_list: for custom in custom_list: if type(custom) is type(default): object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" raise ValueError( f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, " f"but it has already been created with the values {default}. {default} has been created by passing the " "corresponding arguments to generate or by the model's config default values. " f"If you just want to change the default values of {object_type} consider passing them as arguments " f"to `generate` instead of using a custom {object_type}." ) default_list.extend(custom_list) return default_list def compute_transition_beam_scores( self, sequences: torch.Tensor, scores: Tuple[torch.Tensor], beam_indices: torch.Tensor, eos_token_id: int = None, ): """compute the transition probabilities of sequences given generation scores and beam indices""" # reshape scores as [vocab_size * batch_size, # generation steps] # with batch_size being 2 * vocab_size and # generation steps being # seq_len - input_length scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) # start of generated tokens cut_idx = sequences.shape[-1] - scores.shape[-1] # adjust for beam indices beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size # compute real indices indices = sequences[:, cut_idx:] + beam_sequence_indices # gather scores and run transition_scores = scores.gather(0, indices) # make sure that if EOS token was used before length of sequence `sequence.shape[-1]` # get first occurence of EOS token eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id if eos_token_id is not None: is_eos_token_id = sequences[:, cut_idx:] == eos_token_id # make sure first eos token still contributes to transition probs is_eos_token_id[:, -1] = False is_eos_token_id = is_eos_token_id.roll(1, -1) # all indices after eos shoud be masked zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool() # zero out padded probs transition_scores.masked_fill_(zero_transition_prob_mask, 0.0) return transition_scores # ADDED FRED def remove_subsets(self, l): #l = [[1, 2, 4, 8], [1, 2, 4, 5, 6], [1, 2, 3], [2, 3, 21], [1, 2, 3, 4], [1, 2, 3, 4, 5, 6, 7]] l2 = l[:] for m in l: for n in l: if set(m).issubset(set(n)) and m != n: l2.remove(m) break return l2 # ADDED FRED @torch.no_grad() def cs_generate( self, inputs: Optional[torch.Tensor] = None, contexts:List[str]=None, #input data model_input:Dict=None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, encoder_no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, max_time: Optional[float] = None, max_new_tokens: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, num_beam_groups: Optional[int] = None, diversity_penalty: Optional[float] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), renormalize_logits: Optional[bool] = None, stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), constraints: Optional[List[Constraint]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, remove_invalid_values: Optional[bool] = None, synced_gpus: Optional[bool] = False, exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, use_kg:bool=False, #added relation_mapper_builder=None, tokenizer=None, max_neig_per_concept=1, #it slows down quite a lot **model_kwargs, ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: # print(model_input) input_ids = model_input['input_ids'] if "input_commonsense_relations" in model_input: # print(model_input['input_commonsense_relations'].sum()) model_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(input_ids.device) if use_kg: all_constraints = [] print('contexts:', contexts[:3]) for context in contexts: constraints = [] print('+++++++') concepts_from_context = relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True, alignment=1) print('concepts_from_context:', concepts_from_context) useful_concepts = [relation_mapper_builder.swow_knowledge.get_related_concepts(concept) for concept in concepts_from_context] if not useful_concepts: useful_concepts = [relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context] useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] # add spaces # useful_concepts = [[phrase for phrase in concepts if len(phrase.split(' ')) == 1] for concepts in useful_concepts] # useful_concepts = list(itertools.chain.from_iterable(useful_concepts)) # print('useful_concepts:', useful_concepts[:5]) print('-------') print('useful_concepts:', useful_concepts) if concepts_from_context and useful_concepts: for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts): print('neighbour:', neighbour_concepts[:5]) # flexible_words = self.most_similar_words(context_concept, neighbour_concepts) # limit the upperbound # flexible_words = [word for word in flexible_words if word not in context_concept] # remove input concepts flexible_words = [word for word in neighbour_concepts if word not in context_concept] # remove input concepts print('flexible_words:', flexible_words[:5]) if not flexible_words: continue flexible_words_ids: List[List[int]] = tokenizer(flexible_words, add_special_tokens=False).input_ids #add_prefix_space=True, flexible_words_ids = self.remove_subsets(flexible_words_ids) # add_prefix_space=True # flexible_words_ids = [x for x in flexible_words_ids if len(x) == 1] # problem with subsets flexible_words_ids = flexible_words_ids[:max_neig_per_concept] #print('flexible_words_ids:', flexible_words_ids[:3]) constraint = DisjunctiveConstraint(flexible_words_ids) constraints.append(constraint) all_constraints.extend(constraints) else: all_constraints = None generated_answers_encoded = self.generate(input_ids=input_ids, #attention_mask=model_input["attention_mask"].to(input_ids.device), constraints=all_constraints, min_length=min_length, #max_length=max_length, do_sample=do_sample, early_stopping=early_stopping, #num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, # eos_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=no_repeat_ngram_size, num_return_sequences=num_return_sequences, return_dict_in_generate=return_dict_in_generate, output_attentions=output_attentions, output_scores=output_scores, **model_kwargs, ) return generated_answers_encoded # ADDED FRED @torch.no_grad() def cs_simple_generate( self, inputs: Optional[torch.Tensor] = None, neighbours_contexts:List[List[str]]=None, #input data model_input:Dict=None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, encoder_no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, max_time: Optional[float] = None, max_new_tokens: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, num_beam_groups: Optional[int] = None, diversity_penalty: Optional[float] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), renormalize_logits: Optional[bool] = None, stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), constraints: Optional[List[Constraint]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, remove_invalid_values: Optional[bool] = None, synced_gpus: Optional[bool] = False, exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, use_kg:bool=False, #added relation_mapper_builder=None, tokenizer=None, max_concepts=2, #it slows down quite a lot **model_kwargs, ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: # print(model_input) input_ids = model_input['input_ids'] if use_kg: all_constraints = [] for context_neighbours in neighbours_contexts: # context_neighbours is a collection of concepts # lets create sub collections of concepts context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3] n_size_chuncks = len(context_neighbours) // max_concepts n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1 sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks)) constraints = [] for sub_concepts in sub_concepts_collection[:max_concepts]: flexible_words_ids: List[List[int]] = tokenizer(sub_concepts, add_special_tokens=False).input_ids #add_prefix_space=True, #flexible_words_ids = self.remove_subsets(flexible_words_ids) flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids] disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids)))) # add_prefix_space=True # flexible_words_ids = [x for x in flexible_words_ids if len(x) == 1] # problem with subsets #flexible_words_ids = flexible_words_ids[:max_neig_per_concept] #print('flexible_words_ids:', flexible_words_ids[:3]) if not any(disjunctive_set): continue constraint = DisjunctiveConstraint(disjunctive_set) constraints.append(constraint) if not any(constraints): constraints=None all_constraints.append(constraints) else: all_constraints = None if not all_constraints: all_constraints = None generated_answers_encoded = [] #print('all_constraints:', all_constraints) for i, contraints in enumerate(all_constraints): #print('contraints.token_ids:', [x.token_ids for x in contraints]) if "input_commonsense_relations" in model_input: # print(model_input['input_commonsense_relations'].sum()) model_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations")[i].unsqueeze(0).to(input_ids.device) #print('model_kwargs.get("attention_mask"):', model_kwargs.get("attention_mask")) model_kwargs["attention_mask"] = model_input.get("attention_mask")[i].unsqueeze(0).to(input_ids.device) gen = self.generate(input_ids=input_ids[i].unsqueeze(0), constraints=contraints, min_length=min_length, #max_length=max_length, do_sample=do_sample, early_stopping=early_stopping, #num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, # eos_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=no_repeat_ngram_size, num_return_sequences=num_return_sequences, return_dict_in_generate=return_dict_in_generate, output_attentions=output_attentions, output_scores=output_scores, **model_kwargs) #print('[gen]:', gen) #print(tokenizer.batch_decode(gen)) generated_answers_encoded.append(gen[0].detach().cpu()) #torch.LongTensor(generated_answers_encoded) #print('generated_answers_encoded:', generated_answers_encoded) return torch.LongTensor(pad_sequence(generated_answers_encoded, batch_first=True, padding_value=tokenizer.pad_token_id)).to(input_ids.device) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, max_length: Optional[int] = None, min_length: Optional[int] = None, do_sample: Optional[bool] = None, early_stopping: Optional[bool] = None, num_beams: Optional[int] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, typical_p: Optional[float] = None, repetition_penalty: Optional[float] = None, bad_words_ids: Optional[Iterable[int]] = None, force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, bos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, encoder_no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, max_time: Optional[float] = None, max_new_tokens: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, num_beam_groups: Optional[int] = None, diversity_penalty: Optional[float] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), renormalize_logits: Optional[bool] = None, stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), constraints: Optional[List[Constraint]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, remove_invalid_values: Optional[bool] = None, synced_gpus: Optional[bool] = False, exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, **model_kwargs, ) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head. The method supports the following generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: - *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and `do_sample=False`. - *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and `do_sample=True`. - *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and `do_sample=False`. - *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if `num_beams>1` and `do_sample=True`. - *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if `num_beams>1` and `num_beam_groups>1`. - *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or `force_words_ids!=None`. Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as defined in the model's config (`config.json`) which in turn defaults to the [`~modeling_utils.PretrainedConfig`] of the model. Most of these parameters are explained in more detail in [this blog post](https://huggingface.co/blog/how-to-generate). Parameters: inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. max_length (`int`, *optional*, defaults to `model.config.max_length`): The maximum length of the sequence to be generated. max_new_tokens (`int`, *optional*, defaults to None): The maximum numbers of tokens to generate, ignore the current number of tokens. Use either `max_new_tokens` or `max_length` but not both, they serve the same purpose. min_length (`int`, *optional*, defaults to 10): The minimum length of the sequence to be generated. do_sample (`bool`, *optional*, defaults to `False`): Whether or not to use sampling ; use greedy decoding otherwise. early_stopping (`bool`, *optional*, defaults to `False`): Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. num_beams (`int`, *optional*, defaults to 1): Number of beams for beam search. 1 means no beam search. temperature (`float`, *optional*, defaults to 1.0): The value used to module the next token probabilities. top_k (`int`, *optional*, defaults to 50): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`, *optional*, defaults to 1.0): If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation. repetition_penalty (`float`, *optional*, defaults to 1.0): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. pad_token_id (`int`, *optional*): The id of the *padding* token. bos_token_id (`int`, *optional*): The id of the *beginning-of-sequence* token. eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token. length_penalty (`float`, *optional*, defaults to 1.0): Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length. 0.0 means no penalty. Set to values < 0.0 in order to encourage the model to generate longer sequences, to a value > 0.0 in order to encourage the model to produce shorter sequences. no_repeat_ngram_size (`int`, *optional*, defaults to 0): If set to int > 0, all ngrams of that size can only occur once. encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`. bad_words_ids(`List[List[int]]`, *optional*): List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`. force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. num_return_sequences(`int`, *optional*, defaults to 1): The number of independently computed returned sequences for each element in the batch. max_time(`float`, *optional*, defaults to None): The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask) decoder_start_token_id (`int`, *optional*): If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. use_cache: (`bool`, *optional*, defaults to `True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. num_beam_groups (`int`, *optional*, defaults to 1): Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. diversity_penalty (`float`, *optional*, defaults to 0.0): This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): If provided, this function constraints the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). logits_processor (`LogitsProcessorList`, *optional*): Custom logits processors that complement the default logits processors built from arguments and a model's config. If a logit processor is passed that is already created with the arguments or a model's config an error is thrown. This feature is intended for advanced users. renormalize_logits: (`bool`, *optional*, defaults to `False`): Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization. stopping_criteria (`StoppingCriteriaList`, *optional*): Custom stopping criteria that complement the default stopping criteria built from arguments and a model's config. If a stopping criteria is passed that is already created with the arguments or a model's config an error is thrown. This feature is intended for advanced users. constraints (`List[Constraint]`, *optional*): Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. forced_bos_token_id (`int`, *optional*): The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target language token. forced_eos_token_id (`int`, *optional*): The id of the token to force as the last generated token when `max_length` is reached. remove_invalid_values (`bool`, *optional*): Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) exponential_decay_length_penalty (`tuple(int, float)`, *optional*): This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Return: [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible [`~utils.ModelOutput`] types are: - [`~generation_utils.GreedySearchDecoderOnlyOutput`], - [`~generation_utils.SampleDecoderOnlyOutput`], - [`~generation_utils.BeamSearchDecoderOnlyOutput`], - [`~generation_utils.BeamSampleDecoderOnlyOutput`] If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible [`~utils.ModelOutput`] types are: - [`~generation_utils.GreedySearchEncoderDecoderOutput`], - [`~generation_utils.SampleEncoderDecoderOutput`], - [`~generation_utils.BeamSearchEncoderDecoderOutput`], - [`~generation_utils.BeamSampleEncoderDecoderOutput`] Examples: Greedy Decoding: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> prompt = "Today I believe we can finally" >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids >>> # generate up to 30 tokens >>> outputs = model.generate(input_ids, do_sample=False, max_length=30) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] ``` Multinomial Sampling: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> prompt = "Today I believe we can finally" >>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids >>> # sample up to 30 tokens >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT >>> outputs = model.generate(input_ids, do_sample=True, max_length=30) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] ``` Beam-search decoding: ```python >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") >>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") >>> sentence = "Paris is one of the densest populated areas in Europe." >>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids >>> outputs = model.generate(input_ids) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] ```""" # 1. Set generation parameters if not already defined bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id num_beams = num_beams if num_beams is not None else self.config.num_beams length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups do_sample = do_sample if do_sample is not None else self.config.do_sample num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id if eos_token_id is None and hasattr(self.config, "decoder"): eos_token_id = self.config.decoder.eos_token_id if pad_token_id is None and eos_token_id is not None: # special case if pad_token_id is not defined print(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # 2. Define model inputs # inputs_tensor has to be defined # model_input_name is defined if model-specific keyword input is passed # otherwise model_input_name is None # all model-specific keyword inputs are removed from `model_kwargs` inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) batch_size = inputs_tensor.shape[0] # 3. Define other model kwargs model_kwargs["output_attentions"] = output_attentions model_kwargs["output_hidden_states"] = output_hidden_states model_kwargs["use_cache"] = use_cache accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, pad_token_id, eos_token_id ) if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: # if model is encoder decoder encoder_outputs are created # and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name ) # 4. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: input_ids = self._prepare_decoder_input_ids_for_generation( batch_size, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, model_kwargs=model_kwargs, device=inputs_tensor.device, ) else: # if decoder-only then inputs_tensor has to be `input_ids` input_ids = inputs_tensor input_ids_seq_length = input_ids.shape[-1] # 5. Prepare `max_length` depending on other stopping criteria # if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens` if max_length is None and max_new_tokens is not None: max_length = max_new_tokens + input_ids_seq_length elif max_length is not None and max_new_tokens is not None: # Both are set, this is odd, raise a warning warnings.warn( "Both `max_length` and `max_new_tokens` have been set " f"but they serve the same purpose. `max_length` {max_length} " f"will take priority over `max_new_tokens` {max_new_tokens}.", UserWarning, ) # default to config if still None max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length if min_length is not None and min_length > max_length: raise ValueError( f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " f"length ({max_length})" ) if input_ids_seq_length >= max_length: input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" print( f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. " "This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." ) # 6. determine generation mode is_constraint_gen_mode = constraints is not None or force_words_ids is not None is_greedy_gen_mode = ( (num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode ) is_sample_gen_mode = ( (num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode ) is_beam_gen_mode = ( (num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode ) is_beam_sample_gen_mode = ( (num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode ) is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode if num_beam_groups > num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") if is_group_beam_gen_mode and do_sample is True: raise ValueError( "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." ) # 7. prepare distribution pre_processing samplers logits_processor = self._get_logits_processor( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, input_ids_seq_length=input_ids_seq_length, encoder_input_ids=inputs_tensor, bad_words_ids=bad_words_ids, min_length=min_length, max_length=max_length, eos_token_id=eos_token_id, forced_bos_token_id=forced_bos_token_id, forced_eos_token_id=forced_eos_token_id, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, num_beams=num_beams, num_beam_groups=num_beam_groups, diversity_penalty=diversity_penalty, remove_invalid_values=remove_invalid_values, exponential_decay_length_penalty=exponential_decay_length_penalty, logits_processor=logits_processor, renormalize_logits=renormalize_logits, ) # 8. prepare stopping criteria stopping_criteria = self._get_stopping_criteria( max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria ) # 9. go into different generation modes if is_greedy_gen_mode: if num_return_sequences > 1: raise ValueError( f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." ) # 10. run greedy search return self.greedy_search( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams, renormalize_logits=renormalize_logits, ) # 11. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 12. run sample return self.sample( input_ids, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_gen_mode: if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 10. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) # 12. run beam search return self.beam_search( input_ids, beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_beam_sample_gen_mode: # 10. prepare logits warper logits_warper = self._get_logits_warper( top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams, renormalize_logits=renormalize_logits, ) if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 11. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size * num_return_sequences, num_beams=num_beams, device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams * num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 13. run beam sample return self.beam_sample( input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_group_beam_gen_mode: if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if num_beams % num_beam_groups != 0: raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") # 10. prepare beam search scorer beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, max_length=stopping_criteria.max_length, device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, num_beam_groups=num_beam_groups, ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) # 12. run beam search return self.group_beam_search( input_ids, beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) elif is_constraint_gen_mode: if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") if stopping_criteria.max_length is None: raise ValueError("`max_length` needs to be a stopping_criteria for now.") if num_beams <= 1: raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.") if do_sample: raise ValueError("`do_sample` needs to be false for constrained generation.") if num_beam_groups is not None and num_beam_groups > 1: raise ValueError("`num_beam_groups` not supported yet for constrained generation.") final_constraints = [] if constraints is not None: final_constraints = constraints if force_words_ids is not None: def typeerror(): raise ValueError( "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" f"of positive integers, but is {force_words_ids}." ) if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: typeerror() for word_ids in force_words_ids: if isinstance(word_ids[0], list): if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() if any(not isinstance(token_ids, list) for token_ids in word_ids): typeerror() if any( any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) for token_ids in word_ids ): typeerror() constraint = DisjunctiveConstraint(word_ids) else: if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): typeerror() constraint = PhrasalConstraint(word_ids) final_constraints.append(constraint) # 10. prepare beam search scorer constrained_beam_scorer = ConstrainedBeamSearchScorer( constraints=final_constraints, batch_size=batch_size, num_beams=num_beams, device=inputs_tensor.device, length_penalty=length_penalty, do_early_stopping=early_stopping, num_beam_hyps_to_keep=num_return_sequences, ) # 11. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs ) # 12. run beam search return self.constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, return_dict_in_generate=return_dict_in_generate, synced_gpus=synced_gpus, **model_kwargs, ) def greedy_search( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, **model_kwargs, ) -> Union[GreedySearchOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs: Additional model specific keyword arguments will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForCausalLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... StoppingCriteriaList, ... MaxLengthCriteria, ... ) >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "It might be possible to" >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ ... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id), ... ] ... ) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> outputs = model.greedy_search( ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ["It might be possible to get a better understanding of the nature of the problem, but it's not"] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) cur_len = input_ids.shape[-1] this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # pre-process distribution next_tokens_scores = logits_processor(input_ids, next_token_logits) # argmax next_tokens = torch.argmax(next_tokens_scores, dim=-1) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) cur_len = cur_len + 1 # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True if return_dict_in_generate: if self.config.is_encoder_decoder: return GreedySearchEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return GreedySearchDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return input_ids def sample( self, input_ids: torch.LongTensor, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, **model_kwargs, ) -> Union[SampleOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. logits_warper (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForCausalLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... TopKLogitsWarper, ... TemperatureLogitsWarper, ... StoppingCriteriaList, ... MaxLengthCriteria, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token >>> model.config.pad_token_id = model.config.eos_token_id >>> input_prompt = "Today is a beautiful day, and" >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ ... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), ... ] ... ) >>> # instantiate logits processors >>> logits_warper = LogitsProcessorList( ... [ ... TopKLogitsWarper(50), ... TemperatureLogitsWarper(0.7), ... ] ... ) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT >>> outputs = model.sample( ... input_ids, ... logits_processor=logits_processor, ... logits_warper=logits_warper, ... stopping_criteria=stopping_criteria, ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) cur_len = input_ids.shape[-1] this_peer_finished = False # used by synced_gpus only # auto-regressive generation while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token if eos_token_id is not None: if pad_token_id is None: raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) cur_len = cur_len + 1 # if eos_token was found in one sentence, set sentence to finished if eos_token_id is not None: unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True if return_dict_in_generate: if self.config.is_encoder_decoder: return SampleEncoderDecoderOutput( sequences=input_ids, scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return SampleDecoderOnlyOutput( sequences=input_ids, scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return input_ids def beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... BeamSearchScorer, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> encoder_input_str = "translate English to German: How old are you?" >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids >>> # lets run beam search using 3 beams >>> num_beams = 3 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()( ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True ... ) ... } >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... num_beams=num_beams, ... device=model.device, ... ) >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), ... ] ... ) >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Wie alt bist du?'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) #Normal execution next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores_processed,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) next_indices = torch_int_div(next_tokens, vocab_size) next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None else: num_return_sequences = beam_scorer.num_beam_hyps_to_keep # return only as many indices as sequences beam_indices = tuple( (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) ) beam_indices = sum(beam_indices, ()) if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=beam_indices, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=beam_indices, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"] def beam_sample( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_warper: Optional[LogitsProcessorList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, **model_kwargs, ) -> Union[BeamSampleOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **beam search multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. logits_warper (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... TopKLogitsWarper, ... TemperatureLogitsWarper, ... BeamSearchScorer, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> encoder_input_str = "translate English to German: How old are you?" >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids >>> # lets run beam search using 3 beams >>> num_beams = 3 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()( ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True ... ) ... } >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, ... ) >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)] ... ) >>> # instantiate logits processors >>> logits_warper = LogitsProcessorList( ... [ ... TopKLogitsWarper(50), ... TemperatureLogitsWarper(0.7), ... ] ... ) >>> outputs = model.beam_sample( ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Wie alt bist du?'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (logits_warper(input_ids, next_token_scores_processed),) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) next_token_scores = torch.gather(next_token_scores, -1, next_tokens) next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) next_tokens = torch.gather(next_tokens, -1, _indices) next_indices = torch_int_div(next_tokens, vocab_size) next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) if return_dict_in_generate and output_scores: beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None else: num_return_sequences = beam_scorer.num_beam_hyps_to_keep # return only as many indices as sequences beam_indices = tuple( (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) ) beam_indices = sum(beam_indices, ()) if self.config.is_encoder_decoder: return BeamSampleEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=beam_indices, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSampleDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=beam_indices, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"] def group_beam_search( self, input_ids: torch.LongTensor, beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = False, **model_kwargs, ): r""" Generates sequences of token ids for models with a language modeling head using **diverse beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. beam_scorer (`BeamScorer`): An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs: Additional model specific kwargs that will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... HammingDiversityLogitsProcessor, ... BeamSearchScorer, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> encoder_input_str = "translate English to German: How old are you?" >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids >>> # lets run diverse beam search using 6 beams >>> num_beams = 6 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()( ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True ... ) ... } >>> # instantiate beam scorer >>> beam_scorer = BeamSearchScorer( ... batch_size=1, ... max_length=model.config.max_length, ... num_beams=num_beams, ... device=model.device, ... num_beam_groups=3, ... ) >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ ... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), ... ] ... ) >>> outputs = model.group_beam_search( ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Wie alt bist du?'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) batch_size = len(beam_scorer._beam_hyps) num_beams = beam_scorer.num_beams num_beam_groups = beam_scorer.num_beam_groups num_sub_beams = num_beams // num_beam_groups device = input_ids.device batch_beam_size, cur_len = input_ids.shape if return_dict_in_generate and output_scores: beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] else: beam_indices = None if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in # the same group don't produce same tokens everytime. beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) # indices which will form the beams in the next time step reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) # do one decoder step on all beams of all sentences in batch model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need if output_scores: processed_score = torch.zeros_like(outputs.logits[:, -1, :]) for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) group_size = group_end_idx - group_start_idx # indices of beams of current group among all sentences in batch batch_group_indices = [] for batch_idx in range(batch_size): batch_group_indices.extend( [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] ) group_input_ids = input_ids[batch_group_indices] # select outputs of beams of current group only next_token_logits = outputs.logits[batch_group_indices, -1, :] # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * group_size, vocab_size) vocab_size = next_token_scores.shape[-1] next_token_scores_processed = logits_processor( group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) next_token_scores = next_token_scores.expand_as(next_token_scores_processed) if output_scores: processed_score[batch_group_indices] = next_token_scores_processed # reshape for beam search next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True ) next_indices = torch_int_div(next_tokens, vocab_size) next_tokens = next_tokens % vocab_size # stateless beam_outputs = beam_scorer.process( group_input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] if return_dict_in_generate and output_scores: beam_indices[beam_group_idx] = tuple( beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) ) input_ids[batch_group_indices] = group_input_ids[beam_idx] group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) current_tokens[batch_group_indices] = group_input_ids[:, -1] # (beam_idx // group_size) -> batch_idx # (beam_idx % group_size) -> offset of idx inside the group reordering_indices[batch_group_indices] = ( num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size) ) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (processed_score,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None else: beam_indices = sum(beam_indices, ()) num_return_sequences = beam_scorer.num_beam_hyps_to_keep # return only as many indices as sequences beam_indices = tuple( (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) ) beam_indices = sum(beam_indices, ()) if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, beam_indices=beam_indices, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"] def constrained_beam_search( self, input_ids: torch.LongTensor, constrained_beam_scorer: ConstrainedBeamSearchScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, synced_gpus: Optional[bool] = None, **model_kwargs, ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **constrained beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. constrained_beam_scorer (`ConstrainedBeamSearchScorer`): A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and sorted during generation, while satisfying a list of positive constraints. For more information, the documentation of [`ConstrainedBeamSearchScorer`] should be read. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. logits_warper (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. pad_token_id (`int`, *optional*): The id of the *padding* token. eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details. output_hidden_states (`bool`, *optional*, defaults to `False`): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details. output_scores (`bool`, *optional*, defaults to `False`): Whether or not to return the prediction scores. See `scores` under returned tensors for more details. return_dict_in_generate (`bool`, *optional*, defaults to `False`): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. Examples: ```python >>> from transformers import ( ... AutoTokenizer, ... AutoModelForSeq2SeqLM, ... LogitsProcessorList, ... MinLengthLogitsProcessor, ... ConstrainedBeamSearchScorer, ... PhrasalConstraint, ... ) >>> import torch >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> encoder_input_str = "translate English to German: How old are you?" >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids >>> # lets run beam search using 3 beams >>> num_beams = 3 >>> # define decoder start token ids >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) >>> input_ids = input_ids * model.config.decoder_start_token_id >>> # add encoder_outputs to model keyword arguments >>> model_kwargs = { ... "encoder_outputs": model.get_encoder()( ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True ... ) ... } >>> constraint_str = "Sie" >>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token >>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] >>> # instantiate beam scorer >>> beam_scorer = ConstrainedBeamSearchScorer( ... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints ... ) >>> # instantiate logits processors >>> logits_processor = LogitsProcessorList( ... [ ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), ... ] ... ) >>> outputs = model.constrained_beam_search( ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs ... ) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) ['Wie alt sind Sie?'] ```""" # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if len(stopping_criteria) == 0: warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) batch_size = len(constrained_beam_scorer._beam_hyps) num_beams = constrained_beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then if this_peer_finished_flag.item() == 0.0: break model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) outputs = self( **model_inputs, return_dict=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # don't waste resources running the code we don't need next_token_logits = outputs.logits[:, -1, :] # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `nn.functional.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) scores_for_all_vocab = next_token_scores_processed.clone() next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) next_indices = (next_tokens / vocab_size).long() next_tokens = next_tokens % vocab_size # stateless beam_outputs = constrained_beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, scores_for_all_vocab, pad_token_id=pad_token_id, eos_token_id=eos_token_id, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder ) if model_kwargs["past"] is not None: model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) # increase cur_len cur_len = cur_len + 1 if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): if not synced_gpus: break else: this_peer_finished = True sequence_outputs = constrained_beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: return BeamSearchDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: return sequence_outputs["sequences"] def top_k_top_p_filtering( logits: torch.FloatTensor, top_k: int = 0, top_p: float = 1.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1, ) -> torch.FloatTensor: """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size) top_k (`int`, *optional*, defaults to 0): If > 0, only keep the top k tokens with highest probability (top-k filtering) top_p (`float`, *optional*, defaults to 1.0): If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) min_tokens_to_keep (`int`, *optional*, defaults to 1): Minimumber of tokens we keep per batch example in the output. From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( None, logits ) if 0 <= top_p <= 1.0: logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) return logits