|
import inspect |
|
import warnings |
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch import nn |
|
|
|
from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint |
|
from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer |
|
from transformers.generation_logits_process import ( |
|
EncoderNoRepeatNGramLogitsProcessor, |
|
ExponentialDecayLengthPenalty, |
|
ForcedBOSTokenLogitsProcessor, |
|
ForcedEOSTokenLogitsProcessor, |
|
HammingDiversityLogitsProcessor, |
|
InfNanRemoveLogitsProcessor, |
|
LogitNormalization, |
|
LogitsProcessorList, |
|
MinLengthLogitsProcessor, |
|
NoBadWordsLogitsProcessor, |
|
NoRepeatNGramLogitsProcessor, |
|
PrefixConstrainedLogitsProcessor, |
|
RepetitionPenaltyLogitsProcessor, |
|
TemperatureLogitsWarper, |
|
TopKLogitsWarper, |
|
TopPLogitsWarper, |
|
TypicalLogitsWarper, |
|
) |
|
from transformers.generation_stopping_criteria import ( |
|
MaxLengthCriteria, |
|
MaxTimeCriteria, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
validate_stopping_criteria, |
|
) |
|
from transformers.pytorch_utils import torch_int_div |
|
from transformers.utils import ModelOutput |
|
|
|
from transformers.generation_utils import ( |
|
SampleOutput, |
|
BeamSearchOutput, |
|
BeamSampleOutput, |
|
GreedySearchOutput, GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, GreedySearchEncoderDecoderOutput, |
|
BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, BeamSampleDecoderOnlyOutput, |
|
BeamSampleEncoderDecoderOutput, SampleEncoderDecoderOutput, |
|
) |
|
from utils import get_jump_chunks |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
class GenerationMixin: |
|
""" |
|
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. |
|
|
|
The class exposes [`~generation_utils.GenerationMixin.generate`], which can be used for: |
|
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and |
|
`do_sample=False`. |
|
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and |
|
`do_sample=True`. |
|
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and |
|
`do_sample=False`. |
|
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if |
|
`num_beams>1` and `do_sample=True`. |
|
- *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if |
|
`num_beams>1` and `num_beam_groups>1`. |
|
- *constrained beam-search decoding* by calling [`~generation_utils.GenerationMixin.constrained_beam_search`], |
|
if `constraints!=None` or `force_words_ids!=None`. |
|
""" |
|
|
|
def _prepare_model_inputs( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
bos_token_id: Optional[int] = None, |
|
model_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: |
|
""" |
|
This function extracts the model-specific `inputs` for generation. |
|
""" |
|
|
|
|
|
if ( |
|
self.config.is_encoder_decoder |
|
and hasattr(self, "encoder") |
|
and self.encoder.main_input_name != self.main_input_name |
|
): |
|
input_name = self.encoder.main_input_name |
|
else: |
|
input_name = self.main_input_name |
|
|
|
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} |
|
|
|
|
|
|
|
inputs_kwarg = model_kwargs.pop(input_name, None) |
|
if inputs_kwarg is not None and inputs is not None: |
|
raise ValueError( |
|
f"`inputs`: {inputs}` were passed alongside " |
|
f"{input_name} which is not allowed." |
|
f"Make sure to either pass {inputs} or {input_name}=..." |
|
) |
|
elif inputs_kwarg is not None: |
|
inputs = inputs_kwarg |
|
|
|
|
|
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs): |
|
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" |
|
|
|
|
|
if not self.config.is_encoder_decoder and input_name != "input_ids": |
|
raise ValueError( |
|
f"If {input_name} is passed as model-specific keyword " |
|
"input then model has to be an encoder-decoder and not a " |
|
f"{self.__class__.__name__}." |
|
) |
|
|
|
|
|
if inputs is None: |
|
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) |
|
|
|
return inputs, input_name, model_kwargs |
|
|
|
def _can_retrieve_inputs_from_name( |
|
self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor] |
|
) -> torch.Tensor: |
|
""" |
|
If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved |
|
from name |
|
""" |
|
can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set( |
|
inspect.signature(self.forward).parameters.keys() |
|
) |
|
|
|
if can_retrieve_inputs and inputs is not None: |
|
raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}") |
|
|
|
return can_retrieve_inputs |
|
|
|
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: |
|
""" |
|
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method. |
|
""" |
|
return {"input_ids": input_ids} |
|
|
|
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: |
|
""" |
|
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. |
|
""" |
|
return logits |
|
|
|
def _prepare_input_ids_for_generation( |
|
self, bos_token_id: Optional[int], encoder_outputs: Optional[ModelOutput] |
|
) -> torch.LongTensor: |
|
if self.config.is_encoder_decoder and encoder_outputs is not None: |
|
|
|
shape = encoder_outputs.last_hidden_state.size()[:-1] |
|
return torch.ones(shape, dtype=torch.long, device=self.device) * -100 |
|
|
|
if bos_token_id is None: |
|
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") |
|
return torch.ones((1, 1), dtype=torch.long, device=self.device) * bos_token_id |
|
|
|
def _prepare_attention_mask_for_generation( |
|
self, |
|
inputs: torch.Tensor, |
|
pad_token_id: int, |
|
eos_token_id: int, |
|
) -> torch.LongTensor: |
|
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] |
|
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) |
|
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( |
|
(eos_token_id is not None) and (pad_token_id != eos_token_id) |
|
) |
|
|
|
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: |
|
return inputs.ne(pad_token_id).long() |
|
else: |
|
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) |
|
|
|
def _prepare_encoder_decoder_kwargs_for_generation( |
|
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None |
|
) -> Dict[str, Any]: |
|
|
|
encoder = self.get_encoder() |
|
|
|
|
|
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] |
|
encoder_kwargs = { |
|
argument: value |
|
for argument, value in model_kwargs.items() |
|
if not any(argument.startswith(p) for p in irrelevant_prefix) |
|
} |
|
print('encoder_kwargs:', encoder_kwargs) |
|
|
|
|
|
model_input_name = model_input_name if model_input_name is not None else self.main_input_name |
|
encoder_kwargs["return_dict"] = True |
|
encoder_kwargs[model_input_name] = inputs_tensor |
|
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs) |
|
|
|
return model_kwargs |
|
|
|
def _prepare_decoder_input_ids_for_generation( |
|
self, |
|
batch_size: int, |
|
decoder_start_token_id: int = None, |
|
bos_token_id: int = None, |
|
model_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
device: torch.device = None, |
|
) -> torch.LongTensor: |
|
|
|
if model_kwargs is not None and "decoder_input_ids" in model_kwargs: |
|
return model_kwargs.pop("decoder_input_ids") |
|
else: |
|
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) |
|
if device is None: |
|
device = self.device |
|
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id |
|
|
|
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: |
|
decoder_start_token_id = ( |
|
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id |
|
) |
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id |
|
|
|
if decoder_start_token_id is not None: |
|
return decoder_start_token_id |
|
elif ( |
|
hasattr(self.config, "decoder") |
|
and hasattr(self.config.decoder, "decoder_start_token_id") |
|
and self.config.decoder.decoder_start_token_id is not None |
|
): |
|
return self.config.decoder.decoder_start_token_id |
|
elif bos_token_id is not None: |
|
return bos_token_id |
|
elif ( |
|
hasattr(self.config, "decoder") |
|
and hasattr(self.config.decoder, "bos_token_id") |
|
and self.config.decoder.bos_token_id is not None |
|
): |
|
return self.config.decoder.bos_token_id |
|
raise ValueError( |
|
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." |
|
) |
|
|
|
@staticmethod |
|
def _expand_inputs_for_generation( |
|
input_ids: torch.LongTensor, |
|
expand_size: int = 1, |
|
is_encoder_decoder: bool = False, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
encoder_outputs: Optional[ModelOutput] = None, |
|
**model_kwargs, |
|
) -> Tuple[torch.LongTensor, Dict[str, Any]]: |
|
expanded_return_idx = ( |
|
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) |
|
) |
|
input_ids = input_ids.index_select(0, expanded_return_idx) |
|
|
|
if "token_type_ids" in model_kwargs: |
|
token_type_ids = model_kwargs["token_type_ids"] |
|
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) |
|
|
|
if attention_mask is not None: |
|
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) |
|
|
|
if is_encoder_decoder: |
|
if encoder_outputs is None: |
|
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") |
|
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( |
|
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) |
|
) |
|
model_kwargs["encoder_outputs"] = encoder_outputs |
|
return input_ids, model_kwargs |
|
|
|
@staticmethod |
|
def _update_model_kwargs_for_generation( |
|
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False |
|
) -> Dict[str, Any]: |
|
|
|
if "past_key_values" in outputs: |
|
model_kwargs["past"] = outputs.past_key_values |
|
elif "mems" in outputs: |
|
model_kwargs["past"] = outputs.mems |
|
elif "past_buckets_states" in outputs: |
|
model_kwargs["past"] = outputs.past_buckets_states |
|
else: |
|
model_kwargs["past"] = None |
|
|
|
|
|
if "token_type_ids" in model_kwargs: |
|
token_type_ids = model_kwargs["token_type_ids"] |
|
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) |
|
|
|
|
|
if not is_encoder_decoder: |
|
if "attention_mask" in model_kwargs: |
|
attention_mask = model_kwargs["attention_mask"] |
|
model_kwargs["attention_mask"] = torch.cat( |
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 |
|
) |
|
|
|
return model_kwargs |
|
|
|
def _reorder_cache(self, past, beam_idx): |
|
raise NotImplementedError( |
|
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}" |
|
) |
|
|
|
def _get_logits_warper( |
|
self, |
|
top_k: Optional[int] = None, |
|
top_p: Optional[float] = None, |
|
typical_p: Optional[float] = None, |
|
temperature: Optional[float] = None, |
|
num_beams: Optional[int] = None, |
|
renormalize_logits: Optional[bool] = None, |
|
) -> LogitsProcessorList: |
|
""" |
|
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances |
|
used for multinomial sampling. |
|
""" |
|
|
|
|
|
top_k = top_k if top_k is not None else self.config.top_k |
|
top_p = top_p if top_p is not None else self.config.top_p |
|
typical_p = typical_p if typical_p is not None else self.config.typical_p |
|
temperature = temperature if temperature is not None else self.config.temperature |
|
|
|
warpers = LogitsProcessorList() |
|
|
|
|
|
|
|
if temperature is not None and temperature != 1.0: |
|
warpers.append(TemperatureLogitsWarper(temperature)) |
|
if top_k is not None and top_k != 0: |
|
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) |
|
if top_p is not None and top_p < 1.0: |
|
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) |
|
if typical_p is not None and typical_p < 1.0: |
|
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) |
|
|
|
if renormalize_logits is True: |
|
warpers.append(LogitNormalization()) |
|
return warpers |
|
|
|
def _get_logits_processor( |
|
self, |
|
repetition_penalty: float, |
|
no_repeat_ngram_size: int, |
|
encoder_no_repeat_ngram_size: int, |
|
input_ids_seq_length: int, |
|
encoder_input_ids: torch.LongTensor, |
|
bad_words_ids: List[List[int]], |
|
min_length: int, |
|
max_length: int, |
|
eos_token_id: int, |
|
forced_bos_token_id: int, |
|
forced_eos_token_id: int, |
|
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], |
|
num_beams: int, |
|
num_beam_groups: int, |
|
diversity_penalty: float, |
|
remove_invalid_values: bool, |
|
exponential_decay_length_penalty: Tuple, |
|
logits_processor: Optional[LogitsProcessorList], |
|
renormalize_logits: Optional[bool], |
|
) -> LogitsProcessorList: |
|
""" |
|
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] |
|
instances used to modify the scores of the language model head. |
|
""" |
|
processors = LogitsProcessorList() |
|
|
|
|
|
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty |
|
no_repeat_ngram_size = ( |
|
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size |
|
) |
|
encoder_no_repeat_ngram_size = ( |
|
encoder_no_repeat_ngram_size |
|
if encoder_no_repeat_ngram_size is not None |
|
else self.config.encoder_no_repeat_ngram_size |
|
) |
|
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.config.diversity_penalty |
|
forced_bos_token_id = ( |
|
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id |
|
) |
|
forced_eos_token_id = ( |
|
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id |
|
) |
|
remove_invalid_values = ( |
|
remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values |
|
) |
|
exponential_decay_length_penalty = ( |
|
exponential_decay_length_penalty |
|
if exponential_decay_length_penalty is not None |
|
else self.config.exponential_decay_length_penalty |
|
) |
|
|
|
|
|
|
|
|
|
if diversity_penalty is not None and diversity_penalty > 0.0: |
|
processors.append( |
|
HammingDiversityLogitsProcessor( |
|
diversity_penalty=diversity_penalty, num_beams=num_beams, num_beam_groups=num_beam_groups |
|
) |
|
) |
|
if repetition_penalty is not None and repetition_penalty != 1.0: |
|
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) |
|
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: |
|
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) |
|
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: |
|
if self.config.is_encoder_decoder: |
|
processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) |
|
else: |
|
raise ValueError( |
|
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" |
|
) |
|
if bad_words_ids is not None: |
|
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) |
|
if min_length is not None and eos_token_id is not None and min_length > 0: |
|
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) |
|
if prefix_allowed_tokens_fn is not None: |
|
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) |
|
if forced_bos_token_id is not None: |
|
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) |
|
if forced_eos_token_id is not None: |
|
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) |
|
if remove_invalid_values is True: |
|
processors.append(InfNanRemoveLogitsProcessor()) |
|
if exponential_decay_length_penalty is not None: |
|
processors.append( |
|
ExponentialDecayLengthPenalty(exponential_decay_length_penalty, eos_token_id, input_ids_seq_length) |
|
) |
|
processors = self._merge_criteria_processor_list(processors, logits_processor) |
|
|
|
if renormalize_logits is True: |
|
processors.append(LogitNormalization()) |
|
return processors |
|
|
|
def _get_stopping_criteria( |
|
self, max_length: Optional[int], max_time: Optional[float], stopping_criteria: Optional[StoppingCriteriaList] |
|
) -> StoppingCriteriaList: |
|
criteria = StoppingCriteriaList() |
|
if max_length is not None: |
|
criteria.append(MaxLengthCriteria(max_length=max_length)) |
|
if max_time is not None: |
|
criteria.append(MaxTimeCriteria(max_time=max_time)) |
|
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) |
|
return criteria |
|
|
|
def _merge_criteria_processor_list( |
|
self, |
|
default_list: Union[LogitsProcessorList, StoppingCriteriaList], |
|
custom_list: Union[LogitsProcessorList, StoppingCriteriaList], |
|
) -> Union[LogitsProcessorList, StoppingCriteriaList]: |
|
if len(custom_list) == 0: |
|
return default_list |
|
for default in default_list: |
|
for custom in custom_list: |
|
if type(custom) is type(default): |
|
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" |
|
raise ValueError( |
|
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to `generate`, " |
|
f"but it has already been created with the values {default}. {default} has been created by passing the " |
|
"corresponding arguments to generate or by the model's config default values. " |
|
f"If you just want to change the default values of {object_type} consider passing them as arguments " |
|
f"to `generate` instead of using a custom {object_type}." |
|
) |
|
default_list.extend(custom_list) |
|
return default_list |
|
|
|
def compute_transition_beam_scores( |
|
self, |
|
sequences: torch.Tensor, |
|
scores: Tuple[torch.Tensor], |
|
beam_indices: torch.Tensor, |
|
eos_token_id: int = None, |
|
): |
|
"""compute the transition probabilities of sequences given generation |
|
scores and beam indices""" |
|
|
|
|
|
|
|
|
|
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) |
|
|
|
|
|
cut_idx = sequences.shape[-1] - scores.shape[-1] |
|
|
|
beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size |
|
|
|
indices = sequences[:, cut_idx:] + beam_sequence_indices |
|
|
|
transition_scores = scores.gather(0, indices) |
|
|
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
|
|
if eos_token_id is not None: |
|
is_eos_token_id = sequences[:, cut_idx:] == eos_token_id |
|
|
|
is_eos_token_id[:, -1] = False |
|
is_eos_token_id = is_eos_token_id.roll(1, -1) |
|
|
|
zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool() |
|
|
|
transition_scores.masked_fill_(zero_transition_prob_mask, 0.0) |
|
|
|
return transition_scores |
|
|
|
|
|
def remove_subsets(self, l): |
|
|
|
l2 = l[:] |
|
for m in l: |
|
for n in l: |
|
if set(m).issubset(set(n)) and m != n: |
|
l2.remove(m) |
|
break |
|
return l2 |
|
|
|
|
|
@torch.no_grad() |
|
def cs_generate( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
contexts:List[str]=None, |
|
model_input:Dict=None, |
|
max_length: Optional[int] = None, |
|
min_length: Optional[int] = None, |
|
do_sample: Optional[bool] = None, |
|
early_stopping: Optional[bool] = None, |
|
num_beams: Optional[int] = None, |
|
temperature: Optional[float] = None, |
|
top_k: Optional[int] = None, |
|
top_p: Optional[float] = None, |
|
typical_p: Optional[float] = None, |
|
repetition_penalty: Optional[float] = None, |
|
bad_words_ids: Optional[Iterable[int]] = None, |
|
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, |
|
bos_token_id: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
length_penalty: Optional[float] = None, |
|
no_repeat_ngram_size: Optional[int] = None, |
|
encoder_no_repeat_ngram_size: Optional[int] = None, |
|
num_return_sequences: Optional[int] = None, |
|
max_time: Optional[float] = None, |
|
max_new_tokens: Optional[int] = None, |
|
decoder_start_token_id: Optional[int] = None, |
|
use_cache: Optional[bool] = None, |
|
num_beam_groups: Optional[int] = None, |
|
diversity_penalty: Optional[float] = None, |
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
|
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), |
|
renormalize_logits: Optional[bool] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), |
|
constraints: Optional[List[Constraint]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
forced_bos_token_id: Optional[int] = None, |
|
forced_eos_token_id: Optional[int] = None, |
|
remove_invalid_values: Optional[bool] = None, |
|
synced_gpus: Optional[bool] = False, |
|
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, |
|
use_kg:bool=False, |
|
relation_mapper_builder=None, |
|
tokenizer=None, |
|
max_neig_per_concept=1, |
|
**model_kwargs, |
|
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: |
|
|
|
input_ids = model_input['input_ids'] |
|
if "input_commonsense_relations" in model_input: |
|
|
|
model_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(input_ids.device) |
|
if use_kg: |
|
all_constraints = [] |
|
print('contexts:', contexts[:3]) |
|
for context in contexts: |
|
constraints = [] |
|
print('+++++++') |
|
concepts_from_context = relation_mapper_builder.get_concepts_from_context(context=context, |
|
clear_common_wds=True, alignment=1) |
|
print('concepts_from_context:', concepts_from_context) |
|
useful_concepts = [relation_mapper_builder.swow_knowledge.get_related_concepts(concept) for concept in |
|
concepts_from_context] |
|
if not useful_concepts: |
|
useful_concepts = [relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context] |
|
useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] |
|
|
|
|
|
|
|
print('-------') |
|
print('useful_concepts:', useful_concepts) |
|
if concepts_from_context and useful_concepts: |
|
for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts): |
|
print('neighbour:', neighbour_concepts[:5]) |
|
|
|
|
|
flexible_words = [word for word in neighbour_concepts if |
|
word not in context_concept] |
|
print('flexible_words:', flexible_words[:5]) |
|
if not flexible_words: |
|
continue |
|
flexible_words_ids: List[List[int]] = tokenizer(flexible_words, add_special_tokens=False).input_ids |
|
flexible_words_ids = self.remove_subsets(flexible_words_ids) |
|
|
|
|
|
flexible_words_ids = flexible_words_ids[:max_neig_per_concept] |
|
|
|
constraint = DisjunctiveConstraint(flexible_words_ids) |
|
constraints.append(constraint) |
|
all_constraints.extend(constraints) |
|
else: |
|
all_constraints = None |
|
|
|
generated_answers_encoded = self.generate(input_ids=input_ids, |
|
|
|
constraints=all_constraints, |
|
min_length=min_length, |
|
|
|
do_sample=do_sample, |
|
early_stopping=early_stopping, |
|
|
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
|
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
num_return_sequences=num_return_sequences, |
|
return_dict_in_generate=return_dict_in_generate, |
|
output_attentions=output_attentions, |
|
output_scores=output_scores, |
|
**model_kwargs, |
|
) |
|
return generated_answers_encoded |
|
|
|
|
|
@torch.no_grad() |
|
def cs_simple_generate( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
neighbours_contexts:List[List[str]]=None, |
|
model_input:Dict=None, |
|
max_length: Optional[int] = None, |
|
min_length: Optional[int] = None, |
|
do_sample: Optional[bool] = None, |
|
early_stopping: Optional[bool] = None, |
|
num_beams: Optional[int] = None, |
|
temperature: Optional[float] = None, |
|
top_k: Optional[int] = None, |
|
top_p: Optional[float] = None, |
|
typical_p: Optional[float] = None, |
|
repetition_penalty: Optional[float] = None, |
|
bad_words_ids: Optional[Iterable[int]] = None, |
|
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, |
|
bos_token_id: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
length_penalty: Optional[float] = None, |
|
no_repeat_ngram_size: Optional[int] = None, |
|
encoder_no_repeat_ngram_size: Optional[int] = None, |
|
num_return_sequences: Optional[int] = None, |
|
max_time: Optional[float] = None, |
|
max_new_tokens: Optional[int] = None, |
|
decoder_start_token_id: Optional[int] = None, |
|
use_cache: Optional[bool] = None, |
|
num_beam_groups: Optional[int] = None, |
|
diversity_penalty: Optional[float] = None, |
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
|
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), |
|
renormalize_logits: Optional[bool] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), |
|
constraints: Optional[List[Constraint]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
forced_bos_token_id: Optional[int] = None, |
|
forced_eos_token_id: Optional[int] = None, |
|
remove_invalid_values: Optional[bool] = None, |
|
synced_gpus: Optional[bool] = False, |
|
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, |
|
use_kg:bool=False, |
|
relation_mapper_builder=None, |
|
tokenizer=None, |
|
max_concepts=2, |
|
**model_kwargs, |
|
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: |
|
|
|
input_ids = model_input['input_ids'] |
|
if use_kg: |
|
all_constraints = [] |
|
for context_neighbours in neighbours_contexts: |
|
|
|
|
|
context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3] |
|
n_size_chuncks = len(context_neighbours) // max_concepts |
|
n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1 |
|
sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks)) |
|
constraints = [] |
|
for sub_concepts in sub_concepts_collection[:max_concepts]: |
|
flexible_words_ids: List[List[int]] = tokenizer(sub_concepts, add_special_tokens=False).input_ids |
|
|
|
flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids] |
|
disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids)))) |
|
|
|
|
|
|
|
|
|
|
|
if not any(disjunctive_set): |
|
continue |
|
constraint = DisjunctiveConstraint(disjunctive_set) |
|
constraints.append(constraint) |
|
if not any(constraints): |
|
constraints=None |
|
all_constraints.append(constraints) |
|
else: |
|
all_constraints = None |
|
if not all_constraints: |
|
all_constraints = None |
|
|
|
generated_answers_encoded = [] |
|
|
|
for i, contraints in enumerate(all_constraints): |
|
|
|
if "input_commonsense_relations" in model_input: |
|
|
|
model_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations")[i].unsqueeze(0).to(input_ids.device) |
|
|
|
model_kwargs["attention_mask"] = model_input.get("attention_mask")[i].unsqueeze(0).to(input_ids.device) |
|
gen = self.generate(input_ids=input_ids[i].unsqueeze(0), |
|
constraints=contraints, |
|
min_length=min_length, |
|
|
|
do_sample=do_sample, |
|
early_stopping=early_stopping, |
|
|
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
|
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
num_return_sequences=num_return_sequences, |
|
return_dict_in_generate=return_dict_in_generate, |
|
output_attentions=output_attentions, |
|
output_scores=output_scores, |
|
**model_kwargs) |
|
|
|
|
|
generated_answers_encoded.append(gen[0].detach().cpu()) |
|
|
|
|
|
return torch.LongTensor(pad_sequence(generated_answers_encoded, batch_first=True, padding_value=tokenizer.pad_token_id)).to(input_ids.device) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
max_length: Optional[int] = None, |
|
min_length: Optional[int] = None, |
|
do_sample: Optional[bool] = None, |
|
early_stopping: Optional[bool] = None, |
|
num_beams: Optional[int] = None, |
|
temperature: Optional[float] = None, |
|
top_k: Optional[int] = None, |
|
top_p: Optional[float] = None, |
|
typical_p: Optional[float] = None, |
|
repetition_penalty: Optional[float] = None, |
|
bad_words_ids: Optional[Iterable[int]] = None, |
|
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, |
|
bos_token_id: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
length_penalty: Optional[float] = None, |
|
no_repeat_ngram_size: Optional[int] = None, |
|
encoder_no_repeat_ngram_size: Optional[int] = None, |
|
num_return_sequences: Optional[int] = None, |
|
max_time: Optional[float] = None, |
|
max_new_tokens: Optional[int] = None, |
|
decoder_start_token_id: Optional[int] = None, |
|
use_cache: Optional[bool] = None, |
|
num_beam_groups: Optional[int] = None, |
|
diversity_penalty: Optional[float] = None, |
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
|
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), |
|
renormalize_logits: Optional[bool] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), |
|
constraints: Optional[List[Constraint]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
forced_bos_token_id: Optional[int] = None, |
|
forced_eos_token_id: Optional[int] = None, |
|
remove_invalid_values: Optional[bool] = None, |
|
synced_gpus: Optional[bool] = False, |
|
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, |
|
**model_kwargs, |
|
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]: |
|
r""" |
|
|
|
Generates sequences of token ids for models with a language modeling head. The method supports the following |
|
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: |
|
|
|
- *greedy decoding* by calling [`~generation_utils.GenerationMixin.greedy_search`] if `num_beams=1` and |
|
`do_sample=False`. |
|
- *multinomial sampling* by calling [`~generation_utils.GenerationMixin.sample`] if `num_beams=1` and |
|
`do_sample=True`. |
|
- *beam-search decoding* by calling [`~generation_utils.GenerationMixin.beam_search`] if `num_beams>1` and |
|
`do_sample=False`. |
|
- *beam-search multinomial sampling* by calling [`~generation_utils.GenerationMixin.beam_sample`] if |
|
`num_beams>1` and `do_sample=True`. |
|
- *diverse beam-search decoding* by calling [`~generation_utils.GenerationMixin.group_beam_search`], if |
|
`num_beams>1` and `num_beam_groups>1`. |
|
- *constrained beam-search decoding* by calling |
|
[`~generation_utils.GenerationMixin.constrained_beam_search`], if `constraints!=None` or |
|
`force_words_ids!=None`. |
|
|
|
<Tip warning={true}> |
|
|
|
Apart from `inputs`, all the arguments below will default to the value of the attribute of the same name as |
|
defined in the model's config (`config.json`) which in turn defaults to the |
|
[`~modeling_utils.PretrainedConfig`] of the model. |
|
|
|
</Tip> |
|
|
|
Most of these parameters are explained in more detail in [this blog |
|
post](https://huggingface.co/blog/how-to-generate). |
|
|
|
Parameters: |
|
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): |
|
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the |
|
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` |
|
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of |
|
`input_ids`, `input_values`, `input_features`, or `pixel_values`. |
|
max_length (`int`, *optional*, defaults to `model.config.max_length`): |
|
The maximum length of the sequence to be generated. |
|
max_new_tokens (`int`, *optional*, defaults to None): |
|
The maximum numbers of tokens to generate, ignore the current number of tokens. Use either |
|
`max_new_tokens` or `max_length` but not both, they serve the same purpose. |
|
min_length (`int`, *optional*, defaults to 10): |
|
The minimum length of the sequence to be generated. |
|
do_sample (`bool`, *optional*, defaults to `False`): |
|
Whether or not to use sampling ; use greedy decoding otherwise. |
|
early_stopping (`bool`, *optional*, defaults to `False`): |
|
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not. |
|
num_beams (`int`, *optional*, defaults to 1): |
|
Number of beams for beam search. 1 means no beam search. |
|
temperature (`float`, *optional*, defaults to 1.0): |
|
The value used to module the next token probabilities. |
|
top_k (`int`, *optional*, defaults to 50): |
|
The number of highest probability vocabulary tokens to keep for top-k-filtering. |
|
top_p (`float`, *optional*, defaults to 1.0): |
|
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher |
|
are kept for generation. |
|
repetition_penalty (`float`, *optional*, defaults to 1.0): |
|
The parameter for repetition penalty. 1.0 means no penalty. See [this |
|
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. |
|
pad_token_id (`int`, *optional*): |
|
The id of the *padding* token. |
|
bos_token_id (`int`, *optional*): |
|
The id of the *beginning-of-sequence* token. |
|
eos_token_id (`int`, *optional*): |
|
The id of the *end-of-sequence* token. |
|
length_penalty (`float`, *optional*, defaults to 1.0): |
|
Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length. |
|
0.0 means no penalty. Set to values < 0.0 in order to encourage the model to generate longer |
|
sequences, to a value > 0.0 in order to encourage the model to produce shorter sequences. |
|
no_repeat_ngram_size (`int`, *optional*, defaults to 0): |
|
If set to int > 0, all ngrams of that size can only occur once. |
|
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): |
|
If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the |
|
`decoder_input_ids`. |
|
bad_words_ids(`List[List[int]]`, *optional*): |
|
List of token ids that are not allowed to be generated. In order to get the token ids of the words that |
|
should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, |
|
add_special_tokens=False).input_ids`. |
|
force_words_ids(`List[List[int]]` or `List[List[List[int]]]`, *optional*): |
|
List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple |
|
list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, |
|
this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), |
|
where one can allow different forms of each word. |
|
num_return_sequences(`int`, *optional*, defaults to 1): |
|
The number of independently computed returned sequences for each element in the batch. |
|
max_time(`float`, *optional*, defaults to None): |
|
The maximum amount of time you allow the computation to run for in seconds. generation will still |
|
finish the current pass after allocated time has been passed. |
|
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens |
|
that are not masked, and 0 for masked tokens. If not provided, will default to a tensor the same shape |
|
as `input_ids` that masks the pad token. [What are attention masks?](../glossary#attention-mask) |
|
decoder_start_token_id (`int`, *optional*): |
|
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. |
|
use_cache: (`bool`, *optional*, defaults to `True`): |
|
Whether or not the model should use the past last key/values attentions (if applicable to the model) to |
|
speed up decoding. |
|
num_beam_groups (`int`, *optional*, defaults to 1): |
|
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of |
|
beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. |
|
diversity_penalty (`float`, *optional*, defaults to 0.0): |
|
This value is subtracted from a beam's score if it generates a token same as any beam from other group |
|
at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is |
|
enabled. |
|
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): |
|
If provided, this function constraints the beam search to allowed tokens only at each step. If not |
|
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and |
|
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned |
|
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful |
|
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity |
|
Retrieval](https://arxiv.org/abs/2010.00904). |
|
logits_processor (`LogitsProcessorList`, *optional*): |
|
Custom logits processors that complement the default logits processors built from arguments and a |
|
model's config. If a logit processor is passed that is already created with the arguments or a model's |
|
config an error is thrown. This feature is intended for advanced users. |
|
renormalize_logits: (`bool`, *optional*, defaults to `False`): |
|
Whether to renormalize the logits after applying all the logits processors or warpers (including the |
|
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the |
|
score logits are normalized but some logit processors or warpers break the normalization. |
|
stopping_criteria (`StoppingCriteriaList`, *optional*): |
|
Custom stopping criteria that complement the default stopping criteria built from arguments and a |
|
model's config. If a stopping criteria is passed that is already created with the arguments or a |
|
model's config an error is thrown. This feature is intended for advanced users. |
|
constraints (`List[Constraint]`, *optional*): |
|
Custom constraints that can be added to the generation to ensure that the output will contain the use |
|
of certain tokens as defined by `Constraint` objects, in the most sensible way possible. |
|
output_attentions (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more details. |
|
output_hidden_states (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more details. |
|
output_scores (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
forced_bos_token_id (`int`, *optional*): |
|
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful |
|
for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be |
|
the target language token. |
|
forced_eos_token_id (`int`, *optional*): |
|
The id of the token to force as the last generated token when `max_length` is reached. |
|
remove_invalid_values (`bool`, *optional*): |
|
Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to |
|
crash. Note that using `remove_invalid_values` can slow down generation. |
|
synced_gpus (`bool`, *optional*, defaults to `False`): |
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
|
exponential_decay_length_penalty (`tuple(int, float)`, *optional*): |
|
This Tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been |
|
generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates |
|
where penalty starts and `decay_factor` represents the factor of exponential decay |
|
|
|
model_kwargs: |
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model |
|
is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs |
|
should be prefixed with *decoder_*. |
|
|
|
Return: |
|
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` |
|
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. |
|
|
|
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible |
|
[`~utils.ModelOutput`] types are: |
|
|
|
- [`~generation_utils.GreedySearchDecoderOnlyOutput`], |
|
- [`~generation_utils.SampleDecoderOnlyOutput`], |
|
- [`~generation_utils.BeamSearchDecoderOnlyOutput`], |
|
- [`~generation_utils.BeamSampleDecoderOnlyOutput`] |
|
|
|
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible |
|
[`~utils.ModelOutput`] types are: |
|
|
|
- [`~generation_utils.GreedySearchEncoderDecoderOutput`], |
|
- [`~generation_utils.SampleEncoderDecoderOutput`], |
|
- [`~generation_utils.BeamSearchEncoderDecoderOutput`], |
|
- [`~generation_utils.BeamSampleEncoderDecoderOutput`] |
|
|
|
Examples: |
|
|
|
Greedy Decoding: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
|
|
>>> prompt = "Today I believe we can finally" |
|
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
>>> # generate up to 30 tokens |
|
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30) |
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n'] |
|
``` |
|
|
|
Multinomial Sampling: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM |
|
>>> import torch |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
|
|
>>> prompt = "Today I believe we can finally" |
|
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
>>> # sample up to 30 tokens |
|
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT |
|
>>> outputs = model.generate(input_ids, do_sample=True, max_length=30) |
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
['Today I believe we can finally get rid of discrimination," said Rep. Mark Pocan (D-Wis.).\n\n"Just look at the'] |
|
``` |
|
|
|
Beam-search decoding: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de") |
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de") |
|
|
|
>>> sentence = "Paris is one of the densest populated areas in Europe." |
|
>>> input_ids = tokenizer(sentence, return_tensors="pt").input_ids |
|
|
|
>>> outputs = model.generate(input_ids) |
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
['Paris ist eines der dichtesten besiedelten Gebiete Europas.'] |
|
```""" |
|
|
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id |
|
num_beams = num_beams if num_beams is not None else self.config.num_beams |
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty |
|
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping |
|
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups |
|
do_sample = do_sample if do_sample is not None else self.config.do_sample |
|
num_return_sequences = ( |
|
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences |
|
) |
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
|
|
if eos_token_id is None and hasattr(self.config, "decoder"): |
|
eos_token_id = self.config.decoder.eos_token_id |
|
|
|
if pad_token_id is None and eos_token_id is not None: |
|
|
|
print(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") |
|
pad_token_id = eos_token_id |
|
|
|
output_scores = output_scores if output_scores is not None else self.config.output_scores |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, bos_token_id, model_kwargs) |
|
batch_size = inputs_tensor.shape[0] |
|
|
|
|
|
model_kwargs["output_attentions"] = output_attentions |
|
model_kwargs["output_hidden_states"] = output_hidden_states |
|
model_kwargs["use_cache"] = use_cache |
|
|
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) |
|
requires_attention_mask = "encoder_outputs" not in model_kwargs |
|
|
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: |
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( |
|
inputs_tensor, pad_token_id, eos_token_id |
|
) |
|
|
|
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: |
|
|
|
|
|
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( |
|
inputs_tensor, model_kwargs, model_input_name |
|
) |
|
|
|
|
|
if self.config.is_encoder_decoder: |
|
input_ids = self._prepare_decoder_input_ids_for_generation( |
|
batch_size, |
|
decoder_start_token_id=decoder_start_token_id, |
|
bos_token_id=bos_token_id, |
|
model_kwargs=model_kwargs, |
|
device=inputs_tensor.device, |
|
) |
|
else: |
|
|
|
input_ids = inputs_tensor |
|
|
|
input_ids_seq_length = input_ids.shape[-1] |
|
|
|
|
|
|
|
if max_length is None and max_new_tokens is not None: |
|
max_length = max_new_tokens + input_ids_seq_length |
|
elif max_length is not None and max_new_tokens is not None: |
|
|
|
warnings.warn( |
|
"Both `max_length` and `max_new_tokens` have been set " |
|
f"but they serve the same purpose. `max_length` {max_length} " |
|
f"will take priority over `max_new_tokens` {max_new_tokens}.", |
|
UserWarning, |
|
) |
|
|
|
max_length = max_length if max_length is not None else self.config.max_length |
|
min_length = min_length if min_length is not None else self.config.min_length |
|
|
|
if min_length is not None and min_length > max_length: |
|
raise ValueError( |
|
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum " |
|
f"length ({max_length})" |
|
) |
|
if input_ids_seq_length >= max_length: |
|
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" |
|
print( |
|
f"Input length of {input_ids_string} is {input_ids_seq_length}, but ``max_length`` is set to {max_length}. " |
|
"This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``." |
|
) |
|
|
|
|
|
is_constraint_gen_mode = constraints is not None or force_words_ids is not None |
|
is_greedy_gen_mode = ( |
|
(num_beams == 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode |
|
) |
|
is_sample_gen_mode = ( |
|
(num_beams == 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode |
|
) |
|
is_beam_gen_mode = ( |
|
(num_beams > 1) and (num_beam_groups == 1) and do_sample is False and not is_constraint_gen_mode |
|
) |
|
is_beam_sample_gen_mode = ( |
|
(num_beams > 1) and (num_beam_groups == 1) and do_sample is True and not is_constraint_gen_mode |
|
) |
|
is_group_beam_gen_mode = (num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode |
|
|
|
if num_beam_groups > num_beams: |
|
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") |
|
if is_group_beam_gen_mode and do_sample is True: |
|
raise ValueError( |
|
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." |
|
) |
|
|
|
|
|
logits_processor = self._get_logits_processor( |
|
repetition_penalty=repetition_penalty, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, |
|
input_ids_seq_length=input_ids_seq_length, |
|
encoder_input_ids=inputs_tensor, |
|
bad_words_ids=bad_words_ids, |
|
min_length=min_length, |
|
max_length=max_length, |
|
eos_token_id=eos_token_id, |
|
forced_bos_token_id=forced_bos_token_id, |
|
forced_eos_token_id=forced_eos_token_id, |
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
|
num_beams=num_beams, |
|
num_beam_groups=num_beam_groups, |
|
diversity_penalty=diversity_penalty, |
|
remove_invalid_values=remove_invalid_values, |
|
exponential_decay_length_penalty=exponential_decay_length_penalty, |
|
logits_processor=logits_processor, |
|
renormalize_logits=renormalize_logits, |
|
) |
|
|
|
|
|
stopping_criteria = self._get_stopping_criteria( |
|
max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria |
|
) |
|
|
|
|
|
if is_greedy_gen_mode: |
|
if num_return_sequences > 1: |
|
raise ValueError( |
|
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." |
|
) |
|
|
|
|
|
return self.greedy_search( |
|
input_ids, |
|
logits_processor=logits_processor, |
|
stopping_criteria=stopping_criteria, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
output_scores=output_scores, |
|
return_dict_in_generate=return_dict_in_generate, |
|
synced_gpus=synced_gpus, |
|
**model_kwargs, |
|
) |
|
|
|
elif is_sample_gen_mode: |
|
|
|
logits_warper = self._get_logits_warper( |
|
top_k=top_k, |
|
top_p=top_p, |
|
typical_p=typical_p, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
renormalize_logits=renormalize_logits, |
|
) |
|
|
|
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation( |
|
input_ids, |
|
expand_size=num_return_sequences, |
|
is_encoder_decoder=self.config.is_encoder_decoder, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
return self.sample( |
|
input_ids, |
|
logits_processor=logits_processor, |
|
logits_warper=logits_warper, |
|
stopping_criteria=stopping_criteria, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
output_scores=output_scores, |
|
return_dict_in_generate=return_dict_in_generate, |
|
synced_gpus=synced_gpus, |
|
**model_kwargs, |
|
) |
|
|
|
elif is_beam_gen_mode: |
|
if num_return_sequences > num_beams: |
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") |
|
|
|
if stopping_criteria.max_length is None: |
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.") |
|
|
|
|
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=num_beams, |
|
device=inputs_tensor.device, |
|
length_penalty=length_penalty, |
|
do_early_stopping=early_stopping, |
|
num_beam_hyps_to_keep=num_return_sequences, |
|
) |
|
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation( |
|
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs |
|
) |
|
|
|
return self.beam_search( |
|
input_ids, |
|
beam_scorer, |
|
logits_processor=logits_processor, |
|
stopping_criteria=stopping_criteria, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
output_scores=output_scores, |
|
return_dict_in_generate=return_dict_in_generate, |
|
synced_gpus=synced_gpus, |
|
**model_kwargs, |
|
) |
|
|
|
elif is_beam_sample_gen_mode: |
|
|
|
logits_warper = self._get_logits_warper( |
|
top_k=top_k, |
|
top_p=top_p, |
|
typical_p=typical_p, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
renormalize_logits=renormalize_logits, |
|
) |
|
|
|
if stopping_criteria.max_length is None: |
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.") |
|
|
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size * num_return_sequences, |
|
num_beams=num_beams, |
|
device=inputs_tensor.device, |
|
length_penalty=length_penalty, |
|
do_early_stopping=early_stopping, |
|
) |
|
|
|
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation( |
|
input_ids, |
|
expand_size=num_beams * num_return_sequences, |
|
is_encoder_decoder=self.config.is_encoder_decoder, |
|
**model_kwargs, |
|
) |
|
|
|
|
|
return self.beam_sample( |
|
input_ids, |
|
beam_scorer, |
|
logits_processor=logits_processor, |
|
logits_warper=logits_warper, |
|
stopping_criteria=stopping_criteria, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
output_scores=output_scores, |
|
return_dict_in_generate=return_dict_in_generate, |
|
synced_gpus=synced_gpus, |
|
**model_kwargs, |
|
) |
|
|
|
elif is_group_beam_gen_mode: |
|
if num_return_sequences > num_beams: |
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") |
|
|
|
if num_beams % num_beam_groups != 0: |
|
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") |
|
|
|
if stopping_criteria.max_length is None: |
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.") |
|
|
|
|
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=num_beams, |
|
max_length=stopping_criteria.max_length, |
|
device=inputs_tensor.device, |
|
length_penalty=length_penalty, |
|
do_early_stopping=early_stopping, |
|
num_beam_hyps_to_keep=num_return_sequences, |
|
num_beam_groups=num_beam_groups, |
|
) |
|
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation( |
|
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs |
|
) |
|
|
|
return self.group_beam_search( |
|
input_ids, |
|
beam_scorer, |
|
logits_processor=logits_processor, |
|
stopping_criteria=stopping_criteria, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
output_scores=output_scores, |
|
return_dict_in_generate=return_dict_in_generate, |
|
synced_gpus=synced_gpus, |
|
**model_kwargs, |
|
) |
|
|
|
elif is_constraint_gen_mode: |
|
if num_return_sequences > num_beams: |
|
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") |
|
|
|
if stopping_criteria.max_length is None: |
|
raise ValueError("`max_length` needs to be a stopping_criteria for now.") |
|
|
|
if num_beams <= 1: |
|
raise ValueError("`num_beams` needs to be greater than 1 for constrained genertation.") |
|
|
|
if do_sample: |
|
raise ValueError("`do_sample` needs to be false for constrained generation.") |
|
|
|
if num_beam_groups is not None and num_beam_groups > 1: |
|
raise ValueError("`num_beam_groups` not supported yet for constrained generation.") |
|
|
|
final_constraints = [] |
|
if constraints is not None: |
|
final_constraints = constraints |
|
|
|
if force_words_ids is not None: |
|
|
|
def typeerror(): |
|
raise ValueError( |
|
"`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`" |
|
f"of positive integers, but is {force_words_ids}." |
|
) |
|
|
|
if not isinstance(force_words_ids, list) or len(force_words_ids) == 0: |
|
typeerror() |
|
|
|
for word_ids in force_words_ids: |
|
if isinstance(word_ids[0], list): |
|
if not isinstance(word_ids, list) or len(word_ids) == 0: |
|
typeerror() |
|
if any(not isinstance(token_ids, list) for token_ids in word_ids): |
|
typeerror() |
|
if any( |
|
any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) |
|
for token_ids in word_ids |
|
): |
|
typeerror() |
|
|
|
constraint = DisjunctiveConstraint(word_ids) |
|
else: |
|
if not isinstance(word_ids, list) or len(word_ids) == 0: |
|
typeerror() |
|
if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): |
|
typeerror() |
|
|
|
constraint = PhrasalConstraint(word_ids) |
|
final_constraints.append(constraint) |
|
|
|
|
|
constrained_beam_scorer = ConstrainedBeamSearchScorer( |
|
constraints=final_constraints, |
|
batch_size=batch_size, |
|
num_beams=num_beams, |
|
device=inputs_tensor.device, |
|
length_penalty=length_penalty, |
|
do_early_stopping=early_stopping, |
|
num_beam_hyps_to_keep=num_return_sequences, |
|
) |
|
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation( |
|
input_ids, expand_size=num_beams, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs |
|
) |
|
|
|
return self.constrained_beam_search( |
|
input_ids, |
|
constrained_beam_scorer=constrained_beam_scorer, |
|
logits_processor=logits_processor, |
|
stopping_criteria=stopping_criteria, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
output_scores=output_scores, |
|
return_dict_in_generate=return_dict_in_generate, |
|
synced_gpus=synced_gpus, |
|
**model_kwargs, |
|
) |
|
|
|
def greedy_search( |
|
self, |
|
input_ids: torch.LongTensor, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: Optional[bool] = False, |
|
**model_kwargs, |
|
) -> Union[GreedySearchOutput, torch.LongTensor]: |
|
r""" |
|
Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be |
|
used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
|
|
|
Parameters: |
|
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
The sequence used as a prompt for the generation. |
|
logits_processor (`LogitsProcessorList`, *optional*): |
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
used to modify the prediction scores of the language modeling head applied at each generation step. |
|
stopping_criteria (`StoppingCriteriaList`, *optional*): |
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
|
used to tell if the generation loop should stop. |
|
|
|
max_length (`int`, *optional*, defaults to 20): |
|
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
|
tokens. The maximum length of the sequence to be generated. |
|
pad_token_id (`int`, *optional*): |
|
The id of the *padding* token. |
|
eos_token_id (`int`, *optional*): |
|
The id of the *end-of-sequence* token. |
|
output_attentions (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more details. |
|
output_hidden_states (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more details. |
|
output_scores (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
synced_gpus (`bool`, *optional*, defaults to `False`): |
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
|
model_kwargs: |
|
Additional model specific keyword arguments will be forwarded to the `forward` function of the model. |
|
If model is an encoder-decoder model the kwargs should include `encoder_outputs`. |
|
|
|
Return: |
|
[`~generation_utils.GreedySearchDecoderOnlyOutput`], [`~generation_utils.GreedySearchEncoderDecoderOutput`] |
|
or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
|
[`~generation_utils.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
|
`return_dict_in_generate=True` or a [`~generation_utils.GreedySearchEncoderDecoderOutput`] if |
|
`model.config.is_encoder_decoder=True`. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import ( |
|
... AutoTokenizer, |
|
... AutoModelForCausalLM, |
|
... LogitsProcessorList, |
|
... MinLengthLogitsProcessor, |
|
... StoppingCriteriaList, |
|
... MaxLengthCriteria, |
|
... ) |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
|
|
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token |
|
>>> model.config.pad_token_id = model.config.eos_token_id |
|
|
|
>>> input_prompt = "It might be possible to" |
|
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids |
|
|
|
>>> # instantiate logits processors |
|
>>> logits_processor = LogitsProcessorList( |
|
... [ |
|
... MinLengthLogitsProcessor(10, eos_token_id=model.config.eos_token_id), |
|
... ] |
|
... ) |
|
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) |
|
|
|
>>> outputs = model.greedy_search( |
|
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria |
|
... ) |
|
|
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
["It might be possible to get a better understanding of the nature of the problem, but it's not"] |
|
```""" |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
output_scores = output_scores if output_scores is not None else self.config.output_scores |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
|
cur_len = input_ids.shape[-1] |
|
|
|
this_peer_finished = False |
|
while True: |
|
|
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (next_token_logits,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
next_tokens_scores = logits_processor(input_ids, next_token_logits) |
|
|
|
|
|
next_tokens = torch.argmax(next_tokens_scores, dim=-1) |
|
|
|
|
|
if eos_token_id is not None: |
|
if pad_token_id is None: |
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
cur_len = cur_len + 1 |
|
|
|
|
|
if eos_token_id is not None: |
|
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) |
|
|
|
|
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
if return_dict_in_generate: |
|
if self.config.is_encoder_decoder: |
|
return GreedySearchEncoderDecoderOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return GreedySearchDecoderOnlyOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return input_ids |
|
|
|
def sample( |
|
self, |
|
input_ids: torch.LongTensor, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
logits_warper: Optional[LogitsProcessorList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: Optional[bool] = False, |
|
**model_kwargs, |
|
) -> Union[SampleOutput, torch.LongTensor]: |
|
r""" |
|
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and |
|
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
|
|
|
Parameters: |
|
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
The sequence used as a prompt for the generation. |
|
logits_processor (`LogitsProcessorList`, *optional*): |
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
used to modify the prediction scores of the language modeling head applied at each generation step. |
|
stopping_criteria (`StoppingCriteriaList`, *optional*): |
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
|
used to tell if the generation loop should stop. |
|
logits_warper (`LogitsProcessorList`, *optional*): |
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used |
|
to warp the prediction score distribution of the language modeling head applied before multinomial |
|
sampling at each generation step. |
|
max_length (`int`, *optional*, defaults to 20): |
|
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
|
tokens. The maximum length of the sequence to be generated. |
|
pad_token_id (`int`, *optional*): |
|
The id of the *padding* token. |
|
eos_token_id (`int`, *optional*): |
|
The id of the *end-of-sequence* token. |
|
output_attentions (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more details. |
|
output_hidden_states (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more details. |
|
output_scores (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
synced_gpus (`bool`, *optional*, defaults to `False`): |
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
|
model_kwargs: |
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
|
an encoder-decoder model the kwargs should include `encoder_outputs`. |
|
|
|
Return: |
|
[`~generation_utils.SampleDecoderOnlyOutput`], [`~generation_utils.SampleEncoderDecoderOutput`] or |
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
|
[`~generation_utils.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
|
`return_dict_in_generate=True` or a [`~generation_utils.SampleEncoderDecoderOutput`] if |
|
`model.config.is_encoder_decoder=True`. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import ( |
|
... AutoTokenizer, |
|
... AutoModelForCausalLM, |
|
... LogitsProcessorList, |
|
... MinLengthLogitsProcessor, |
|
... TopKLogitsWarper, |
|
... TemperatureLogitsWarper, |
|
... StoppingCriteriaList, |
|
... MaxLengthCriteria, |
|
... ) |
|
>>> import torch |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
|
|
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token |
|
>>> model.config.pad_token_id = model.config.eos_token_id |
|
|
|
>>> input_prompt = "Today is a beautiful day, and" |
|
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids |
|
|
|
>>> # instantiate logits processors |
|
>>> logits_processor = LogitsProcessorList( |
|
... [ |
|
... MinLengthLogitsProcessor(15, eos_token_id=model.config.eos_token_id), |
|
... ] |
|
... ) |
|
>>> # instantiate logits processors |
|
>>> logits_warper = LogitsProcessorList( |
|
... [ |
|
... TopKLogitsWarper(50), |
|
... TemperatureLogitsWarper(0.7), |
|
... ] |
|
... ) |
|
|
|
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) |
|
|
|
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT |
|
>>> outputs = model.sample( |
|
... input_ids, |
|
... logits_processor=logits_processor, |
|
... logits_warper=logits_warper, |
|
... stopping_criteria=stopping_criteria, |
|
... ) |
|
|
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] |
|
```""" |
|
|
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
output_scores = output_scores if output_scores is not None else self.config.output_scores |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
|
|
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) |
|
cur_len = input_ids.shape[-1] |
|
|
|
this_peer_finished = False |
|
|
|
while True: |
|
|
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
next_token_scores = logits_processor(input_ids, next_token_logits) |
|
next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (next_token_scores,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
|
|
|
|
if eos_token_id is not None: |
|
if pad_token_id is None: |
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
cur_len = cur_len + 1 |
|
|
|
|
|
if eos_token_id is not None: |
|
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) |
|
|
|
|
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
if return_dict_in_generate: |
|
if self.config.is_encoder_decoder: |
|
return SampleEncoderDecoderOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return SampleDecoderOnlyOutput( |
|
sequences=input_ids, |
|
scores=scores, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return input_ids |
|
|
|
def beam_search( |
|
self, |
|
input_ids: torch.LongTensor, |
|
beam_scorer: BeamScorer, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: Optional[bool] = False, |
|
**model_kwargs, |
|
) -> Union[BeamSearchOutput, torch.LongTensor]: |
|
r""" |
|
Generates sequences of token ids for models with a language modeling head using **beam search decoding** and |
|
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
|
|
|
Parameters: |
|
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
The sequence used as a prompt for the generation. |
|
beam_scorer (`BeamScorer`): |
|
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and |
|
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. |
|
logits_processor (`LogitsProcessorList`, *optional*): |
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
used to modify the prediction scores of the language modeling head applied at each generation step. |
|
stopping_criteria (`StoppingCriteriaList`, *optional*): |
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
|
used to tell if the generation loop should stop. |
|
max_length (`int`, *optional*, defaults to 20): |
|
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
|
tokens. The maximum length of the sequence to be generated. |
|
pad_token_id (`int`, *optional*): |
|
The id of the *padding* token. |
|
eos_token_id (`int`, *optional*): |
|
The id of the *end-of-sequence* token. |
|
output_attentions (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more details. |
|
output_hidden_states (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more details. |
|
output_scores (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
synced_gpus (`bool`, *optional*, defaults to `False`): |
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
|
model_kwargs: |
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
|
an encoder-decoder model the kwargs should include `encoder_outputs`. |
|
|
|
Return: |
|
[`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or |
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
|
[`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
|
`return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if |
|
`model.config.is_encoder_decoder=True`. |
|
|
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import ( |
|
... AutoTokenizer, |
|
... AutoModelForSeq2SeqLM, |
|
... LogitsProcessorList, |
|
... MinLengthLogitsProcessor, |
|
... BeamSearchScorer, |
|
... ) |
|
>>> import torch |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") |
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
|
|
|
>>> encoder_input_str = "translate English to German: How old are you?" |
|
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
|
|
|
|
|
>>> # lets run beam search using 3 beams |
|
>>> num_beams = 3 |
|
>>> # define decoder start token ids |
|
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) |
|
>>> input_ids = input_ids * model.config.decoder_start_token_id |
|
|
|
>>> # add encoder_outputs to model keyword arguments |
|
>>> model_kwargs = { |
|
... "encoder_outputs": model.get_encoder()( |
|
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True |
|
... ) |
|
... } |
|
|
|
>>> # instantiate beam scorer |
|
>>> beam_scorer = BeamSearchScorer( |
|
... batch_size=1, |
|
... num_beams=num_beams, |
|
... device=model.device, |
|
... ) |
|
|
|
>>> # instantiate logits processors |
|
>>> logits_processor = LogitsProcessorList( |
|
... [ |
|
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), |
|
... ] |
|
... ) |
|
|
|
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) |
|
|
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
['Wie alt bist du?'] |
|
```""" |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
if len(stopping_criteria) == 0: |
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
output_scores = output_scores if output_scores is not None else self.config.output_scores |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
|
) |
|
|
|
batch_size = len(beam_scorer._beam_hyps) |
|
num_beams = beam_scorer.num_beams |
|
|
|
batch_beam_size, cur_len = input_ids.shape |
|
|
|
if num_beams * batch_size != batch_beam_size: |
|
raise ValueError( |
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
beam_indices = ( |
|
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None |
|
) |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
|
beam_scores[:, 1:] = -1e9 |
|
beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
|
this_peer_finished = False |
|
while True: |
|
|
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
|
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (next_token_scores_processed,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
vocab_size = next_token_scores.shape[-1] |
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
|
|
|
next_token_scores, next_tokens = torch.topk( |
|
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True |
|
) |
|
|
|
next_indices = torch_int_div(next_tokens, vocab_size) |
|
next_tokens = next_tokens % vocab_size |
|
|
|
|
|
beam_outputs = beam_scorer.process( |
|
input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
) |
|
|
|
beam_scores = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
|
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
if model_kwargs["past"] is not None: |
|
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) |
|
|
|
if return_dict_in_generate and output_scores: |
|
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) |
|
|
|
|
|
cur_len = cur_len + 1 |
|
|
|
if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
sequence_outputs = beam_scorer.finalize( |
|
input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
) |
|
|
|
if return_dict_in_generate: |
|
if not output_scores: |
|
sequence_outputs["sequence_scores"] = None |
|
else: |
|
num_return_sequences = beam_scorer.num_beam_hyps_to_keep |
|
|
|
beam_indices = tuple( |
|
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) |
|
) |
|
beam_indices = sum(beam_indices, ()) |
|
|
|
if self.config.is_encoder_decoder: |
|
return BeamSearchEncoderDecoderOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=beam_indices, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return BeamSearchDecoderOnlyOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=beam_indices, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return sequence_outputs["sequences"] |
|
|
|
def beam_sample( |
|
self, |
|
input_ids: torch.LongTensor, |
|
beam_scorer: BeamScorer, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
logits_warper: Optional[LogitsProcessorList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: Optional[bool] = False, |
|
**model_kwargs, |
|
) -> Union[BeamSampleOutput, torch.LongTensor]: |
|
r""" |
|
Generates sequences of token ids for models with a language modeling head using **beam search multinomial |
|
sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
|
|
|
Parameters: |
|
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
The sequence used as a prompt for the generation. |
|
beam_scorer (`BeamScorer`): |
|
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and |
|
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. |
|
logits_processor (`LogitsProcessorList`, *optional*): |
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
used to modify the prediction scores of the language modeling head applied at each generation step. |
|
stopping_criteria (`StoppingCriteriaList`, *optional*): |
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
|
used to tell if the generation loop should stop. |
|
logits_warper (`LogitsProcessorList`, *optional*): |
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used |
|
to warp the prediction score distribution of the language modeling head applied before multinomial |
|
sampling at each generation step. |
|
max_length (`int`, *optional*, defaults to 20): |
|
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
|
tokens. The maximum length of the sequence to be generated. |
|
pad_token_id (`int`, *optional*): |
|
The id of the *padding* token. |
|
eos_token_id (`int`, *optional*): |
|
The id of the *end-of-sequence* token. |
|
output_attentions (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more details. |
|
output_hidden_states (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more details. |
|
output_scores (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
synced_gpus (`bool`, *optional*, defaults to `False`): |
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
|
model_kwargs: |
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
|
an encoder-decoder model the kwargs should include `encoder_outputs`. |
|
|
|
Return: |
|
[`~generation_utils.BeamSampleDecoderOnlyOutput`], [`~generation_utils.BeamSampleEncoderDecoderOutput`] or |
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
|
[`~generation_utils.BeamSampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
|
`return_dict_in_generate=True` or a [`~generation_utils.BeamSampleEncoderDecoderOutput`] if |
|
`model.config.is_encoder_decoder=True`. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import ( |
|
... AutoTokenizer, |
|
... AutoModelForSeq2SeqLM, |
|
... LogitsProcessorList, |
|
... MinLengthLogitsProcessor, |
|
... TopKLogitsWarper, |
|
... TemperatureLogitsWarper, |
|
... BeamSearchScorer, |
|
... ) |
|
>>> import torch |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") |
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
|
|
|
>>> encoder_input_str = "translate English to German: How old are you?" |
|
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
|
|
|
>>> # lets run beam search using 3 beams |
|
>>> num_beams = 3 |
|
>>> # define decoder start token ids |
|
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) |
|
>>> input_ids = input_ids * model.config.decoder_start_token_id |
|
|
|
>>> # add encoder_outputs to model keyword arguments |
|
>>> model_kwargs = { |
|
... "encoder_outputs": model.get_encoder()( |
|
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True |
|
... ) |
|
... } |
|
|
|
>>> # instantiate beam scorer |
|
>>> beam_scorer = BeamSearchScorer( |
|
... batch_size=1, |
|
... max_length=model.config.max_length, |
|
... num_beams=num_beams, |
|
... device=model.device, |
|
... ) |
|
|
|
>>> # instantiate logits processors |
|
>>> logits_processor = LogitsProcessorList( |
|
... [MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id)] |
|
... ) |
|
>>> # instantiate logits processors |
|
>>> logits_warper = LogitsProcessorList( |
|
... [ |
|
... TopKLogitsWarper(50), |
|
... TemperatureLogitsWarper(0.7), |
|
... ] |
|
... ) |
|
|
|
>>> outputs = model.beam_sample( |
|
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs |
|
... ) |
|
|
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
['Wie alt bist du?'] |
|
```""" |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
output_scores = output_scores if output_scores is not None else self.config.output_scores |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
|
) |
|
|
|
batch_size = len(beam_scorer._beam_hyps) |
|
num_beams = beam_scorer.num_beams |
|
|
|
batch_beam_size, cur_len = input_ids.shape |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
beam_indices = ( |
|
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None |
|
) |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
|
beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
|
this_peer_finished = False |
|
while True: |
|
|
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
|
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) |
|
next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (logits_warper(input_ids, next_token_scores_processed),) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
vocab_size = next_token_scores.shape[-1] |
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
|
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
|
|
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) |
|
next_token_scores = torch.gather(next_token_scores, -1, next_tokens) |
|
|
|
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) |
|
next_tokens = torch.gather(next_tokens, -1, _indices) |
|
|
|
next_indices = torch_int_div(next_tokens, vocab_size) |
|
next_tokens = next_tokens % vocab_size |
|
|
|
|
|
beam_outputs = beam_scorer.process( |
|
input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
) |
|
beam_scores = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
|
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
if model_kwargs["past"] is not None: |
|
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) |
|
|
|
if return_dict_in_generate and output_scores: |
|
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) |
|
|
|
|
|
cur_len = cur_len + 1 |
|
|
|
if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
sequence_outputs = beam_scorer.finalize( |
|
input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
) |
|
|
|
if return_dict_in_generate: |
|
if not output_scores: |
|
sequence_outputs["sequence_scores"] = None |
|
else: |
|
num_return_sequences = beam_scorer.num_beam_hyps_to_keep |
|
|
|
beam_indices = tuple( |
|
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) |
|
) |
|
beam_indices = sum(beam_indices, ()) |
|
|
|
if self.config.is_encoder_decoder: |
|
return BeamSampleEncoderDecoderOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=beam_indices, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return BeamSampleDecoderOnlyOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=beam_indices, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return sequence_outputs["sequences"] |
|
|
|
def group_beam_search( |
|
self, |
|
input_ids: torch.LongTensor, |
|
beam_scorer: BeamScorer, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: Optional[bool] = False, |
|
**model_kwargs, |
|
): |
|
r""" |
|
Generates sequences of token ids for models with a language modeling head using **diverse beam search |
|
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
|
|
|
Parameters: |
|
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
The sequence used as a prompt for the generation. |
|
beam_scorer (`BeamScorer`): |
|
An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and |
|
sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. |
|
logits_processor (`LogitsProcessorList`, *optional*): |
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
used to modify the prediction scores of the language modeling head applied at each generation step. |
|
stopping_criteria (`StoppingCriteriaList`, *optional*): |
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
|
used to tell if the generation loop should stop. |
|
max_length (`int`, *optional*, defaults to 20): |
|
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
|
tokens. The maximum length of the sequence to be generated. |
|
pad_token_id (`int`, *optional*): |
|
The id of the *padding* token. |
|
eos_token_id (`int`, *optional*): |
|
The id of the *end-of-sequence* token. |
|
output_attentions (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more details. |
|
output_hidden_states (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more details. |
|
output_scores (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
synced_gpus (`bool`, *optional*, defaults to `False`): |
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
|
|
|
model_kwargs: |
|
Additional model specific kwargs that will be forwarded to the `forward` function of the model. If |
|
model is an encoder-decoder model the kwargs should include `encoder_outputs`. |
|
|
|
Return: |
|
[`~generation_utils.BeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or |
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
|
[`~generation_utils.BeamSearchDecoderOnlyOutput`] if [`~generation_utils.BeamSearchDecoderOnlyOutput`] if |
|
`model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a |
|
[`~generation_utils.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import ( |
|
... AutoTokenizer, |
|
... AutoModelForSeq2SeqLM, |
|
... LogitsProcessorList, |
|
... MinLengthLogitsProcessor, |
|
... HammingDiversityLogitsProcessor, |
|
... BeamSearchScorer, |
|
... ) |
|
>>> import torch |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") |
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
|
|
|
>>> encoder_input_str = "translate English to German: How old are you?" |
|
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
|
|
|
|
|
>>> # lets run diverse beam search using 6 beams |
|
>>> num_beams = 6 |
|
>>> # define decoder start token ids |
|
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) |
|
>>> input_ids = input_ids * model.config.decoder_start_token_id |
|
|
|
>>> # add encoder_outputs to model keyword arguments |
|
>>> model_kwargs = { |
|
... "encoder_outputs": model.get_encoder()( |
|
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True |
|
... ) |
|
... } |
|
|
|
>>> # instantiate beam scorer |
|
>>> beam_scorer = BeamSearchScorer( |
|
... batch_size=1, |
|
... max_length=model.config.max_length, |
|
... num_beams=num_beams, |
|
... device=model.device, |
|
... num_beam_groups=3, |
|
... ) |
|
|
|
>>> # instantiate logits processors |
|
>>> logits_processor = LogitsProcessorList( |
|
... [ |
|
... HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3), |
|
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), |
|
... ] |
|
... ) |
|
|
|
>>> outputs = model.group_beam_search( |
|
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs |
|
... ) |
|
|
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
['Wie alt bist du?'] |
|
```""" |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
output_scores = output_scores if output_scores is not None else self.config.output_scores |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
|
) |
|
|
|
batch_size = len(beam_scorer._beam_hyps) |
|
num_beams = beam_scorer.num_beams |
|
num_beam_groups = beam_scorer.num_beam_groups |
|
num_sub_beams = num_beams // num_beam_groups |
|
device = input_ids.device |
|
|
|
batch_beam_size, cur_len = input_ids.shape |
|
|
|
if return_dict_in_generate and output_scores: |
|
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)] |
|
else: |
|
beam_indices = None |
|
|
|
if num_beams * batch_size != batch_beam_size: |
|
raise ValueError( |
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) |
|
|
|
|
|
beam_scores[:, ::num_sub_beams] = 0 |
|
beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
|
this_peer_finished = False |
|
while True: |
|
|
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
|
|
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) |
|
|
|
|
|
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) |
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
if output_scores: |
|
processed_score = torch.zeros_like(outputs.logits[:, -1, :]) |
|
|
|
for beam_group_idx in range(num_beam_groups): |
|
group_start_idx = beam_group_idx * num_sub_beams |
|
group_end_idx = min(group_start_idx + num_sub_beams, num_beams) |
|
group_size = group_end_idx - group_start_idx |
|
|
|
|
|
batch_group_indices = [] |
|
|
|
for batch_idx in range(batch_size): |
|
batch_group_indices.extend( |
|
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] |
|
) |
|
group_input_ids = input_ids[batch_group_indices] |
|
|
|
|
|
next_token_logits = outputs.logits[batch_group_indices, -1, :] |
|
|
|
|
|
|
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
vocab_size = next_token_scores.shape[-1] |
|
|
|
next_token_scores_processed = logits_processor( |
|
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx |
|
) |
|
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) |
|
next_token_scores = next_token_scores.expand_as(next_token_scores_processed) |
|
|
|
if output_scores: |
|
processed_score[batch_group_indices] = next_token_scores_processed |
|
|
|
|
|
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) |
|
|
|
next_token_scores, next_tokens = torch.topk( |
|
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True |
|
) |
|
|
|
next_indices = torch_int_div(next_tokens, vocab_size) |
|
next_tokens = next_tokens % vocab_size |
|
|
|
|
|
beam_outputs = beam_scorer.process( |
|
group_input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
) |
|
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
if return_dict_in_generate and output_scores: |
|
beam_indices[beam_group_idx] = tuple( |
|
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0])) |
|
) |
|
|
|
input_ids[batch_group_indices] = group_input_ids[beam_idx] |
|
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
current_tokens[batch_group_indices] = group_input_ids[:, -1] |
|
|
|
|
|
|
|
reordering_indices[batch_group_indices] = ( |
|
num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size) |
|
) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (processed_score,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) |
|
|
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
if model_kwargs["past"] is not None: |
|
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices) |
|
|
|
|
|
cur_len = cur_len + 1 |
|
|
|
if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
sequence_outputs = beam_scorer.finalize( |
|
input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
) |
|
|
|
if return_dict_in_generate: |
|
if not output_scores: |
|
sequence_outputs["sequence_scores"] = None |
|
else: |
|
beam_indices = sum(beam_indices, ()) |
|
num_return_sequences = beam_scorer.num_beam_hyps_to_keep |
|
|
|
beam_indices = tuple( |
|
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) |
|
) |
|
beam_indices = sum(beam_indices, ()) |
|
|
|
if self.config.is_encoder_decoder: |
|
return BeamSearchEncoderDecoderOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
beam_indices=beam_indices, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return BeamSearchDecoderOnlyOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return sequence_outputs["sequences"] |
|
|
|
def constrained_beam_search( |
|
self, |
|
input_ids: torch.LongTensor, |
|
constrained_beam_scorer: ConstrainedBeamSearchScorer, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
max_length: Optional[int] = None, |
|
pad_token_id: Optional[int] = None, |
|
eos_token_id: Optional[int] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_scores: Optional[bool] = None, |
|
return_dict_in_generate: Optional[bool] = None, |
|
synced_gpus: Optional[bool] = None, |
|
**model_kwargs, |
|
) -> Union[BeamSearchOutput, torch.LongTensor]: |
|
|
|
r""" |
|
Generates sequences of token ids for models with a language modeling head using **constrained beam search |
|
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. |
|
|
|
Parameters: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
The sequence used as a prompt for the generation. |
|
constrained_beam_scorer (`ConstrainedBeamSearchScorer`): |
|
A derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and |
|
sorted during generation, while satisfying a list of positive constraints. For more information, the |
|
documentation of [`ConstrainedBeamSearchScorer`] should be read. |
|
logits_processor (`LogitsProcessorList`, *optional*): |
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] |
|
used to modify the prediction scores of the language modeling head applied at each generation step. |
|
stopping_criteria (`StoppingCriteriaList`, *optional*): |
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] |
|
used to tell if the generation loop should stop. |
|
logits_warper (`LogitsProcessorList`, *optional*): |
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used |
|
to warp the prediction score distribution of the language modeling head applied before multinomial |
|
sampling at each generation step. |
|
max_length (`int`, *optional*, defaults to 20): |
|
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated |
|
tokens. The maximum length of the sequence to be generated. |
|
pad_token_id (`int`, *optional*): |
|
The id of the *padding* token. |
|
eos_token_id (`int`, *optional*): |
|
The id of the *end-of-sequence* token. |
|
output_attentions (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more details. |
|
output_hidden_states (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more details. |
|
output_scores (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details. |
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
synced_gpus (`bool`, *optional*, defaults to `False`): |
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) |
|
model_kwargs: |
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is |
|
an encoder-decoder model the kwargs should include `encoder_outputs`. |
|
|
|
Return: |
|
[`generation_utilsBeamSearchDecoderOnlyOutput`], [`~generation_utils.BeamSearchEncoderDecoderOutput`] or |
|
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a |
|
[`~generation_utils.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and |
|
`return_dict_in_generate=True` or a [`~generation_utils.BeamSearchEncoderDecoderOutput`] if |
|
`model.config.is_encoder_decoder=True`. |
|
|
|
|
|
Examples: |
|
|
|
```python |
|
>>> from transformers import ( |
|
... AutoTokenizer, |
|
... AutoModelForSeq2SeqLM, |
|
... LogitsProcessorList, |
|
... MinLengthLogitsProcessor, |
|
... ConstrainedBeamSearchScorer, |
|
... PhrasalConstraint, |
|
... ) |
|
>>> import torch |
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("t5-base") |
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
|
|
|
>>> encoder_input_str = "translate English to German: How old are you?" |
|
>>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
|
|
|
|
|
>>> # lets run beam search using 3 beams |
|
>>> num_beams = 3 |
|
>>> # define decoder start token ids |
|
>>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) |
|
>>> input_ids = input_ids * model.config.decoder_start_token_id |
|
|
|
>>> # add encoder_outputs to model keyword arguments |
|
>>> model_kwargs = { |
|
... "encoder_outputs": model.get_encoder()( |
|
... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True |
|
... ) |
|
... } |
|
|
|
>>> constraint_str = "Sie" |
|
>>> constraint_token_ids = tokenizer.encode(constraint_str)[:-1] # slice to remove eos token |
|
>>> constraints = [PhrasalConstraint(token_ids=constraint_token_ids)] |
|
|
|
|
|
>>> # instantiate beam scorer |
|
>>> beam_scorer = ConstrainedBeamSearchScorer( |
|
... batch_size=1, num_beams=num_beams, device=model.device, constraints=constraints |
|
... ) |
|
|
|
>>> # instantiate logits processors |
|
>>> logits_processor = LogitsProcessorList( |
|
... [ |
|
... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), |
|
... ] |
|
... ) |
|
|
|
>>> outputs = model.constrained_beam_search( |
|
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs |
|
... ) |
|
|
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
['Wie alt sind Sie?'] |
|
```""" |
|
|
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
|
if max_length is not None: |
|
warnings.warn( |
|
"`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", |
|
UserWarning, |
|
) |
|
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |
|
if len(stopping_criteria) == 0: |
|
warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) |
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
|
output_scores = output_scores if output_scores is not None else self.config.output_scores |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict_in_generate = ( |
|
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate |
|
) |
|
|
|
|
|
scores = () if (return_dict_in_generate and output_scores) else None |
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None |
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |
|
|
|
|
|
if return_dict_in_generate and self.config.is_encoder_decoder: |
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |
|
encoder_hidden_states = ( |
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |
|
) |
|
|
|
batch_size = len(constrained_beam_scorer._beam_hyps) |
|
num_beams = constrained_beam_scorer.num_beams |
|
|
|
batch_beam_size, cur_len = input_ids.shape |
|
|
|
if num_beams * batch_size != batch_beam_size: |
|
raise ValueError( |
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
|
) |
|
|
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
|
beam_scores[:, 1:] = -1e9 |
|
beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
|
this_peer_finished = False |
|
while True: |
|
|
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
|
|
|
scores_for_all_vocab = next_token_scores_processed.clone() |
|
|
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) |
|
|
|
|
|
if return_dict_in_generate: |
|
if output_scores: |
|
scores += (next_token_scores,) |
|
if output_attentions: |
|
decoder_attentions += ( |
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |
|
) |
|
if self.config.is_encoder_decoder: |
|
cross_attentions += (outputs.cross_attentions,) |
|
|
|
if output_hidden_states: |
|
decoder_hidden_states += ( |
|
(outputs.decoder_hidden_states,) |
|
if self.config.is_encoder_decoder |
|
else (outputs.hidden_states,) |
|
) |
|
|
|
|
|
vocab_size = next_token_scores.shape[-1] |
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
|
|
|
next_token_scores, next_tokens = torch.topk( |
|
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True |
|
) |
|
|
|
next_indices = (next_tokens / vocab_size).long() |
|
next_tokens = next_tokens % vocab_size |
|
|
|
|
|
beam_outputs = constrained_beam_scorer.process( |
|
input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
scores_for_all_vocab, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
) |
|
beam_scores = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
if model_kwargs["past"] is not None: |
|
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) |
|
|
|
|
|
cur_len = cur_len + 1 |
|
|
|
if constrained_beam_scorer.is_done or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
sequence_outputs = constrained_beam_scorer.finalize( |
|
input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
) |
|
|
|
if return_dict_in_generate: |
|
if not output_scores: |
|
sequence_outputs["sequence_scores"] = None |
|
if self.config.is_encoder_decoder: |
|
return BeamSearchEncoderDecoderOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
encoder_attentions=encoder_attentions, |
|
encoder_hidden_states=encoder_hidden_states, |
|
decoder_attentions=decoder_attentions, |
|
cross_attentions=cross_attentions, |
|
decoder_hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return BeamSearchDecoderOnlyOutput( |
|
sequences=sequence_outputs["sequences"], |
|
sequences_scores=sequence_outputs["sequence_scores"], |
|
scores=scores, |
|
attentions=decoder_attentions, |
|
hidden_states=decoder_hidden_states, |
|
) |
|
else: |
|
return sequence_outputs["sequences"] |
|
|
|
|
|
def top_k_top_p_filtering( |
|
logits: torch.FloatTensor, |
|
top_k: int = 0, |
|
top_p: float = 1.0, |
|
filter_value: float = -float("Inf"), |
|
min_tokens_to_keep: int = 1, |
|
) -> torch.FloatTensor: |
|
""" |
|
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
|
|
|
Args: |
|
logits: logits distribution shape (batch size, vocabulary size) |
|
top_k (`int`, *optional*, defaults to 0): |
|
If > 0, only keep the top k tokens with highest probability (top-k filtering) |
|
top_p (`float`, *optional*, defaults to 1.0): |
|
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus |
|
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
|
min_tokens_to_keep (`int`, *optional*, defaults to 1): |
|
Minimumber of tokens we keep per batch example in the output. |
|
|
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
|
""" |
|
if top_k > 0: |
|
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( |
|
None, logits |
|
) |
|
|
|
if 0 <= top_p <= 1.0: |
|
logits = TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=min_tokens_to_keep)(None, logits) |
|
|
|
return logits |
|
|