File size: 6,720 Bytes
a03c9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""t5mod_helper.py"""
import torch
from torch import nn
from model.t5mod import T5DecoderYMT3, MultiChannelT5Decoder
from typing import Optional, Callable, Union, Literal
@torch.no_grad()
def task_cond_dec_generate(decoder: Union[T5DecoderYMT3, MultiChannelT5Decoder],
decoder_type: Literal["t5", "multi-t5"],
embed_tokens: nn.Embedding,
lm_head: nn.Module,
encoder_hidden_states: torch.FloatTensor,
shift_right_fn: Callable,
prefix_ids: Optional[torch.LongTensor] = None,
max_length: int = 1024,
stop_at_eos: bool = True,
eos_id: Optional[int] = 1,
pad_id: Optional[int] = 0,
decoder_start_token_id: Optional[int] = 0,
debug: bool = False) -> torch.LongTensor:
"""
Generate sequence by task conditioning on the decoder side
:An extension of transofrmers.generate() function for the model with
conditioning only on the decoder side.
Args:
decoder: T5DecoderYMT3 or MultiChannelT5Decoder, any decoder model with T5Stack architecture
decoder_type: Literal["t5", "multi-t5"], type of decoder
embed_tokens: nn.Embedding, embedding layer for the decoder
lm_head: nn.Module, language model head
encoder_hidden_states: torch.FloatTensor, (B, T, D) or (B, K, T, D) last hidden states
shift_right_fn: Callable, shift_right function of the decoder
prefix_ids: torch.LongTensor, (B, prefix_len) prefix ids typically used as task conditioning to decoder.
max_length: int, max token length to generate (default is 1024)
stop_at_eos: bool, whether to early-stop when all predictions in the batch are the <eos> token.
eos_id: int, the id of the <eos> token (default is 1)
pad_id: int, the id of the <pad> token (default is 0)
decoder_start_token_id: int, the id of the <bos> token (default is 0)
debug: bool, whether to print debug information
Returns:
pred_ids: torch.LongTensor, (B, task_len + N) or (B, C, task_len + N) predicted token ids
"""
bsz = encoder_hidden_states.shape[0]
device = encoder_hidden_states.device
# Prepare dec_input_shape: (B, 1) or (B, C, 1)
if decoder_type == "t5":
dec_input_shape = (bsz, 1)
elif decoder_type == "multi-t5":
dec_input_shape = (bsz, decoder.num_channels, 1)
else:
raise ValueError(f"decoder_type {decoder_type} is not supported.")
# Prepare dec_input_ids: <bos> + task_prefix_token (B, prefix_len + 1) or (B, C, prefix_len + 1)
if prefix_ids is not None and prefix_ids.numel() > 0:
dec_input_ids = shift_right_fn(prefix_ids)
prefix_length = prefix_ids.shape[-1]
else:
# if prefix_ids is None, use <bos> as initial inSput
dec_input_ids = torch.tile(torch.LongTensor([decoder_start_token_id]).to(device), dec_input_shape)
prefix_length = 0
dec_inputs_embeds = embed_tokens(dec_input_ids) # (B, L, D) or (B, C, L, D)
# Generate decoder hidden state and past_key_values using prefix:
"""
- initial inputs_embeds can be a sequence, without using past_key_values
- dec_hs: (B, 1, D)
- past_key_values: Tuple of length M for M layers of decoder
- pred_ids: (B, prefix_len) where N is the length of prefix_ids
"""
dec_hs, past_key_values = decoder(inputs_embeds=dec_inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
return_dict=False)
logits = lm_head(dec_hs) # (b, T=1, vocab_size) or (b, C, T=1, vocab_size)
pred_ids = logits.argmax(-1) # (B, prefix_len + 1) or (B, C, prefix_len + 1)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(dec_input_shape, dtype=torch.long, device=device)
# Fast generation with past_key_values for the rest of the sequence
if decoder_type == "t5":
dec_input_ids = pred_ids[:, -1].unsqueeze(-1) # (B, 1)
elif decoder_type == "multi-t5":
dec_input_ids = pred_ids[:, :, -1].unsqueeze(-1) # (B, C, 1)
for i in range(max_length - prefix_length - 1): # -1 for <eos> token
if debug:
past_key_values_length = past_key_values[0][0].shape[
2] # past_key_values_length determines the positional embedding
print(f'i = {i}, past_key_values_length = {past_key_values_length}, pred_ids.shape = {pred_ids.shape}')
# when past_key_values is provided, we use only the last token as input_ids
dec_inputs_embeds = embed_tokens(dec_input_ids) # (B, 1, D) or (B, C, 1, D)
dec_hs, _past_key_values = decoder(inputs_embeds=dec_inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
return_dict=False)
logits = lm_head(dec_hs) # (b, 1, vocab_size) or (b, K, 1, vocab_size)
_pred_ids = logits.argmax(-1) # (B, 1) or (B, K, 1)
# update input_ids and past_key_values for next iteration
dec_input_ids = _pred_ids.clone(
) # (B, 1) or (B, C, 1), deepcopy of _pred_ids because _pred_ids will be modified for finished sentences
past_key_values = _past_key_values
# finished sentences should have their next token be a padding token
if eos_id is not None:
if pad_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
_pred_ids = _pred_ids * unfinished_sequences + pad_id * (1 - unfinished_sequences)
# update pred_ids
pred_ids = torch.cat((pred_ids, _pred_ids), dim=-1) # (B, T') or (B, C, T') with increasing T'
# update state of unfinished_sequences
if eos_id is not None:
unfinished_sequences = unfinished_sequences * _pred_ids.ne(eos_id).long()
# early-stop when each sentence is finished
if stop_at_eos is True and unfinished_sequences.max() == 0:
break
return pred_ids # (B, L) or (B, C, L)
|