Spaces:
Runtime error
Runtime error
from typing import Callable, Optional, Tuple | |
from copy import deepcopy | |
import numpy as np | |
import flax | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze | |
from flax.linen.attention import dot_product_attention_weights | |
from flax.traverse_util import flatten_dict, unflatten_dict | |
from jax import lax | |
from transformers import AlbertConfig | |
from transformers.models.albert.modeling_flax_albert import FlaxAlbertOnlyMLMHead, FlaxAlbertEmbeddings, FlaxAlbertPreTrainedModel | |
from transformers.modeling_flax_outputs import ( | |
FlaxBaseModelOutput, | |
FlaxBaseModelOutputWithPooling, | |
FlaxMaskedLMOutput, | |
FlaxMultipleChoiceModelOutput, | |
FlaxQuestionAnsweringModelOutput, | |
FlaxSequenceClassifierOutput, | |
FlaxTokenClassifierOutput, | |
) | |
from transformers.utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging | |
from transformers.modeling_flax_utils import ( | |
ACT2FN, | |
FlaxPreTrainedModel, | |
append_call_sample_docstring, | |
append_replace_return_docstrings, | |
overwrite_call_docstring, | |
) | |
class CustomFlaxAlbertSelfAttention(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
if self.config.hidden_size % self.config.num_attention_heads != 0: | |
raise ValueError( | |
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " | |
" : {self.config.num_attention_heads}" | |
) | |
self.query = nn.Dense( | |
self.config.hidden_size, | |
dtype=self.dtype, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
) | |
self.key = nn.Dense( | |
self.config.hidden_size, | |
dtype=self.dtype, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
) | |
self.value = nn.Dense( | |
self.config.hidden_size, | |
dtype=self.dtype, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
) | |
self.dense = nn.Dense( | |
self.config.hidden_size, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
dtype=self.dtype, | |
) | |
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) | |
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic=True, | |
output_attentions: bool = False, | |
layer_id: int = None, | |
interv_type: str = "swap", | |
interv_dict: dict = {}, | |
): | |
head_dim = self.config.hidden_size // self.config.num_attention_heads | |
query_states = self.query(hidden_states).reshape( | |
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) | |
) | |
value_states = self.value(hidden_states).reshape( | |
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) | |
) | |
key_states = self.key(hidden_states).reshape( | |
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) | |
) | |
reps = { | |
'lay': hidden_states, | |
'qry': query_states, | |
'key': key_states, | |
'val': value_states, | |
} | |
if layer_id in interv_dict: | |
interv = interv_dict[layer_id] | |
for rep_name in ['lay','qry','key','val']: | |
if rep_name in interv: | |
new_state = deepcopy(reps[rep_name]) | |
for head_id, pos, swap_ids in interv[rep_name]: | |
new_state[swap_ids[0],pos,head_id] = reps[rep_name][swap_ids[1],pos,head_id] | |
new_state[swap_ids[1],pos,head_id] = reps[rep_name][swap_ids[0],pos,head_id] | |
reps[rep_name] = deepcopy(new_state) | |
hidden_states = deepcopy(reps['lay']) | |
query_states = deepcopy(reps['qry']) | |
key_states = deepcopy(reps['key']) | |
value_states = deepcopy(reps['val']) | |
# Convert the boolean attention mask to an attention bias. | |
if attention_mask is not None: | |
# attention mask in the form of attention bias | |
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) | |
attention_bias = lax.select( | |
attention_mask > 0, | |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), | |
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), | |
) | |
else: | |
attention_bias = None | |
dropout_rng = None | |
if not deterministic and self.config.attention_probs_dropout_prob > 0.0: | |
dropout_rng = self.make_rng("dropout") | |
attn_weights = dot_product_attention_weights( | |
query_states, | |
key_states, | |
bias=attention_bias, | |
dropout_rng=dropout_rng, | |
dropout_rate=self.config.attention_probs_dropout_prob, | |
broadcast_dropout=True, | |
deterministic=deterministic, | |
dtype=self.dtype, | |
precision=None, | |
) | |
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) | |
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) | |
projected_attn_output = self.dense(attn_output) | |
projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic) | |
layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states) | |
outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,) | |
return outputs | |
class CustomFlaxAlbertLayer(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
self.attention = CustomFlaxAlbertSelfAttention(self.config, dtype=self.dtype) | |
self.ffn = nn.Dense( | |
self.config.intermediate_size, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
dtype=self.dtype, | |
) | |
self.activation = ACT2FN[self.config.hidden_act] | |
self.ffn_output = nn.Dense( | |
self.config.hidden_size, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
dtype=self.dtype, | |
) | |
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) | |
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
layer_id: int = None, | |
interv_type: str = "swap", | |
interv_dict: dict = {}, | |
): | |
attention_outputs = self.attention( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
layer_id=layer_id, | |
interv_type=interv_type, | |
interv_dict=interv_dict, | |
) | |
attention_output = attention_outputs[0] | |
ffn_output = self.ffn(attention_output) | |
ffn_output = self.activation(ffn_output) | |
ffn_output = self.ffn_output(ffn_output) | |
ffn_output = self.dropout(ffn_output, deterministic=deterministic) | |
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output) | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (attention_outputs[1],) | |
return outputs | |
class CustomFlaxAlbertLayerCollection(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
self.layers = [ | |
CustomFlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num) | |
] | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
layer_id: int = None, | |
interv_type: str = "swap", | |
interv_dict: dict = {}, | |
): | |
layer_hidden_states = () | |
layer_attentions = () | |
for layer_index, albert_layer in enumerate(self.layers): | |
layer_output = albert_layer( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
layer_id=layer_id, | |
interv_type=interv_type, | |
interv_dict=interv_dict, | |
) | |
hidden_states = layer_output[0] | |
if output_attentions: | |
layer_attentions = layer_attentions + (layer_output[1],) | |
if output_hidden_states: | |
layer_hidden_states = layer_hidden_states + (hidden_states,) | |
outputs = (hidden_states,) | |
if output_hidden_states: | |
outputs = outputs + (layer_hidden_states,) | |
if output_attentions: | |
outputs = outputs + (layer_attentions,) | |
return outputs # last-layer hidden state, (layer hidden states), (layer attentions) | |
class CustomFlaxAlbertLayerCollections(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
layer_index: Optional[str] = None | |
def setup(self): | |
self.albert_layers = CustomFlaxAlbertLayerCollection(self.config, dtype=self.dtype) | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
layer_id: int = None, | |
interv_type: str = "swap", | |
interv_dict: dict = {}, | |
): | |
outputs = self.albert_layers( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
layer_id=layer_id, | |
interv_type=interv_type, | |
interv_dict=interv_dict, | |
) | |
return outputs | |
class CustomFlaxAlbertLayerGroups(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
self.layers = [ | |
CustomFlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype) | |
for i in range(self.config.num_hidden_groups) | |
] | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
interv_type: str = "swap", | |
interv_dict: dict = {}, | |
): | |
all_attentions = () if output_attentions else None | |
all_hidden_states = (hidden_states,) if output_hidden_states else None | |
for i in range(self.config.num_hidden_layers): | |
# Index of the hidden group | |
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) | |
layer_group_output = self.layers[group_idx]( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
layer_id=i, | |
interv_type=interv_type, | |
interv_dict=interv_dict, | |
) | |
hidden_states = layer_group_output[0] | |
if output_attentions: | |
all_attentions = all_attentions + layer_group_output[-1] | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
if not return_dict: | |
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) | |
return FlaxBaseModelOutput( | |
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions | |
) | |
class CustomFlaxAlbertEncoder(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
self.embedding_hidden_mapping_in = nn.Dense( | |
self.config.hidden_size, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
dtype=self.dtype, | |
) | |
self.albert_layer_groups = CustomFlaxAlbertLayerGroups(self.config, dtype=self.dtype) | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
interv_type: str = "swap", | |
interv_dict: dict = {}, | |
): | |
hidden_states = self.embedding_hidden_mapping_in(hidden_states) | |
return self.albert_layer_groups( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
interv_type=interv_type, | |
interv_dict=interv_dict, | |
) | |
class CustomFlaxAlbertModule(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
add_pooling_layer: bool = True | |
def setup(self): | |
self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype) | |
self.encoder = CustomFlaxAlbertEncoder(self.config, dtype=self.dtype) | |
if self.add_pooling_layer: | |
self.pooler = nn.Dense( | |
self.config.hidden_size, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
dtype=self.dtype, | |
name="pooler", | |
) | |
self.pooler_activation = nn.tanh | |
else: | |
self.pooler = None | |
self.pooler_activation = None | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
token_type_ids: Optional[np.ndarray] = None, | |
position_ids: Optional[np.ndarray] = None, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
interv_type: str = "swap", | |
interv_dict: dict = {}, | |
): | |
# make sure `token_type_ids` is correctly initialized when not passed | |
if token_type_ids is None: | |
token_type_ids = jnp.zeros_like(input_ids) | |
# make sure `position_ids` is correctly initialized when not passed | |
if position_ids is None: | |
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) | |
hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic) | |
outputs = self.encoder( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
interv_type=interv_type, | |
interv_dict=interv_dict, | |
) | |
hidden_states = outputs[0] | |
if self.add_pooling_layer: | |
pooled = self.pooler(hidden_states[:, 0]) | |
pooled = self.pooler_activation(pooled) | |
else: | |
pooled = None | |
if not return_dict: | |
# if pooled is None, don't return it | |
if pooled is None: | |
return (hidden_states,) + outputs[1:] | |
return (hidden_states, pooled) + outputs[1:] | |
return FlaxBaseModelOutputWithPooling( | |
last_hidden_state=hidden_states, | |
pooler_output=pooled, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class CustomFlaxAlbertForMaskedLMModule(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.albert = CustomFlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) | |
self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
interv_type: str = "swap", | |
interv_dict: dict = {}, | |
): | |
# Model | |
outputs = self.albert( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
interv_type=interv_type, | |
interv_dict=interv_dict, | |
) | |
hidden_states = outputs[0] | |
if self.config.tie_word_embeddings: | |
shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] | |
else: | |
shared_embedding = None | |
# Compute the prediction scores | |
logits = self.predictions(hidden_states, shared_embedding=shared_embedding) | |
if not return_dict: | |
return (logits,) + outputs[1:] | |
return FlaxMaskedLMOutput( | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class CustomFlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel): | |
module_class = CustomFlaxAlbertForMaskedLMModule | |