RA-BART / custom_bart /bart_generation_mixin.py
MrVicente's picture
added demo base code
import inspect
import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch import nn
from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation_logits_process import (
from transformers.generation_stopping_criteria import (
from transformers.pytorch_utils import torch_int_div
from transformers.utils import ModelOutput
from transformers.generation_utils import (
GreedySearchOutput, GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, GreedySearchEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput, SampleEncoderDecoderOutput,
from utils import get_jump_chunks
from torch.nn.utils.rnn import pad_sequence
class GenerationMixin:
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`.
- *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`.
- *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`],
if `constraints!=None` or `force_words_ids!=None`.
def _prepare_model_inputs(
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
This function extracts the model-specific `inputs` for generation.
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if (
and hasattr(self, "encoder")
and self.encoder.main_input_name != self.main_input_name
input_name = self.encoder.main_input_name
input_name = self.main_input_name
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside "
f"{input_name} which is not allowed."
f"Make sure to either pass {inputs} or {input_name}=..."
elif inputs_kwarg is not None:
inputs = inputs_kwarg
# 3. models with `input_ids` can also make use of `inputs_embeds`
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. Only encoder-decoder models can have non `input_ids` input format
if not self.config.is_encoder_decoder and input_name != "input_ids":
raise ValueError(
f"If {input_name} is passed as model-specific keyword "
"input then model has to be an encoder-decoder and not a "
# 5. if `inputs` is still None, try to create `input_ids` from BOS token
if inputs is None:
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
return inputs, input_name, model_kwargs
def _can_retrieve_inputs_from_name(
self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
) -> torch.Tensor:
If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
from name
can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
if can_retrieve_inputs and inputs is not None:
raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")
return can_retrieve_inputs
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
return {"input_ids": input_ids}
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
return logits
def _prepare_input_ids_for_generation(
self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput]
) -> torch.LongTensor:
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
return torch.ones(shape, dtype=torch.long, device=self.device) * -100
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id
def _prepare_attention_mask_for_generation(
inputs: torch.Tensor,
pad_token_id: int,
eos_token_id: int,
) -> torch.LongTensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long()
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
) -> Dict[str, Any]:
# 1. get encoder
encoder = self.get_encoder()
# 2. prepare encoder args and encoder kwargs from model kwargs
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix)
print('encoder_kwargs:', encoder_kwargs)
# 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
batch_size: int,
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
device: torch.device = None,
) -> torch.LongTensor:
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
if device is None:
device = self.device
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "decoder_start_token_id")
and self.config.decoder.decoder_start_token_id is not None
return self.config.decoder.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
return self.config.decoder.bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
def _expand_inputs_for_generation(
input_ids: torch.LongTensor,
expand_size: int = 1,
is_encoder_decoder: bool = False,
attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[ModelOutput] = None,
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
expanded_return_idx = (
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
input_ids = input_ids.index_select(0, expanded_return_idx)
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
if is_encoder_decoder:
if encoder_outputs is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs
def _update_model_kwargs_for_generation(
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
# update past
if "past_key_values" in outputs:
model_kwargs["past"] = outputs.past_key_values
elif "mems" in outputs:
model_kwargs["past"] = outputs.mems
elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs.past_buckets_states
model_kwargs["past"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask
if not is_encoder_decoder:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
return model_kwargs
def _reorder_cache(self, past, beam_idx):
raise NotImplementedError(
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}"
def _get_logits_warper(
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
renormalize_logits: Optional[bool] = None,
) -> LogitsProcessorList:
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
used for multinomial sampling.
# init warp parameters
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
typical_p = typical_p if typical_p is not None else self.config.typical_p
temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = LogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if temperature is not None and temperature != 1.0:
if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if typical_p is not None and typical_p < 1.0:
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
return warpers
def _get_logits_processor(
repetition_penalty: float,
no_repeat_ngram_size: int,
encoder_no_repeat_ngram_size: int,
input_ids_seq_length: int,
encoder_input_ids: torch.LongTensor,
bad_words_ids: List[List[int]],
min_length: int,
max_length: int,
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
num_beams: int,
num_beam_groups: int,
diversity_penalty: float,
remove_invalid_values: bool,
exponential_decay_length_penalty: Tuple,
logits_processor: Optional[LogitsProcessorList],
renormalize_logits: Optional[bool],
) -> LogitsProcessorList:
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
instances used to modify the scores of the language model head.
processors = LogitsProcessorList()
# init warp parameters
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
encoder_no_repeat_ngram_size = (
if encoder_no_repeat_ngram_size is not None
else self.config.encoder_no_repeat_ngram_size
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
remove_invalid_values = (
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values
exponential_decay_length_penalty = (
if exponential_decay_length_penalty is not None
else self.config.exponential_decay_length_penalty
# instantiate processors list
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if diversity_penalty is not None and diversity_penalty > 0.0:
diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups
if repetition_penalty is not None and repetition_penalty != 1.0:
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
if self.config.is_encoder_decoder:
processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids))
raise ValueError(
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture"
if bad_words_ids is not None:
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > 0:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))
if forced_bos_token_id is not None:
if forced_eos_token_id is not None:
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
if remove_invalid_values is True:
if exponential_decay_length_penalty is not None:
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length)
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
return processors
def _get_stopping_criteria(
self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList]
) -> StoppingCriteriaList:
criteria = StoppingCriteriaList()
if max_length is not None:
if max_time is not None:
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria
def _merge_criteria_processor_list(
default_list: Union[LogitsProcessorList, StoppingCriteriaList],
custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
) -> Union[LogitsProcessorList, StoppingCriteriaList]:
if len(custom_list) == 0:
return default_list
for default in default_list:
for custom in custom_list:
if type(custom) is type(default):
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
raise ValueError(
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, "
f"but it has already been created with the values {default}. {default} has been created by passing the "
"corresponding arguments to generate or by the model's config default values. "
f"If you just want to change the default values of {object_type} consider passing them as arguments "
f"to `generate` instead of using a custom {object_type}."
return default_list
def compute_transition_beam_scores(
sequences: torch.Tensor,
scores: Tuple[torch.Tensor],
beam_indices: torch.Tensor,
eos_token_id: int = None,
"""compute the transition probabilities of sequences given generation
scores and beam indices"""
# reshape scores as [vocab_size * batch_size, # generation steps]
# with batch_size being 2 * vocab_size and # generation steps being
# seq_len - input_length
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
# start of generated tokens
cut_idx = sequences.shape[-1] - scores.shape[-1]
# adjust for beam indices
beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size
# compute real indices
indices = sequences[:, cut_idx:] + beam_sequence_indices
# gather scores and run
transition_scores = scores.gather(0, indices)
# make sure that if EOS token was used before length of sequence `sequence.shape[-1]`
# get first occurence of EOS token
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if eos_token_id is not None:
is_eos_token_id = sequences[:, cut_idx:] == eos_token_id
# make sure first eos token still contributes to transition probs
is_eos_token_id[:, -1] = False
is_eos_token_id = is_eos_token_id.roll(1, -1)
# all indices after eos shoud be masked
zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool()
# zero out padded probs
transition_scores.masked_fill_(zero_transition_prob_mask, 0.0)
return transition_scores
def remove_subsets(self, l):
#l = [[1, 2, 4, 8], [1, 2, 4, 5, 6], [1, 2, 3], [2, 3, 21], [1, 2, 3, 4], [1, 2, 3, 4, 5, 6, 7]]
l2 = l[:]
for m in l:
for n in l:
if set(m).issubset(set(n)) and m != n:
return l2
def cs_generate(
inputs: Optional[torch.Tensor] = None,
contexts:List[str]=None, #input data
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
max_new_tokens: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
constraints: Optional[List[Constraint]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
use_kg:bool=False, #added
max_neig_per_concept=1, #it slows down quite a lot
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
# print(model_input)
input_ids = model_input['input_ids']
if "input_commonsense_relations" in model_input:
# print(model_input['input_commonsense_relations'].sum())
model_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(input_ids.device)
if use_kg:
all_constraints = []
print('contexts:', contexts[:3])
for context in contexts:
constraints = []
concepts_from_context = relation_mapper_builder.get_concepts_from_context(context=context,
clear_common_wds=True, alignment=1)
print('concepts_from_context:', concepts_from_context)
useful_concepts = [relation_mapper_builder.swow_knowledge.get_related_concepts(concept) for concept in
if not useful_concepts:
useful_concepts = [relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context]
useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] # add spaces
# useful_concepts = [[phrase for phrase in concepts if len(phrase.split(' ')) == 1] for concepts in useful_concepts]
# useful_concepts = list(itertools.chain.from_iterable(useful_concepts))
# print('useful_concepts:', useful_concepts[:5])
print('useful_concepts:', useful_concepts)
if concepts_from_context and useful_concepts:
for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts):
print('neighbour:', neighbour_concepts[:5])
# flexible_words = self.most_similar_words(context_concept, neighbour_concepts) # limit the upperbound
# flexible_words = [word for word in flexible_words if word not in context_concept] # remove input concepts
flexible_words = [word for word in neighbour_concepts if
word not in context_concept] # remove input concepts
print('flexible_words:', flexible_words[:5])
if not flexible_words:
flexible_words_ids: List[List[int]] = tokenizer(flexible_words, add_special_tokens=False).input_ids #add_prefix_space=True,
flexible_words_ids = self.remove_subsets(flexible_words_ids)
# add_prefix_space=True
# flexible_words_ids = [x for x in flexible_words_ids if len(x) == 1] # problem with subsets
flexible_words_ids = flexible_words_ids[:max_neig_per_concept]
#print('flexible_words_ids:', flexible_words_ids[:3])
constraint = DisjunctiveConstraint(flexible_words_ids)
all_constraints = None
generated_answers_encoded = self.generate(input_ids=input_ids,
# eos_token_id=tokenizer.eos_token_id,
return generated_answers_encoded
def cs_simple_generate(
inputs: Optional[torch.Tensor] = None,
neighbours_contexts:List[List[str]]=None, #input data
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
max_new_tokens: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
constraints: Optional[List[Constraint]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
use_kg:bool=False, #added
max_concepts=2, #it slows down quite a lot
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
# print(model_input)
input_ids = model_input['input_ids']
if use_kg:
all_constraints = []
for context_neighbours in neighbours_contexts:
# context_neighbours is a collection of concepts
# lets create sub collections of concepts
context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3]
n_size_chuncks = len(context_neighbours) // max_concepts
n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1
sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks))
constraints = []
for sub_concepts in sub_concepts_collection[:max_concepts]:
flexible_words_ids: List[List[int]] = tokenizer(sub_concepts, add_special_tokens=False).input_ids #add_prefix_space=True,
#flexible_words_ids = self.remove_subsets(flexible_words_ids)
flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids]
disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids))))
# add_prefix_space=True
# flexible_words_ids = [x for x in flexible_words_ids if len(x) == 1] # problem with subsets
#flexible_words_ids = flexible_words_ids[:max_neig_per_concept]
#print('flexible_words_ids:', flexible_words_ids[:3])
if not any(disjunctive_set):
constraint = DisjunctiveConstraint(disjunctive_set)
if not any(constraints):
all_constraints = None
if not all_constraints:
all_constraints = None
generated_answers_encoded = []
#print('all_constraints:', all_constraints)
for i, contraints in enumerate(all_constraints):
#print('contraints.token_ids:', [x.token_ids for x in contraints])
if "input_commonsense_relations" in model_input:
# print(model_input['input_commonsense_relations'].sum())
model_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations")[i].unsqueeze(0).to(input_ids.device)
#print('model_kwargs.get("attention_mask"):', model_kwargs.get("attention_mask"))
model_kwargs["attention_mask"] = model_input.get("attention_mask")[i].unsqueeze(0).to(input_ids.device)
gen = self.generate(input_ids=input_ids[i].unsqueeze(0),
# eos_token_id=tokenizer.eos_token_id,
#print('[gen]:', gen)
#print('generated_answers_encoded:', generated_answers_encoded)
return torch.LongTensor(pad_sequence(generated_answers_encoded, batch_first=True, padding_value=tokenizer.pad_token_id)).to(input_ids.device)
def generate(
inputs: Optional[torch.Tensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
bad_words_ids: Optional[Iterable[int]] = None,
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
max_new_tokens: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
constraints: Optional[List[Constraint]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
Generates sequences of token ids for models with a language modeling head. The method supports the following
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if
`num_beams>1` and `do_sample=True`.
- *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`.
- *constrained beam-search decoding* by calling
[`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or
<Tip warning={true}>
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as
defined in the model's config (`config.json`) which in turn defaults to the
[`~modeling_utils.PretrainedConfig`] of the model.
Most of these parameters are explained in more detail in [this blog
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
max_length (`int`, *optional*, defaults to `model.config.max_length`):
The maximum length of the sequence to be generated.
max_new_tokens (`int`, *optional*, defaults to None):
The maximum numbers of tokens to generate, ignore the current number of tokens. Use either
`max_new_tokens` or `max_length` but not both, they serve the same purpose.
min_length (`int`, *optional*, defaults to 10):
The minimum length of the sequence to be generated.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
are kept for generation.
repetition_penalty (`float`, *optional*, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length.
0.0 means no penalty. Set to values < 0.0 in order to encourage the model to generate longer
sequences, to a value > 0.0 in order to encourage the model to produce shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the
bad_words_ids(`List[List[int]]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the token ids of the words that
should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*):
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple
list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`,
this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081),
where one can allow different forms of each word.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
max_time(`float`, *optional*, defaults to None):
The maximum amount of time you allow the computation to run for in seconds. generation will still
finish the current pass after allocated time has been passed.
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape
as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask)
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
use_cache: (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
num_beam_groups (`int`, *optional*, defaults to 1):
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of
beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
diversity_penalty (`float`, *optional*, defaults to 0.0):
This value is subtracted from a beam's score if it generates a token same as any beam from other group
at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and a
model's config. If a logit processor is passed that is already created with the arguments or a model's
config an error is thrown. This feature is intended for advanced users.
renormalize_logits: (`bool`, *optional*, defaults to `False`):
Whether to renormalize the logits after applying all the logits processors or warpers (including the
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
score logits are normalized but some logit processors or warpers break the normalization.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
model's config. If a stopping criteria is passed that is already created with the arguments or a
model's config an error is thrown. This feature is intended for advanced users.
constraints (`List[Constraint]`, *optional*):
Custom constraints that can be added to the generation to ensure that the output will contain the use
of certain tokens as defined by `Constraint` objects, in the most sensible way possible.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
forced_bos_token_id (`int`, *optional*):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
the target language token.
forced_eos_token_id (`int`, *optional*):
The id of the token to force as the last generated token when `max_length` is reached.
remove_invalid_values (`bool`, *optional*):
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to
crash. Note that using `remove_invalid_values` can slow down generation.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
exponential_decay_length_penalty (`tuple(int, float)`, *optional*):
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates
where penalty starts and `decay_factor` represents the factor of exponential decay
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs
should be prefixed with *decoder_*.
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation_utils.GreedySearchDecoderOnlyOutput`],
- [`~generation_utils.SampleDecoderOnlyOutput`],
- [`~generation_utils.BeamSearchDecoderOnlyOutput`],
- [`~generation_utils.BeamSampleDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation_utils.GreedySearchEncoderDecoderOutput`],
- [`~generation_utils.SampleEncoderDecoderOutput`],
- [`~generation_utils.BeamSearchEncoderDecoderOutput`],
- [`~generation_utils.BeamSampleEncoderDecoderOutput`]
Greedy Decoding:
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> # generate up to 30 tokens
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
Multinomial Sampling:
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> # sample up to 30 tokens
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.generate(input_ids, do_sample=True, max_length=30)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the']
Beam-search decoding:
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> sentence = "Paris is one of the densest populated areas in Europe."
>>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids
>>> outputs = model.generate(input_ids)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
num_beams = num_beams if num_beams is not None else self.config.num_beams
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if eos_token_id is None and hasattr(self.config, "decoder"):
eos_token_id = self.config.decoder.eos_token_id
if pad_token_id is None and eos_token_id is not None:
# special case if pad_token_id is not defined
print(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
# 2. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
# otherwise model_input_name is None
# all model-specific keyword inputs are removed from `model_kwargs`
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs)
batch_size = inputs_tensor.shape[0]
# 3. Define other model kwargs
model_kwargs["output_attentions"] = output_attentions
model_kwargs["output_hidden_states"] = output_hidden_states
model_kwargs["use_cache"] = use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, pad_token_id, eos_token_id
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name
# 4. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids = self._prepare_decoder_input_ids_for_generation(
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
input_ids_seq_length = input_ids.shape[-1]
# 5. Prepare `max_length` depending on other stopping criteria
# if `max_new_tokens` is passed, but not `max_length` -> set `max_length = max_new_tokens`
if max_length is None and max_new_tokens is not None:
max_length = max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None:
# Both are set, this is odd, raise a warning
"Both `max_length` and `max_new_tokens` have been set "
f"but they serve the same purpose. `max_length` {max_length} "
f"will take priority over `max_new_tokens` {max_new_tokens}.",
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
if input_ids_seq_length >= max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. "
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``."
# 6. determine generation mode
is_constraint_gen_mode = constraints is not None or force_words_ids is not None
is_greedy_gen_mode = (
(num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
is_sample_gen_mode = (
(num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
is_beam_gen_mode = (
(num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode
is_beam_sample_gen_mode = (
(num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
if num_beam_groups > num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if is_group_beam_gen_mode and do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
# 7. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
# 8. prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
# 9. go into different generation modes
if is_greedy_gen_mode:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
# 10. run greedy search
return self.greedy_search(
elif is_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
# 12. run sample
return self.sample(
elif is_beam_gen_mode:
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 10. prepare beam search scorer
beam_scorer = BeamSearchScorer(
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
# 12. run beam search
return self.beam_search(
elif is_beam_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size * num_return_sequences,
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
expand_size=num_beams * num_return_sequences,
# 13. run beam sample
return self.beam_sample(
elif is_group_beam_gen_mode:
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if num_beams % num_beam_groups != 0:
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 10. prepare beam search scorer
beam_scorer = BeamSearchScorer(
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
# 12. run beam search
return self.group_beam_search(
elif is_constraint_gen_mode:
if num_return_sequences > num_beams:
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
if num_beams <= 1:
raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.")
if do_sample:
raise ValueError("`do_sample` needs to be false for constrained generation.")
if num_beam_groups is not None and num_beam_groups > 1:
raise ValueError("`num_beam_groups` not supported yet for constrained generation.")
final_constraints = []
if constraints is not None:
final_constraints = constraints
if force_words_ids is not None:
def typeerror():
raise ValueError(
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
f"of positive integers, but is {force_words_ids}."
if not isinstance(force_words_ids, list) or len(force_words_ids) == 0:
for word_ids in force_words_ids:
if isinstance(word_ids[0], list):
if not isinstance(word_ids, list) or len(word_ids) == 0:
if any(not isinstance(token_ids, list) for token_ids in word_ids):
if any(
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
for token_ids in word_ids
constraint = DisjunctiveConstraint(word_ids)
if not isinstance(word_ids, list) or len(word_ids) == 0:
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
constraint = PhrasalConstraint(word_ids)
# 10. prepare beam search scorer
constrained_beam_scorer = ConstrainedBeamSearchScorer(
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
# 12. run beam search
return self.constrained_beam_search(
def greedy_search(
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
) -> Union[GreedySearchOutput, torch.LongTensor]:
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
max_length (`int`, *optional*, defaults to 20):
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
[`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`]
or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if
>>> from transformers import (
... AutoTokenizer,
... AutoModelForCausalLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... StoppingCriteriaList,
... MaxLengthCriteria,
... )
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "It might be possible to"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id),
... ]
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.greedy_search(
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
cur_len = input_ids.shape[-1]
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
# pre-process distribution
next_tokens_scores = logits_processor(input_ids, next_token_logits)
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
cur_len = cur_len + 1
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
this_peer_finished = True
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GreedySearchEncoderDecoderOutput(
return GreedySearchDecoderOnlyOutput(
return input_ids
def sample(
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
) -> Union[SampleOutput, torch.LongTensor]:
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
max_length (`int`, *optional*, defaults to 20):
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
[`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if
>>> from transformers import (
... AutoTokenizer,
... AutoModelForCausalLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... TopKLogitsWarper,
... TemperatureLogitsWarper,
... StoppingCriteriaList,
... MaxLengthCriteria,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "Today is a beautiful day, and"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id),
... ]
... )
>>> # instantiate logits processors
>>> logits_warper = LogitsProcessorList(
... [
... TopKLogitsWarper(50),
... TemperatureLogitsWarper(0.7),
... ]
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.sample(
... input_ids,
... logits_processor=logits_processor,
... logits_warper=logits_warper,
... stopping_criteria=stopping_criteria,
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
cur_len = input_ids.shape[-1]
this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
cur_len = cur_len + 1
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
this_peer_finished = True
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return SampleEncoderDecoderOutput(
return SampleDecoderOnlyOutput(
return input_ids
def beam_search(
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
) -> Union[BeamSearchOutput, torch.LongTensor]:
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
beam_scorer (`BeamScorer`):
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
max_length (`int`, *optional*, defaults to 20):
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
[`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... BeamSearchScorer,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
... )
... }
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... num_beams=num_beams,
... device=model.device,
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
... ]
... )
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
#Normal execution
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores_processed,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
next_indices = torch_int_div(next_tokens, vocab_size)
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
return BeamSearchDecoderOnlyOutput(
return sequence_outputs["sequences"]
def beam_sample(
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
) -> Union[BeamSampleOutput, torch.LongTensor]:
Generates sequences of token ids for models with a language modeling head using **beam search multinomial
sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
beam_scorer (`BeamScorer`):
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
max_length (`int`, *optional*, defaults to 20):
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
[`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... TopKLogitsWarper,
... TemperatureLogitsWarper,
... BeamSearchScorer,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
... )
... }
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)]
... )
>>> # instantiate logits processors
>>> logits_warper = LogitsProcessorList(
... [
... TopKLogitsWarper(50),
... TemperatureLogitsWarper(0.7),
... ]
... )
>>> outputs = model.beam_sample(
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (logits_warper(input_ids, next_token_scores_processed),)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
next_indices = torch_int_div(next_tokens, vocab_size)
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder:
return BeamSampleEncoderDecoderOutput(
return BeamSampleDecoderOnlyOutput(
return sequence_outputs["sequences"]
def group_beam_search(
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
Generates sequences of token ids for models with a language modeling head using **diverse beam search
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
beam_scorer (`BeamScorer`):
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
max_length (`int`, *optional*, defaults to 20):
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
[`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if
`model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a
[`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`.
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... HammingDiversityLogitsProcessor,
... BeamSearchScorer,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run diverse beam search using 6 beams
>>> num_beams = 6
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
... )
... }
>>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer(
... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams,
... device=model.device,
... num_beam_groups=3,
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
... ]
... )
>>> outputs = model.group_beam_search(
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
device = input_ids.device
batch_beam_size, cur_len = input_ids.shape
if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
beam_indices = None
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
for batch_idx in range(batch_size):
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of current group only
next_token_logits = outputs.logits[batch_group_indices, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1]
next_token_scores_processed = logits_processor(
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
if output_scores:
processed_score[batch_group_indices] = next_token_scores_processed
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
next_indices = torch_int_div(next_tokens, vocab_size)
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = beam_scorer.process(
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
if return_dict_in_generate and output_scores:
beam_indices[beam_group_idx] = tuple(
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (processed_score,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
this_peer_finished = True
sequence_outputs = beam_scorer.finalize(
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
beam_indices = sum(beam_indices, ())
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
return BeamSearchDecoderOnlyOutput(
return sequence_outputs["sequences"]
def constrained_beam_search(
input_ids: torch.LongTensor,
constrained_beam_scorer: ConstrainedBeamSearchScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = None,
) -> Union[BeamSearchOutput, torch.LongTensor]:
Generates sequences of token ids for models with a language modeling head using **constrained beam search
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
constrained_beam_scorer (`ConstrainedBeamSearchScorer`):
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
sorted during generation, while satisfying a list of positive constraints. For more information, the
documentation of [`ConstrainedBeamSearchScorer`] should be read.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
max_length (`int`, *optional*, defaults to 20):
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
tokens. The maximum length of the sequence to be generated.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
[`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if
>>> from transformers import (
... AutoTokenizer,
... AutoModelForSeq2SeqLM,
... LogitsProcessorList,
... MinLengthLogitsProcessor,
... ConstrainedBeamSearchScorer,
... PhrasalConstraint,
... )
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
>>> encoder_input_str = "translate English to German: How old are you?"
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
>>> # lets run beam search using 3 beams
>>> num_beams = 3
>>> # define decoder start token ids
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
>>> input_ids = input_ids * model.config.decoder_start_token_id
>>> # add encoder_outputs to model keyword arguments
>>> model_kwargs = {
... "encoder_outputs": model.get_encoder()(
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
... )
... }
>>> constraint_str = "Sie"
>>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token
>>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)]
>>> # instantiate beam scorer
>>> beam_scorer = ConstrainedBeamSearchScorer(
... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints
... )
>>> # instantiate logits processors
>>> logits_processor = LogitsProcessorList(
... [
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
... ]
... )
>>> outputs = model.constrained_beam_search(
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt sind Sie?']
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
batch_size = len(constrained_beam_scorer._beam_hyps)
num_beams = constrained_beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
scores_for_all_vocab = next_token_scores_processed.clone()
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
next_indices = (next_tokens / vocab_size).long()
next_tokens = next_tokens % vocab_size
# stateless
beam_outputs = constrained_beam_scorer.process(
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
# increase cur_len
cur_len = cur_len + 1
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores):
if not synced_gpus:
this_peer_finished = True
sequence_outputs = constrained_beam_scorer.finalize(
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
return BeamSearchDecoderOnlyOutput(
return sequence_outputs["sequences"]
def top_k_top_p_filtering(
logits: torch.FloatTensor,
top_k: int = 0,
top_p: float = 1.0,
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> torch.FloatTensor:
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
logits: logits distribution shape (batch size, vocabulary size)
top_k (`int`, *optional*, defaults to 0):
If > 0, only keep the top k tokens with highest probability (top-k filtering)
top_p (`float`, *optional*, defaults to 1.0):
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
min_tokens_to_keep (`int`, *optional*, defaults to 1):
Minimumber of tokens we keep per batch example in the output.
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
if top_k > 0:
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
None, logits
if 0 <= top_p <= 1.0:
logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits)
return logits