|
from typing import Optional, Tuple |
|
|
|
import flax |
|
import jaxlib.xla_extension as jax_xla |
|
|
|
from transformers.file_utils import ModelOutput |
|
|
|
|
|
@flax.struct.dataclass |
|
class TransformerVaeOutput(ModelOutput): |
|
""" |
|
Base class for a Transformer-VAE's outputs. |
|
|
|
Args: |
|
latent_codes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_latent_tokens, latent_token_size)`): |
|
Latent codes representing encoded sequences. |
|
remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`): |
|
Reconstructed encoder hidden states representing sequences. |
|
|
|
(std Seq2Seq) Args: |
|
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): |
|
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 |
|
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional |
|
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. |
|
last_hidden_state (:obj:`tuple(jax_xla.DeviceArray)`: |
|
Last model hidden state. |
|
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): |
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each |
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. |
|
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): |
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, |
|
sequence_length, sequence_length)`. |
|
|
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the |
|
self-attention heads. |
|
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): |
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, |
|
sequence_length, sequence_length)`. |
|
|
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the |
|
weighted average in the cross-attention heads. |
|
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): |
|
Sequence of hidden-states at the output of the last layer of the encoder of the model. |
|
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): |
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each |
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. |
|
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): |
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, |
|
sequence_length, sequence_length)`. |
|
|
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the |
|
self-attention heads. |
|
""" |
|
logits: jax_xla.DeviceArray = None |
|
latent_codes: jax_xla.DeviceArray = None |
|
remade_encoder_hidden_state: jax_xla.DeviceArray = None |
|
|
|
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None |
|
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None |
|
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None |
|
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None |
|
last_hidden_state: Optional[jax_xla.DeviceArray] = None |
|
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None |
|
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None |
|
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None |
|
|