|
from typing import Optional, Tuple |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
from jax.random import PRNGKey |
|
import flax.linen as nn |
|
from flax.core.frozen_dict import FrozenDict, unfreeze |
|
|
|
from transformers.modeling_flax_outputs import FlaxCausalLMOutputWithCrossAttentions |
|
from transformers.file_utils import add_start_docstrings |
|
from transformers.modeling_flax_utils import FlaxPreTrainedModel |
|
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGenerationModule |
|
|
|
from model.vae import VAE |
|
from model.outputs import TransformerVaeOutput |
|
from model.config import T5VaeConfig |
|
|
|
|
|
@add_start_docstrings("""T5 Model with a `language modeling` head on top converted into a VAE.""") |
|
class FlaxT5VaeForAutoencodingModule(nn.Module): |
|
config: T5VaeConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def _get_encoder_module(self): |
|
return self.t5.encoder |
|
|
|
def _get_vae_encoder_module(self): |
|
return self.vae.encoder |
|
|
|
def _get_vae_decoder_module(self): |
|
return self.vae.decoder |
|
|
|
def _get_decoder_module(self): |
|
return self.t5.decoder |
|
|
|
def setup(self): |
|
self.t5 = FlaxT5ForConditionalGenerationModule(self.config.t5) |
|
self.vae = VAE(self.config) |
|
|
|
def __call__( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
decoder_input_ids=None, |
|
decoder_attention_mask=None, |
|
encoder_outputs=None, |
|
latent_codes=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
deterministic: bool = True, |
|
): |
|
""" |
|
Adapted from `FlaxT5ForConditionalGenerationModule` |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
encoder_outputs = self.t5.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
hidden_states = encoder_outputs[0] |
|
|
|
|
|
hidden_states, latent_codes = self.vae(hidden_states, latent_codes) |
|
encoder_attention_mask = jnp.ones((hidden_states.shape[0], hidden_states.shape[1])) |
|
|
|
|
|
decoder_outputs = self.t5.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
encoder_hidden_states=hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
sequence_output = decoder_outputs[0] |
|
|
|
if self.t5.config.tie_word_embeddings: |
|
|
|
|
|
sequence_output = sequence_output * (self.t5.config.d_model ** -0.5) |
|
|
|
if self.t5.config.tie_word_embeddings: |
|
shared_embedding = self.t5.shared.variables["params"]["embedding"] |
|
lm_logits = self.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) |
|
else: |
|
lm_logits = self.t5.lm_head(sequence_output) |
|
|
|
if not return_dict: |
|
return [lm_logits, latent_codes] + decoder_outputs[1:] + encoder_outputs |
|
|
|
return TransformerVaeOutput( |
|
logits=lm_logits, |
|
latent_codes=latent_codes, |
|
last_hidden_state=decoder_outputs.last_hidden_state, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class FlaxT5VaePreTrainedModel(FlaxPreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = T5VaeConfig |
|
base_model_prefix = "transformer" |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: T5VaeConfig, |
|
input_shape: Tuple[int] = (1, 1), |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
**kwargs |
|
): |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
|
|
|
input_ids = jnp.zeros(input_shape, dtype="i4") |
|
|
|
attention_mask = jnp.ones_like(input_ids) |
|
decoder_input_ids = jnp.ones_like(input_ids) |
|
decoder_attention_mask = jnp.ones_like(input_ids) |
|
|
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
return self.module.init( |
|
rngs, |
|
input_ids, |
|
attention_mask, |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
)["params"] |
|
|
|
def __call__( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_input_ids: jnp.ndarray = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if decoder_input_ids is None: |
|
raise ValueError( |
|
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here." |
|
) |
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(input_ids) |
|
|
|
|
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
|
|
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
input_ids=jnp.array(input_ids, dtype="i4"), |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
) |
|
|
|
def init_cache(self, batch_size, max_length, latent_codes): |
|
r""" |
|
Args: |
|
batch_size (:obj:`int`): |
|
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. |
|
max_length (:obj:`int`): |
|
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized |
|
cache. |
|
latent_codes (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): |
|
``latent_codes`` consists of compressed hidden-states at the output of the last layer of the encoder. |
|
Used in the cross-attention of the decoder. |
|
""" |
|
|
|
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
|
|
def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs): |
|
vae_decoder_module = module._get_vae_decoder_module() |
|
decoder_module = module._get_decoder_module() |
|
return decoder_module( |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
encoder_hidden_states=vae_decoder_module(latent_codes), |
|
**kwargs, |
|
) |
|
|
|
init_variables = self.module.init( |
|
jax.random.PRNGKey(0), |
|
decoder_input_ids=decoder_input_ids, |
|
latent_codes=latent_codes, |
|
decoder_attention_mask=decoder_attention_mask, |
|
init_cache=True, |
|
method=_decoder_forward, |
|
) |
|
return unfreeze(init_variables["cache"]) |
|
|
|
def encode( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
raise NotImplementedError() |
|
|
|
def decode( |
|
self, |
|
decoder_input_ids, |
|
latent_codes, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
past_key_values: dict = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
raise NotImplementedError() |
|
|
|
|
|
class FlaxT5VaeForAutoencoding(FlaxT5VaePreTrainedModel): |
|
module_class = FlaxT5VaeForAutoencodingModule |
|
|
|
def __call__( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_input_ids=None, |
|
decoder_attention_mask=None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
''' |
|
Adapted from `FlaxT5PreTrainedModel` |
|
''' |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if decoder_input_ids is None: |
|
raise ValueError( |
|
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here." |
|
) |
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(input_ids) |
|
|
|
|
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
|
|
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
input_ids=jnp.array(input_ids, dtype="i4"), |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
) |
|
|
|
def encode( |
|
self, |
|
input_ids: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(input_ids) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
def _encoder_forward(module, input_ids, attention_mask, **kwargs): |
|
encode_module = module._get_encoder_module() |
|
vae_encoder_module = module._get_vae_encoder_module() |
|
return vae_encoder_module(encode_module(input_ids, attention_mask, **kwargs)[0]) |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
input_ids=jnp.array(input_ids, dtype="i4"), |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
method=_encoder_forward, |
|
) |
|
|
|
def decode( |
|
self, |
|
decoder_input_ids, |
|
latent_codes, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
decoder_attention_mask: Optional[jnp.ndarray] = None, |
|
past_key_values: dict = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
r""" |
|
Returns: |
|
|
|
Example:: |
|
|
|
>>> model = FlaxT5VaeForAutoencoding.from_pretrained('t5-small') |
|
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small') |
|
|
|
>>> text = "My friends are cool but they eat too many carbs." |
|
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax') |
|
>>> latent_codes = model.encode(**inputs) |
|
|
|
>>> decoder_start_token_id = model.config.decoder_start_token_id |
|
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id |
|
|
|
>>> outputs = model.decode(decoder_input_ids, latent_codes) |
|
>>> last_decoder_hidden_states = outputs.last_hidden_state |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if encoder_attention_mask is None: |
|
batch_size, sequence_length = latent_codes.shape[:2] |
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
batch_size, sequence_length = decoder_input_ids.shape |
|
if decoder_attention_mask is None: |
|
decoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
|
|
|
|
|
|
if past_key_values: |
|
inputs["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs): |
|
vae_decoder_module = module._get_vae_decoder_module() |
|
decoder_module = module._get_decoder_module() |
|
decoder_outputs = decoder_module( |
|
decoder_input_ids, |
|
decoder_attention_mask, |
|
encoder_hidden_states=vae_decoder_module(latent_codes), |
|
**kwargs, |
|
) |
|
sequence_output = decoder_outputs[0] |
|
|
|
if self.config.tie_word_embeddings: |
|
|
|
|
|
sequence_output = sequence_output * (self.config.d_model ** -0.5) |
|
|
|
if self.config.tie_word_embeddings: |
|
shared_embedding = module.t5.shared.variables["params"]["embedding"] |
|
lm_logits = module.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) |
|
else: |
|
lm_logits = module.t5.lm_head(sequence_output) |
|
|
|
return lm_logits, decoder_outputs |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
|
latent_codes=latent_codes, |
|
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
|
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
mutable=mutable, |
|
method=_decoder_forward, |
|
) |
|
|
|
if past_key_values is None: |
|
lm_logits, decoder_outputs = outputs |
|
else: |
|
(lm_logits, decoder_outputs), past = outputs |
|
|
|
if return_dict: |
|
outputs = FlaxCausalLMOutputWithCrossAttentions( |
|
logits=lm_logits, |
|
hidden_states=decoder_outputs.hidden_states, |
|
attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
) |
|
else: |
|
outputs = (lm_logits,) + decoder_outputs[1:] |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs["past_key_values"] = unfreeze(past["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
decoder_input_ids, |
|
max_length, |
|
attention_mask: Optional[jnp.DeviceArray] = None, |
|
decoder_attention_mask: Optional[jnp.DeviceArray] = None, |
|
latent_codes=None, |
|
**kwargs |
|
): |
|
|
|
batch_size, seq_length = decoder_input_ids.shape |
|
|
|
past_key_values = self.init_cache(batch_size, max_length, latent_codes) |
|
|
|
|
|
|
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") |
|
if decoder_attention_mask is not None: |
|
extended_attention_mask = jax.lax.dynamic_update_slice( |
|
extended_attention_mask, decoder_attention_mask, (0, 0) |
|
) |
|
|
|
return { |
|
"past_key_values": past_key_values, |
|
"latent_codes": latent_codes, |
|
"encoder_attention_mask": attention_mask, |
|
"decoder_attention_mask": extended_attention_mask, |
|
} |
|
|
|
def update_inputs_for_generation(self, model_outputs, model_kwargs): |
|
model_kwargs["past_key_values"] = model_outputs.past_key_values |
|
return model_kwargs |
|
|