Fraser's picture
cope that submodules not allowed
1d30073
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict, unfreeze
from transformers.modeling_flax_outputs import FlaxCausalLMOutputWithCrossAttentions
from transformers.file_utils import add_start_docstrings
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGenerationModule
from t5_vae_flax_alt.src.vae import VAE
from t5_vae_flax_alt.src.generate import VaeFlaxGenerationMixin
from t5_vae_flax_alt.src.outputs import TransformerVaeOutput
from t5_vae_flax_alt.src.config import T5VaeConfig
@add_start_docstrings("""T5 Model with a `language modeling` head on top converted into a VAE.""")
class FlaxT5VaeForAutoencodingModule(nn.Module):
config: T5VaeConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def _get_encoder_module(self):
return self.t5.encoder
def _get_vae_encoder_module(self):
return self.vae.encoder
def _get_vae_decoder_module(self):
return self.vae.decoder
def _get_decoder_module(self):
return self.t5.decoder
def setup(self):
self.t5 = FlaxT5ForConditionalGenerationModule(self.config.t5)
self.vae = VAE(self.config)
def __call__(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
latent_codes=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
deterministic: bool = True,
):
"""
Adapted from `FlaxT5ForConditionalGenerationModule`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode
encoder_outputs = self.t5.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = encoder_outputs[0]
# Autoencode
hidden_states, latent_codes = self.vae(hidden_states, latent_codes)
encoder_attention_mask = jnp.ones((hidden_states.shape[0], hidden_states.shape[1]))
# Decode
decoder_outputs = self.t5.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
sequence_output = decoder_outputs[0]
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.config.t5.d_model ** -0.5)
if self.t5.config.tie_word_embeddings:
shared_embedding = self.t5.shared.variables["params"]["embedding"]
lm_logits = self.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
else:
lm_logits = self.t5.lm_head(sequence_output)
if not return_dict:
return [lm_logits, latent_codes] + decoder_outputs[1:] + encoder_outputs
return TransformerVaeOutput(
logits=lm_logits,
latent_codes=latent_codes,
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class FlaxT5VaePreTrainedModel(FlaxPreTrainedModel, VaeFlaxGenerationMixin):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = T5VaeConfig
base_model_prefix = "transformer"
module_class: nn.Module = None
def __init__(
self,
config: T5VaeConfig,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = jnp.ones_like(input_ids)
decoder_attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
)["params"]
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: jnp.ndarray = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if decoder_input_ids is None:
raise ValueError(
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
)
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# prepare decoder inputs
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
def init_cache(self, batch_size, max_length, latent_codes):
r"""
Args:
batch_size (:obj:`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (:obj:`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
latent_codes (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
``latent_codes`` consists of compressed hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
"""
# init input variables to retrieve cache
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
init_cache=True,
method=_decoder_forward, # we only need to call the decoder to init the cache
)
return unfreeze(init_variables["cache"])
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
raise NotImplementedError()
def decode(
self,
decoder_input_ids,
latent_codes,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
raise NotImplementedError()
class FlaxT5VaeForAutoencoding(FlaxT5VaePreTrainedModel):
module_class = FlaxT5VaeForAutoencodingModule
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids=None,
decoder_attention_mask=None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
'''
Adapted from `FlaxT5PreTrainedModel`
'''
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if decoder_input_ids is None:
raise ValueError(
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
)
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# prepare decoder inputs
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, input_ids, attention_mask, **kwargs):
encode_module = module._get_encoder_module()
vae_encoder_module = module._get_vae_encoder_module()
return vae_encoder_module(encode_module(input_ids, attention_mask, **kwargs)[0])
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
def decode(
self,
decoder_input_ids,
latent_codes,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example::
>>> model = FlaxT5VaeForAutoencoding.from_pretrained('t5-small')
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
>>> latent_codes = model.encode(**inputs)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, latent_codes)
>>> last_decoder_hidden_states = outputs.last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if encoder_attention_mask is None:
batch_size, sequence_length = latent_codes.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = decoder_input_ids.shape
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxT5Attention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs):
vae_decoder_module = module._get_vae_decoder_module()
decoder_module = module._get_decoder_module()
decoder_outputs = decoder_module(
decoder_input_ids,
decoder_attention_mask,
encoder_hidden_states=vae_decoder_module(latent_codes),
**kwargs,
)
sequence_output = decoder_outputs[0]
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.config.t5.d_model ** -0.5)
if self.config.tie_word_embeddings:
shared_embedding = module.t5.shared.variables["params"]["embedding"]
lm_logits = module.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
else:
lm_logits = module.t5.lm_head(sequence_output)
return lm_logits, decoder_outputs
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
latent_codes=latent_codes,
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
if past_key_values is None:
lm_logits, decoder_outputs = outputs
else:
(lm_logits, decoder_outputs), past = outputs
if return_dict:
outputs = FlaxCausalLMOutputWithCrossAttentions(
logits=lm_logits,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
)
else:
outputs = (lm_logits,) + decoder_outputs[1:]
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
latent_codes=None,
**kwargs
):
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, latent_codes)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
extended_attention_mask = jax.lax.dynamic_update_slice(
extended_attention_mask, decoder_attention_mask, (0, 0)
)
return {
"past_key_values": past_key_values,
"latent_codes": latent_codes,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
return model_kwargs