ydshieh commited on
Commit
e755009
·
1 Parent(s): 686f21e

try load model from hub

Browse files
model.py CHANGED
@@ -9,11 +9,13 @@ from transformers import GPT2Tokenizer
9
  current_path = os.path.dirname(os.path.abspath(__file__))
10
  sys.path.append(current_path)
11
 
12
- # Main model - ViTGPT2LM
13
- # from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
14
 
15
- def predict(image):
 
16
 
 
17
  return 'dummy caption!', ['dummy', 'caption', '!'], [1, 2, 3]
18
 
19
 
 
9
  current_path = os.path.dirname(os.path.abspath(__file__))
10
  sys.path.append(current_path)
11
 
12
+ Main model - ViTGPT2LM
13
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
14
 
15
+ model_name_or_path = 'flax-community/vit-gpt2/checkpoints/ckpt_5/'
16
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_name_or_path)
17
 
18
+ def predict(image):
19
  return 'dummy caption!', ['dummy', 'caption', '!'], [1, 2, 3]
20
 
21
 
vit_gpt2/__init__.py ADDED
File without changes
vit_gpt2/configuration_vit_gpt2.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from transformers import GPT2Config, ViTConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class ViTGPT2Config(PretrainedConfig):
11
+
12
+ model_type = "vit-gpt2"
13
+ is_composition = True
14
+
15
+ def __init__(self, **kwargs):
16
+ super().__init__(**kwargs)
17
+
18
+ if "vit_config" not in kwargs:
19
+ raise ValueError("`vit_config` can not be `None`.")
20
+
21
+ if "gpt2_config" not in kwargs:
22
+ raise ValueError("`gpt2_config` can not be `None`.")
23
+
24
+ vit_config = kwargs.pop("vit_config")
25
+ gpt2_config = kwargs.pop("gpt2_config")
26
+
27
+ self.vit_config = ViTConfig(**vit_config)
28
+ self.gpt2_config = GPT2Config(**gpt2_config)
29
+
30
+ @classmethod
31
+ def from_vit_gpt2_configs(
32
+ cls, vit_config: PretrainedConfig, gpt2_config: PretrainedConfig, **kwargs
33
+ ):
34
+ return cls(
35
+ vit_config=vit_config.to_dict(),
36
+ gpt2_config=gpt2_config.to_dict(),
37
+ **kwargs
38
+ )
39
+
40
+ def to_dict(self):
41
+ output = copy.deepcopy(self.__dict__)
42
+ output["vit_config"] = self.vit_config.to_dict()
43
+ output["gpt2_config"] = self.gpt2_config.to_dict()
44
+ output["model_type"] = self.__class__.model_type
45
+ return output
vit_gpt2/modeling_flax_gpt2.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional, Tuple
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from flax.core.frozen_dict import FrozenDict, unfreeze
22
+ from flax.linen import combine_masks, make_causal_mask
23
+ from flax.linen.attention import dot_product_attention_weights
24
+ from jax import lax
25
+
26
+ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
27
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxSeq2SeqLMOutput
28
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
29
+ from transformers.utils import logging
30
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ _CHECKPOINT_FOR_DOC = "gpt2"
36
+ _CONFIG_FOR_DOC = "GPT2Config"
37
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
38
+
39
+
40
+ GPT2_START_DOCSTRING = r"""
41
+
42
+ This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
43
+ generic methods the library implements for all its model (such as downloading or saving, resizing the input
44
+ embeddings, pruning heads etc.)
45
+
46
+ This model is also a Flax Linen `flax.nn.Module
47
+ <https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
48
+ Module and refer to the Flax documentation for all matter related to general usage and behavior.
49
+
50
+ Finally, this model supports inherent JAX features such as:
51
+
52
+ - `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
53
+ - `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
54
+ - `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
55
+ - `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
56
+
57
+ Parameters:
58
+ config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
59
+ Initializing with a config file does not load the weights associated with the model, only the
60
+ configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
61
+ model weights.
62
+ """
63
+
64
+ GPT2_INPUTS_DOCSTRING = r"""
65
+ Args:
66
+ input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, input_ids_length)`):
67
+ :obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary.
68
+
69
+ Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
70
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
71
+ details.
72
+
73
+ `What are input IDs? <../glossary.html#input-ids>`__
74
+ attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
75
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
76
+
77
+ - 1 for tokens that are **not masked**,
78
+ - 0 for tokens that are **masked**.
79
+
80
+ `What are attention masks? <../glossary.html#attention-mask>`__
81
+ position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
82
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
83
+ config.max_position_embeddings - 1]``.
84
+ past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):
85
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
86
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
87
+ output_attentions (:obj:`bool`, `optional`):
88
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
89
+ tensors for more detail.
90
+ output_hidden_states (:obj:`bool`, `optional`):
91
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
92
+ more detail.
93
+ return_dict (:obj:`bool`, `optional`):
94
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
95
+ """
96
+
97
+
98
+ class FlaxConv1D(nn.Module):
99
+ features: int
100
+ use_bias: bool = True
101
+ dtype: Any = jnp.float32
102
+ precision: Any = None
103
+
104
+ @nn.compact
105
+ def __call__(self, inputs):
106
+ inputs = jnp.asarray(inputs, self.dtype)
107
+ kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1]))
108
+ kernel = jnp.asarray(kernel.transpose(), self.dtype)
109
+ y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision)
110
+ if self.use_bias:
111
+ bias = self.param("bias", jax.nn.initializers.zeros, (self.features,))
112
+ bias = jnp.asarray(bias, self.dtype)
113
+ y = y + bias
114
+ return y
115
+
116
+
117
+ class FlaxGPT2Attention(nn.Module):
118
+ config: GPT2Config
119
+ dtype: jnp.dtype = jnp.float32
120
+ causal: bool = True
121
+
122
+ def setup(self):
123
+ config = self.config
124
+ self.embed_dim = config.hidden_size
125
+ self.num_heads = config.num_attention_heads
126
+ self.head_dim = self.embed_dim // self.num_heads
127
+
128
+ self.c_attn = FlaxConv1D(features=3 * self.embed_dim, dtype=self.dtype)
129
+ self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
130
+
131
+ self.c_attn_for_k_v = FlaxConv1D(features=2 * self.embed_dim, dtype=self.dtype)
132
+
133
+ self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
134
+
135
+ if self.causal:
136
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
137
+
138
+ def _split_heads(self, hidden_states):
139
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
140
+
141
+ def _merge_heads(self, hidden_states):
142
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
143
+
144
+ @nn.compact
145
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
146
+ """
147
+ This function takes projected key, value states from a single input token and concatenates the states to cached
148
+ states from previous steps. This function is slighly adapted from the official Flax repository:
149
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
150
+ """
151
+ # detect if we're initializing by absence of existing cache data.
152
+ is_initialized = self.has_variable("cache", "cached_key")
153
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
154
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
155
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
156
+
157
+ if is_initialized:
158
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
159
+ # update key, value caches with our new 1d spatial slices
160
+ cur_index = cache_index.value
161
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
162
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
163
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
164
+ cached_key.value = key
165
+ cached_value.value = value
166
+ num_updated_cache_vectors = query.shape[1]
167
+ cache_index.value = cache_index.value + num_updated_cache_vectors
168
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
169
+ pad_mask = jnp.broadcast_to(
170
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
171
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
172
+ )
173
+ attention_mask = combine_masks(pad_mask, attention_mask)
174
+ return key, value, attention_mask
175
+
176
+ def __call__(
177
+ self,
178
+ hidden_states,
179
+ key_value_states: Optional[jnp.ndarray] = None,
180
+ attention_mask=None,
181
+ deterministic: bool = True,
182
+ init_cache: bool = False,
183
+ output_attentions: bool = False,
184
+ ):
185
+
186
+ # if key_value_states are provided this layer is used as a cross-attention layer
187
+ # for the decoder
188
+ is_cross_attention = key_value_states is not None
189
+
190
+ qkv_out = self.c_attn(hidden_states)
191
+ query, key, value = jnp.split(qkv_out, 3, axis=2)
192
+
193
+ if is_cross_attention:
194
+ _qkv_out = self.c_attn_for_k_v(key_value_states)
195
+ key, value = jnp.split(_qkv_out, 2, axis=2)
196
+
197
+ query = self._split_heads(query)
198
+ key = self._split_heads(key)
199
+ value = self._split_heads(value)
200
+
201
+ query_length, key_length = query.shape[1], key.shape[1]
202
+
203
+ if self.causal:
204
+ if self.has_variable("cache", "cached_key"):
205
+ mask_shift = self.variables["cache"]["cache_index"]
206
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
207
+ causal_mask = lax.dynamic_slice(
208
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
209
+ )
210
+ else:
211
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
212
+
213
+ batch_size = hidden_states.shape[0]
214
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
215
+
216
+ # combine masks if needed
217
+ if attention_mask is not None and self.causal:
218
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
219
+ attention_mask = combine_masks(attention_mask, causal_mask)
220
+ elif self.causal:
221
+ attention_mask = causal_mask
222
+ elif attention_mask is not None:
223
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
224
+
225
+ dropout_rng = None
226
+ if not deterministic and self.config.attn_pdrop > 0.0:
227
+ dropout_rng = self.make_rng("dropout")
228
+
229
+ # During fast autoregressive decoding, we feed one position at a time,
230
+ # and cache the keys and values step by step.
231
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
232
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
233
+
234
+ # transform boolean mask into float mask
235
+ if attention_mask is not None:
236
+ attention_bias = lax.select(
237
+ attention_mask > 0,
238
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
239
+ jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
240
+ )
241
+ else:
242
+ attention_bias = None
243
+
244
+ # usual dot product attention
245
+ attn_weights = dot_product_attention_weights(
246
+ query,
247
+ key,
248
+ bias=attention_bias,
249
+ dropout_rng=dropout_rng,
250
+ dropout_rate=self.config.attn_pdrop,
251
+ deterministic=deterministic,
252
+ dtype=self.dtype,
253
+ precision=None,
254
+ )
255
+
256
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
257
+ attn_output = self._merge_heads(attn_output)
258
+ attn_output = self.c_proj(attn_output)
259
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
260
+
261
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
262
+ return outputs
263
+
264
+
265
+ class FlaxGPT2MLP(nn.Module):
266
+ config: GPT2Config
267
+ intermediate_size: int
268
+ dtype: jnp.dtype = jnp.float32
269
+
270
+ def setup(self):
271
+ embed_dim = self.config.hidden_size
272
+ self.c_fc = FlaxConv1D(self.intermediate_size, dtype=self.dtype)
273
+ self.c_proj = FlaxConv1D(embed_dim, dtype=self.dtype)
274
+ self.act = ACT2FN[self.config.activation_function]
275
+ self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
276
+
277
+ def __call__(self, hidden_states, deterministic: bool = True):
278
+ hidden_states = self.c_fc(hidden_states)
279
+ hidden_states = self.act(hidden_states)
280
+ hidden_states = self.c_proj(hidden_states)
281
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
282
+ return hidden_states
283
+
284
+
285
+ class FlaxGPT2Block(nn.Module):
286
+ config: GPT2Config
287
+ dtype: jnp.dtype = jnp.float32
288
+
289
+ def setup(self):
290
+ hidden_size = self.config.hidden_size
291
+ inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
292
+
293
+ self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
294
+ self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
295
+ self.ln_3 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
296
+ self.encoder_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype)
297
+ self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
298
+ self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
299
+
300
+ def __call__(
301
+ self,
302
+ hidden_states,
303
+ attention_mask=None,
304
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
305
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
306
+ deterministic: bool = True,
307
+ init_cache: bool = False,
308
+ output_attentions: bool = False,
309
+ ):
310
+ residual = hidden_states
311
+ hidden_states = self.ln_1(hidden_states)
312
+ outputs = self.attn(
313
+ hidden_states,
314
+ attention_mask=attention_mask,
315
+ deterministic=deterministic,
316
+ init_cache=init_cache,
317
+ output_attentions=output_attentions,
318
+ )
319
+ # residual connection
320
+ attn_output = outputs[0]
321
+ hidden_states = attn_output + residual
322
+
323
+ # Cross-Attention Block
324
+ if encoder_hidden_states is not None:
325
+
326
+ residual = hidden_states
327
+ hidden_states = self.ln_3(hidden_states)
328
+
329
+ cross_attn_outputs = self.encoder_attn(
330
+ hidden_states=hidden_states,
331
+ key_value_states=encoder_hidden_states,
332
+ attention_mask=encoder_attention_mask,
333
+ deterministic=deterministic,
334
+ output_attentions=output_attentions,
335
+ )
336
+
337
+ # residual connection
338
+ cross_attn_output = cross_attn_outputs[0]
339
+ hidden_states = cross_attn_output + residual
340
+
341
+ residual = hidden_states
342
+ hidden_states = self.ln_2(hidden_states)
343
+ feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
344
+ # residual connection
345
+ hidden_states = residual + feed_forward_hidden_states
346
+
347
+ output = (hidden_states,) + outputs[1:]
348
+ if encoder_hidden_states is not None:
349
+ output = output + cross_attn_outputs[1:]
350
+
351
+ return output
352
+
353
+
354
+ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
355
+ """
356
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
357
+ models.
358
+ """
359
+
360
+ config_class = GPT2Config
361
+ base_model_prefix = "transformer"
362
+ module_class: nn.Module = None
363
+
364
+ def __init__(
365
+ self,
366
+ config: GPT2Config,
367
+ input_shape: Tuple = (1, 1),
368
+ seed: int = 0,
369
+ dtype: jnp.dtype = jnp.float32,
370
+ **kwargs,
371
+ ):
372
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
373
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
374
+
375
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
376
+ # init input tensors
377
+ input_ids = jnp.zeros(input_shape, dtype="i4")
378
+ attention_mask = jnp.ones_like(input_ids)
379
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
380
+ params_rng, dropout_rng = jax.random.split(rng)
381
+ rngs = {"params": params_rng, "dropout": dropout_rng}
382
+
383
+ if self.config.add_cross_attention:
384
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
385
+ encoder_attention_mask = attention_mask
386
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, encoder_hidden_states, encoder_attention_mask, return_dict=False)
387
+ else:
388
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
389
+
390
+ return module_init_outputs["params"]
391
+
392
+ @classmethod
393
+ def _from_config(cls, config, **kwargs):
394
+ return super()._from_config(config, **kwargs)
395
+
396
+ def init_cache(self, batch_size, max_length):
397
+ r"""
398
+ Args:
399
+ batch_size (:obj:`int`):
400
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
401
+ max_length (:obj:`int`):
402
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
403
+ cache.
404
+ """
405
+ # init input variables to retrieve cache
406
+ input_ids = jnp.ones((batch_size, max_length))
407
+ attention_mask = jnp.ones_like(input_ids)
408
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
409
+
410
+ init_variables = self.module.init(
411
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
412
+ )
413
+ return init_variables["cache"]
414
+
415
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
416
+ def __call__(
417
+ self,
418
+ input_ids,
419
+ attention_mask=None,
420
+ position_ids=None,
421
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
422
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
423
+ params: dict = None,
424
+ past_key_values: dict = None,
425
+ dropout_rng: jax.random.PRNGKey = None,
426
+ train: bool = False,
427
+ output_attentions: Optional[bool] = None,
428
+ output_hidden_states: Optional[bool] = None,
429
+ return_dict: Optional[bool] = None,
430
+ ):
431
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
432
+ output_hidden_states = (
433
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
434
+ )
435
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
436
+
437
+ if encoder_hidden_states is not None and encoder_attention_mask is None:
438
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
439
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
440
+
441
+ batch_size, sequence_length = input_ids.shape
442
+
443
+ if position_ids is None:
444
+ if past_key_values is not None:
445
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
446
+
447
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
448
+
449
+ if attention_mask is None:
450
+ attention_mask = jnp.ones((batch_size, sequence_length))
451
+
452
+ # Handle any PRNG if needed
453
+ rngs = {}
454
+ if dropout_rng is not None:
455
+ rngs["dropout"] = dropout_rng
456
+
457
+ inputs = {"params": params or self.params}
458
+
459
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
460
+ if past_key_values:
461
+ inputs["cache"] = past_key_values
462
+ mutable = ["cache"]
463
+ else:
464
+ mutable = False
465
+
466
+ outputs = self.module.apply(
467
+ inputs,
468
+ jnp.array(input_ids, dtype="i4"),
469
+ jnp.array(attention_mask, dtype="i4"),
470
+ jnp.array(position_ids, dtype="i4"),
471
+ encoder_hidden_states,
472
+ encoder_attention_mask,
473
+ not train,
474
+ False,
475
+ output_attentions,
476
+ output_hidden_states,
477
+ return_dict,
478
+ rngs=rngs,
479
+ mutable=mutable,
480
+ )
481
+
482
+ # add updated cache to model output
483
+ if past_key_values is not None and return_dict:
484
+ outputs, past_key_values = outputs
485
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
486
+ return outputs
487
+ elif past_key_values is not None and not return_dict:
488
+ outputs, past_key_values = outputs
489
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
490
+
491
+ return outputs
492
+
493
+
494
+ class FlaxGPT2BlockCollection(nn.Module):
495
+ config: GPT2Config
496
+ dtype: jnp.dtype = jnp.float32
497
+
498
+ def setup(self):
499
+ self.blocks = [
500
+ FlaxGPT2Block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
501
+ ]
502
+
503
+ def __call__(
504
+ self,
505
+ hidden_states,
506
+ attention_mask=None,
507
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
508
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
509
+ deterministic: bool = True,
510
+ init_cache: bool = False,
511
+ output_attentions: bool = False,
512
+ output_hidden_states: bool = False,
513
+ return_dict: bool = True,
514
+ ):
515
+ all_attentions = () if output_attentions else None
516
+ all_hidden_states = () if output_hidden_states else None
517
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
518
+
519
+ for block in self.blocks:
520
+ if output_hidden_states:
521
+ all_hidden_states += (hidden_states,)
522
+
523
+ layer_outputs = block(
524
+ hidden_states,
525
+ attention_mask,
526
+ encoder_hidden_states=encoder_hidden_states,
527
+ encoder_attention_mask=encoder_attention_mask,
528
+ deterministic=deterministic,
529
+ init_cache=init_cache,
530
+ output_attentions=output_attentions,
531
+ )
532
+ hidden_states = layer_outputs[0]
533
+
534
+ if output_attentions:
535
+ all_attentions += (layer_outputs[1],)
536
+ if encoder_hidden_states is not None:
537
+ all_cross_attentions += (layer_outputs[2],)
538
+
539
+ if output_hidden_states:
540
+ all_hidden_states += (hidden_states,)
541
+
542
+ outputs = [hidden_states, all_hidden_states, all_attentions, all_cross_attentions]
543
+
544
+ if not return_dict:
545
+ return tuple(v for v in outputs if v is not None)
546
+
547
+ if encoder_hidden_states is None:
548
+ return FlaxBaseModelOutputWithPast(
549
+ last_hidden_state=hidden_states,
550
+ past_key_values=None,
551
+ hidden_states=all_hidden_states,
552
+ attentions=all_attentions,
553
+ )
554
+ else:
555
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
556
+ last_hidden_state=hidden_states,
557
+ past_key_values=None,
558
+ hidden_states=all_hidden_states,
559
+ attentions=all_attentions,
560
+ cross_attentions=all_cross_attentions,
561
+ )
562
+
563
+ class FlaxGPT2Module(nn.Module):
564
+ config: GPT2Config
565
+ dtype: jnp.dtype = jnp.float32
566
+
567
+ def setup(self):
568
+ self.embed_dim = self.config.hidden_size
569
+
570
+ self.wte = nn.Embed(
571
+ self.config.vocab_size,
572
+ self.embed_dim,
573
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
574
+ dtype=self.dtype,
575
+ )
576
+ self.wpe = nn.Embed(
577
+ self.config.max_position_embeddings,
578
+ self.embed_dim,
579
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
580
+ dtype=self.dtype,
581
+ )
582
+ self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
583
+ self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
584
+ self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
585
+
586
+ def __call__(
587
+ self,
588
+ input_ids,
589
+ attention_mask,
590
+ position_ids,
591
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
592
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
593
+ deterministic=True,
594
+ init_cache: bool = False,
595
+ output_attentions: bool = False,
596
+ output_hidden_states: bool = False,
597
+ return_dict: bool = True,
598
+ ):
599
+ input_embeds = self.wte(input_ids.astype("i4"))
600
+ position_embeds = self.wpe(position_ids.astype("i4"))
601
+
602
+ hidden_states = input_embeds + position_embeds
603
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
604
+
605
+ outputs = self.h(
606
+ hidden_states,
607
+ attention_mask,
608
+ encoder_hidden_states,
609
+ encoder_attention_mask,
610
+ deterministic=deterministic,
611
+ init_cache=init_cache,
612
+ output_attentions=output_attentions,
613
+ output_hidden_states=output_hidden_states,
614
+ return_dict=return_dict,
615
+ )
616
+
617
+ hidden_states = outputs[0]
618
+ hidden_states = self.ln_f(hidden_states)
619
+
620
+ if not return_dict:
621
+ return (hidden_states,) + outputs[1:]
622
+
623
+ if encoder_hidden_states is None:
624
+ return FlaxBaseModelOutput(
625
+ last_hidden_state=hidden_states,
626
+ hidden_states=outputs.hidden_states,
627
+ attentions=outputs.attentions,
628
+ )
629
+ else:
630
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
631
+ last_hidden_state=hidden_states,
632
+ hidden_states=outputs.hidden_states,
633
+ attentions=outputs.attentions,
634
+ cross_attentions=outputs.cross_attentions,
635
+ )
636
+
637
+ @add_start_docstrings(
638
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
639
+ GPT2_START_DOCSTRING,
640
+ )
641
+ class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
642
+ module_class = FlaxGPT2Module
643
+
644
+
645
+ append_call_sample_docstring(
646
+ FlaxGPT2Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC
647
+ )
648
+
649
+
650
+ class FlaxGPT2LMHeadModule(nn.Module):
651
+ config: GPT2Config
652
+ dtype: jnp.dtype = jnp.float32
653
+
654
+ def setup(self):
655
+ self.transformer = FlaxGPT2Module(self.config, dtype=self.dtype)
656
+ self.lm_head = nn.Dense(
657
+ self.config.vocab_size,
658
+ use_bias=False,
659
+ dtype=self.dtype,
660
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype),
661
+ )
662
+
663
+ def __call__(
664
+ self,
665
+ input_ids,
666
+ attention_mask,
667
+ position_ids,
668
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
669
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
670
+ deterministic: bool = True,
671
+ init_cache: bool = False,
672
+ output_attentions: bool = False,
673
+ output_hidden_states: bool = False,
674
+ return_dict: bool = True,
675
+ ):
676
+ outputs = self.transformer(
677
+ input_ids,
678
+ attention_mask,
679
+ position_ids,
680
+ encoder_hidden_states,
681
+ encoder_attention_mask,
682
+ deterministic=deterministic,
683
+ init_cache=init_cache,
684
+ output_attentions=output_attentions,
685
+ output_hidden_states=output_hidden_states,
686
+ return_dict=return_dict,
687
+ )
688
+
689
+ hidden_states = outputs[0]
690
+
691
+ if self.config.tie_word_embeddings:
692
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
693
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
694
+ else:
695
+ lm_logits = self.lm_head(hidden_states)
696
+
697
+ if not return_dict:
698
+ return (lm_logits,) + outputs[1:]
699
+
700
+ if encoder_hidden_states is None:
701
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
702
+ else:
703
+ return FlaxSeq2SeqLMOutput(
704
+ logits=lm_logits,
705
+ decoder_hidden_states=outputs.hidden_states,
706
+ decoder_attentions=outputs.attentions,
707
+ cross_attentions=outputs.cross_attentions,
708
+ encoder_last_hidden_state=encoder_hidden_states,
709
+ encoder_hidden_states=None,
710
+ encoder_attentions=None,
711
+ )
712
+
713
+ @add_start_docstrings(
714
+ """
715
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
716
+ embeddings).
717
+ """,
718
+ GPT2_START_DOCSTRING,
719
+ )
720
+ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
721
+ module_class = FlaxGPT2LMHeadModule
722
+
723
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
724
+ # initializing the cache
725
+ batch_size, seq_length = input_ids.shape
726
+
727
+ past_key_values = self.init_cache(batch_size, max_length)
728
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
729
+ # But since GPT2 uses a causal mask, those positions are masked anyways.
730
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
731
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
732
+ if attention_mask is not None:
733
+ position_ids = attention_mask.cumsum(axis=-1) - 1
734
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
735
+ else:
736
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
737
+
738
+ return {
739
+ "past_key_values": past_key_values,
740
+ "attention_mask": extended_attention_mask,
741
+ "position_ids": position_ids,
742
+ }
743
+
744
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
745
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
746
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
747
+ return model_kwargs
748
+
749
+
750
+ append_call_sample_docstring(
751
+ FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC
752
+ )
vit_gpt2/modeling_flax_vit_gpt2.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax.core.frozen_dict import FrozenDict, unfreeze
7
+ from jax import lax
8
+ from jax.random import PRNGKey
9
+ from transformers import GPT2Config, FlaxViTModel, ViTConfig
10
+ from transformers.modeling_flax_outputs import (
11
+ FlaxCausalLMOutputWithCrossAttentions,
12
+ FlaxSeq2SeqLMOutput,
13
+ FlaxSeq2SeqModelOutput,
14
+ )
15
+ from transformers.models.bart.modeling_flax_bart import (
16
+ shift_tokens_right,
17
+ )
18
+ from .modeling_flax_gpt2 import (
19
+ FlaxGPT2Module,
20
+ FlaxGPT2Model,
21
+ FlaxPreTrainedModel
22
+ )
23
+ from transformers.models.vit.modeling_flax_vit import FlaxViTModule
24
+
25
+ from .configuration_vit_gpt2 import ViTGPT2Config
26
+
27
+
28
+ class FlaxViTGPT2Module(nn.Module):
29
+ config: ViTGPT2Config
30
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
31
+
32
+ def setup(self):
33
+
34
+ self.encoder = FlaxViTModule(self.config.vit_config, dtype=self.dtype)
35
+ self.decoder = FlaxGPT2Module(self.config.gpt2_config, dtype=self.dtype)
36
+
37
+ def _get_encoder_module(self):
38
+ return self.encoder
39
+
40
+ def _get_decoder_module(self):
41
+ return self.decoder
42
+
43
+ def __call__(
44
+ self,
45
+ pixel_values,
46
+ input_ids,
47
+ attention_mask,
48
+ position_ids,
49
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
50
+ output_attentions: bool = False,
51
+ output_hidden_states: bool = False,
52
+ return_dict: bool = True,
53
+ deterministic: bool = True,
54
+ ):
55
+ encoder_outputs = self.encoder(
56
+ pixel_values=pixel_values,
57
+ deterministic=deterministic,
58
+ output_attentions=output_attentions,
59
+ output_hidden_states=output_hidden_states,
60
+ return_dict=return_dict,
61
+ )
62
+
63
+ decoder_outputs = self.decoder(
64
+ input_ids=input_ids,
65
+ attention_mask=attention_mask,
66
+ position_ids=position_ids,
67
+ encoder_hidden_states=encoder_outputs[0],
68
+ encoder_attention_mask=encoder_attention_mask,
69
+ deterministic=deterministic,
70
+ output_attentions=output_attentions,
71
+ output_hidden_states=output_hidden_states,
72
+ return_dict=return_dict
73
+ )
74
+
75
+ return FlaxSeq2SeqModelOutput(
76
+ last_hidden_state=decoder_outputs.last_hidden_state,
77
+ decoder_hidden_states=decoder_outputs.hidden_states,
78
+ decoder_attentions=decoder_outputs.attentions,
79
+ cross_attentions=decoder_outputs.cross_attentions,
80
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
81
+ encoder_hidden_states=encoder_outputs.hidden_states,
82
+ encoder_attentions=encoder_outputs.attentions,
83
+ )
84
+
85
+
86
+ class FlaxViTGPT2ForConditionalGenerationModule(nn.Module):
87
+ config: ViTGPT2Config
88
+ dtype: jnp.dtype = jnp.float32
89
+ bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
90
+
91
+ def setup(self):
92
+ self.model = FlaxViTGPT2Module(config=self.config, dtype=self.dtype)
93
+ self.lm_head = nn.Dense(
94
+ self.model.decoder.embed_dim,
95
+ use_bias=False,
96
+ dtype=self.dtype,
97
+ kernel_init=jax.nn.initializers.normal(
98
+ self.config.gpt2_config.initializer_range, self.dtype
99
+ ),
100
+ )
101
+ self.final_logits_bias = self.param(
102
+ "final_logits_bias", self.bias_init, (1, self.model.decoder.embed_dim)
103
+ )
104
+
105
+ def _get_encoder_module(self):
106
+ return self.model.encoder
107
+
108
+ def _get_decoder_module(self):
109
+ return self.model.decoder
110
+
111
+ def __call__(
112
+ self,
113
+ pixel_values,
114
+ input_ids,
115
+ attention_mask,
116
+ position_ids,
117
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
118
+ output_attentions: bool = False,
119
+ output_hidden_states: bool = False,
120
+ return_dict: bool = True,
121
+ deterministic: bool = True,
122
+ ):
123
+ outputs = self.model(
124
+ pixel_values=pixel_values,
125
+ input_ids=input_ids,
126
+ attention_mask=attention_mask,
127
+ position_ids=position_ids,
128
+ encoder_attention_mask=encoder_attention_mask,
129
+ output_attentions=output_attentions,
130
+ output_hidden_states=output_hidden_states,
131
+ return_dict=return_dict,
132
+ deterministic=deterministic,
133
+ )
134
+
135
+ hidden_states = outputs[0]
136
+ lm_logits = self.lm_head(hidden_states)
137
+ lm_logits += self.final_logits_bias
138
+
139
+ if not return_dict:
140
+ output = (lm_logits,) + outputs[1:]
141
+ return output
142
+
143
+ return FlaxSeq2SeqLMOutput(
144
+ logits=lm_logits,
145
+ decoder_hidden_states=outputs.decoder_hidden_states,
146
+ decoder_attentions=outputs.decoder_attentions,
147
+ cross_attentions=outputs.cross_attentions,
148
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
149
+ encoder_hidden_states=outputs.encoder_hidden_states,
150
+ encoder_attentions=outputs.encoder_attentions,
151
+ )
152
+
153
+ class FlaxViTGPT2PreTrainedModel(FlaxPreTrainedModel):
154
+ config_class = ViTGPT2Config
155
+ base_model_prefix: str = "model"
156
+ module_class: nn.Module = None
157
+
158
+ def __init__(
159
+ self,
160
+ config: ViTGPT2Config,
161
+ input_shape: Tuple = None,
162
+ seed: int = 0,
163
+ dtype: jnp.dtype = jnp.float32,
164
+ **kwargs,
165
+ ):
166
+ if input_shape is None:
167
+ input_shape = (
168
+ (1, config.vit_config.image_size, config.vit_config.image_size, 3),
169
+ (1, 1),
170
+ )
171
+
172
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
173
+ super().__init__(
174
+ config, module, input_shape=input_shape, seed=seed, dtype=dtype
175
+ )
176
+
177
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
178
+ # init input tensors
179
+ pixel_values = jax.random.normal(rng, input_shape[0])
180
+ # # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
181
+ # input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
182
+
183
+ input_ids = jnp.zeros(input_shape[1], dtype="i4")
184
+ attention_mask = jnp.ones_like(input_ids)
185
+
186
+ batch_size, sequence_length = input_ids.shape
187
+ position_ids = jnp.broadcast_to(
188
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
189
+ )
190
+
191
+ params_rng, dropout_rng = jax.random.split(rng)
192
+ rngs = {"params": params_rng, "dropout": dropout_rng}
193
+
194
+ return self.module.init(
195
+ rngs,
196
+ pixel_values,
197
+ input_ids,
198
+ attention_mask,
199
+ position_ids,
200
+ )["params"]
201
+
202
+ def init_cache(self, batch_size, max_length, encoder_outputs):
203
+
204
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
205
+ attention_mask = jnp.ones_like(input_ids)
206
+ position_ids = jnp.broadcast_to(
207
+ jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
208
+ input_ids.shape,
209
+ )
210
+
211
+ def _decoder_forward(
212
+ module,
213
+ input_ids,
214
+ attention_mask,
215
+ position_ids,
216
+ **kwargs,
217
+ ):
218
+ decoder_module = module._get_decoder_module()
219
+ return decoder_module(
220
+ input_ids,
221
+ attention_mask,
222
+ position_ids,
223
+ **kwargs,
224
+ )
225
+
226
+ init_variables = self.module.init(
227
+ jax.random.PRNGKey(0),
228
+ input_ids=input_ids,
229
+ attention_mask=attention_mask,
230
+ position_ids=position_ids,
231
+ encoder_hidden_states=encoder_outputs[0],
232
+ init_cache=True,
233
+ method=_decoder_forward, # we only need to call the decoder to init the cache
234
+ )
235
+ return unfreeze(init_variables["cache"])
236
+
237
+ def encode(
238
+ self,
239
+ pixel_values: jnp.ndarray,
240
+ output_attentions: Optional[bool] = None,
241
+ output_hidden_states: Optional[bool] = None,
242
+ return_dict: Optional[bool] = None,
243
+ train: bool = False,
244
+ params: dict = None,
245
+ dropout_rng: PRNGKey = None,
246
+ ):
247
+ output_attentions = (
248
+ output_attentions
249
+ if output_attentions is not None
250
+ else self.config.output_attentions
251
+ )
252
+ output_hidden_states = (
253
+ output_hidden_states
254
+ if output_hidden_states is not None
255
+ else self.config.output_hidden_states
256
+ )
257
+ return_dict = (
258
+ return_dict if return_dict is not None else self.config.return_dict
259
+ )
260
+
261
+ # Handle any PRNG if needed
262
+ rngs = {}
263
+ if dropout_rng is not None:
264
+ rngs["dropout"] = dropout_rng
265
+
266
+ def _encoder_forward(module, pixel_values, **kwargs):
267
+ encode_module = module._get_encoder_module()
268
+ return encode_module(pixel_values, **kwargs)
269
+
270
+ return self.module.apply(
271
+ {"params": params or self.params},
272
+ pixel_values=jnp.array(pixel_values, dtype="i4"),
273
+ output_attentions=output_attentions,
274
+ output_hidden_states=output_hidden_states,
275
+ return_dict=return_dict,
276
+ deterministic=not train,
277
+ rngs=rngs,
278
+ method=_encoder_forward,
279
+ )
280
+
281
+ def decode(
282
+ self,
283
+ input_ids,
284
+ encoder_outputs,
285
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
286
+ attention_mask: Optional[jnp.ndarray] = None,
287
+ position_ids: Optional[jnp.ndarray] = None,
288
+ past_key_values: dict = None,
289
+ output_attentions: Optional[bool] = None,
290
+ output_hidden_states: Optional[bool] = None,
291
+ return_dict: Optional[bool] = None,
292
+ train: bool = False,
293
+ params: dict = None,
294
+ dropout_rng: PRNGKey = None,
295
+ ):
296
+
297
+ output_attentions = (
298
+ output_attentions
299
+ if output_attentions is not None
300
+ else self.config.output_attentions
301
+ )
302
+ output_hidden_states = (
303
+ output_hidden_states
304
+ if output_hidden_states is not None
305
+ else self.config.output_hidden_states
306
+ )
307
+ return_dict = (
308
+ return_dict if return_dict is not None else self.config.return_dict
309
+ )
310
+
311
+ encoder_hidden_states = encoder_outputs[0]
312
+ if encoder_attention_mask is None:
313
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
314
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
315
+
316
+ batch_size, sequence_length = input_ids.shape
317
+ if attention_mask is None:
318
+ attention_mask = jnp.ones((batch_size, sequence_length))
319
+
320
+ if position_ids is None:
321
+ if past_key_values is not None:
322
+ raise ValueError(
323
+ "Make sure to provide `position_ids` when passing `past_key_values`."
324
+ )
325
+
326
+ position_ids = jnp.broadcast_to(
327
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
328
+ )
329
+
330
+ # Handle any PRNG if needed
331
+ rngs = {}
332
+ if dropout_rng is not None:
333
+ rngs["dropout"] = dropout_rng
334
+
335
+ inputs = {"params": params or self.params}
336
+
337
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
338
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
339
+ # it can be changed by FlaxGPT2Attention module
340
+ if past_key_values:
341
+ inputs["cache"] = past_key_values
342
+ mutable = ["cache"]
343
+ else:
344
+ mutable = False
345
+
346
+ def _decoder_forward(
347
+ module,
348
+ input_ids,
349
+ attention_mask,
350
+ position_ids,
351
+ **kwargs,
352
+ ):
353
+ decoder_module = module._get_decoder_module()
354
+ return decoder_module(
355
+ input_ids,
356
+ attention_mask,
357
+ position_ids,
358
+ **kwargs,
359
+ )
360
+
361
+ outputs = self.module.apply(
362
+ inputs,
363
+ input_ids=jnp.array(input_ids, dtype="i4"),
364
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
365
+ position_ids=jnp.array(position_ids, dtype="i4"),
366
+ encoder_hidden_states=encoder_hidden_states,
367
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
368
+ output_attentions=output_attentions,
369
+ output_hidden_states=output_hidden_states,
370
+ return_dict=return_dict,
371
+ deterministic=not train,
372
+ rngs=rngs,
373
+ mutable=mutable,
374
+ method=_decoder_forward,
375
+ )
376
+
377
+ # add updated cache to model output
378
+ if past_key_values is not None and return_dict:
379
+ outputs, past = outputs
380
+ outputs["past_key_values"] = unfreeze(past["cache"])
381
+ return outputs
382
+ elif past_key_values is not None and not return_dict:
383
+ outputs, past = outputs
384
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
385
+
386
+ return outputs
387
+
388
+ def __call__(
389
+ self,
390
+ pixel_values: jnp.ndarray,
391
+ input_ids: Optional[jnp.ndarray] = None,
392
+ attention_mask: Optional[jnp.ndarray] = None,
393
+ position_ids: Optional[jnp.ndarray] = None,
394
+ output_attentions: Optional[bool] = None,
395
+ output_hidden_states: Optional[bool] = None,
396
+ return_dict: Optional[bool] = None,
397
+ train: bool = False,
398
+ params: dict = None,
399
+ dropout_rng: PRNGKey = None,
400
+ ):
401
+ output_attentions = (
402
+ output_attentions
403
+ if output_attentions is not None
404
+ else self.config.output_attentions
405
+ )
406
+ output_hidden_states = (
407
+ output_hidden_states
408
+ if output_hidden_states is not None
409
+ else self.config.output_hidden_states
410
+ )
411
+ return_dict = (
412
+ return_dict if return_dict is not None else self.config.return_dict
413
+ )
414
+
415
+ pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
416
+
417
+ # # prepare encoder inputs
418
+ # if encoder_attention_mask is None:
419
+ # encoder_attention_mask = jnp.ones_like(input_ids)
420
+
421
+ # if position_ids is None:
422
+ # batch_size, sequence_length = input_ids.shape
423
+ # position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
424
+
425
+ # prepare decoder inputs
426
+ # if decoder_input_ids is None:
427
+ # decoder_input_ids = shift_tokens_right(
428
+ # input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
429
+ # ) # TODO: Check how to use this
430
+
431
+ if attention_mask is None:
432
+ attention_mask = jnp.ones_like(input_ids)
433
+ if position_ids is None:
434
+ batch_size, sequence_length = input_ids.shape
435
+ position_ids = jnp.broadcast_to(
436
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
437
+ )
438
+
439
+ # Handle any PRNG if needed
440
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
441
+
442
+ return self.module.apply(
443
+ {"params": params or self.params},
444
+ pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
445
+ input_ids=jnp.array(input_ids, dtype="i4"),
446
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
447
+ position_ids=jnp.array(position_ids, dtype="i4"),
448
+ output_attentions=output_attentions,
449
+ output_hidden_states=output_hidden_states,
450
+ return_dict=return_dict,
451
+ deterministic=not train,
452
+ rngs=rngs,
453
+ )
454
+
455
+
456
+ class FlaxViTGPT2ForConditionalGeneration(FlaxViTGPT2PreTrainedModel):
457
+ module_class = FlaxViTGPT2ForConditionalGenerationModule
458
+ dtype: jnp.dtype = jnp.float32
459
+
460
+ def decode(
461
+ self,
462
+ input_ids,
463
+ encoder_outputs,
464
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
465
+ attention_mask: Optional[jnp.ndarray] = None,
466
+ position_ids: Optional[jnp.ndarray] = None,
467
+ past_key_values: dict = None,
468
+ output_attentions: Optional[bool] = None,
469
+ output_hidden_states: Optional[bool] = None,
470
+ return_dict: Optional[bool] = None,
471
+ deterministic: bool = True,
472
+ params: dict = None,
473
+ dropout_rng: PRNGKey = None,
474
+ ):
475
+ output_attentions = (
476
+ output_attentions
477
+ if output_attentions is not None
478
+ else self.config.output_attentions
479
+ )
480
+ output_hidden_states = (
481
+ output_hidden_states
482
+ if output_hidden_states is not None
483
+ else self.config.output_hidden_states
484
+ )
485
+ return_dict = (
486
+ return_dict if return_dict is not None else self.config.return_dict
487
+ )
488
+
489
+ encoder_hidden_states = encoder_outputs[0]
490
+ if encoder_attention_mask is None:
491
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
492
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
493
+
494
+ batch_size, sequence_length = input_ids.shape
495
+ if attention_mask is None:
496
+ attention_mask = jnp.ones((batch_size, sequence_length))
497
+
498
+ if position_ids is None:
499
+ if past_key_values is not None:
500
+ raise ValueError(
501
+ "Make sure to provide `position_ids` when passing `past_key_values`."
502
+ )
503
+
504
+ position_ids = jnp.broadcast_to(
505
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
506
+ )
507
+
508
+ # Handle any PRNG if needed
509
+ rngs = {}
510
+ if dropout_rng is not None:
511
+ rngs["dropout"] = dropout_rng
512
+
513
+ inputs = {"params": params or self.params}
514
+
515
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
516
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
517
+ # it can be changed by FlaxGPT2Attention module
518
+ if past_key_values:
519
+ inputs["cache"] = past_key_values
520
+ mutable = ["cache"]
521
+ else:
522
+ mutable = False
523
+
524
+ def _decoder_forward(
525
+ module,
526
+ input_ids,
527
+ attention_mask,
528
+ position_ids,
529
+ **kwargs,
530
+ ):
531
+ decoder_module = module._get_decoder_module()
532
+ outputs = decoder_module(
533
+ input_ids,
534
+ attention_mask,
535
+ position_ids,
536
+ **kwargs,
537
+ )
538
+ hidden_states = outputs[0]
539
+
540
+ if self.config.tie_word_embeddings:
541
+ shared_embedding = module.model.variables["params"]["shared"][
542
+ "embedding"
543
+ ]
544
+ lm_logits = module.lm_head.apply(
545
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
546
+ )
547
+ else:
548
+ lm_logits = module.lm_head(hidden_states)
549
+
550
+ lm_logits += module.final_logits_bias
551
+ return lm_logits, outputs
552
+
553
+ outputs = self.module.apply(
554
+ inputs,
555
+ input_ids=jnp.array(input_ids, dtype="i4"),
556
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
557
+ position_ids=jnp.array(position_ids, dtype="i4"),
558
+ encoder_hidden_states=encoder_hidden_states,
559
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
560
+ output_attentions=output_attentions,
561
+ output_hidden_states=output_hidden_states,
562
+ return_dict=return_dict,
563
+ deterministic=deterministic,
564
+ rngs=rngs,
565
+ mutable=mutable,
566
+ method=_decoder_forward,
567
+ )
568
+
569
+ if past_key_values is None:
570
+ lm_logits, outputs = outputs
571
+ else:
572
+ (lm_logits, outputs), past = outputs
573
+
574
+ if return_dict:
575
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
576
+ logits=lm_logits,
577
+ hidden_states=outputs.hidden_states,
578
+ attentions=outputs.attentions,
579
+ cross_attentions=outputs.cross_attentions,
580
+ )
581
+ else:
582
+ outputs = (lm_logits,) + outputs[1:]
583
+
584
+ # add updated cache to model output
585
+ if past_key_values is not None and return_dict:
586
+ outputs["past_key_values"] = unfreeze(past["cache"])
587
+ return outputs
588
+ elif past_key_values is not None and not return_dict:
589
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
590
+
591
+ return outputs
592
+
593
+ def prepare_inputs_for_generation(
594
+ self,
595
+ input_ids,
596
+ max_length,
597
+ encoder_attention_mask: Optional[jnp.DeviceArray] = None,
598
+ attention_mask: Optional[jnp.DeviceArray] = None,
599
+ encoder_outputs=None,
600
+ **kwargs,
601
+ ):
602
+ # initializing the cache
603
+ batch_size, seq_length = input_ids.shape
604
+
605
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
606
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
607
+ # But since the decoder uses a causal mask, those positions are masked anyways.
608
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
609
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
610
+ if attention_mask is not None:
611
+ position_ids = attention_mask.cumsum(axis=-1) - 1
612
+ extended_attention_mask = lax.dynamic_update_slice(
613
+ extended_attention_mask, attention_mask, (0, 0)
614
+ )
615
+ else:
616
+ position_ids = jnp.broadcast_to(
617
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
618
+ )
619
+
620
+ return {
621
+ "past_key_values": past_key_values,
622
+ "encoder_outputs": encoder_outputs,
623
+ "encoder_attention_mask": encoder_attention_mask,
624
+ "attention_mask": extended_attention_mask,
625
+ "position_ids": position_ids,
626
+ }
627
+
628
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
629
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
630
+ model_kwargs["position_ids"] = (
631
+ model_kwargs["position_ids"][:, -1:] + 1
632
+ )
633
+ return model_kwargs
634
+
635
+ @classmethod
636
+ def from_vit_gpt2_pretrained(
637
+ cls,
638
+ vit_model_name_or_path: str = None,
639
+ gpt2_model_name_or_path: str = None,
640
+ *model_args,
641
+ **kwargs,
642
+ ) -> FlaxViTGPT2PreTrainedModel:
643
+
644
+ kwargs_gpt2 = {
645
+ argument[len("gpt2_") :]: value
646
+ for argument, value in kwargs.items()
647
+ if argument.startswith("gpt2_")
648
+ }
649
+
650
+ kwargs_vit = {
651
+ argument[len("vit_") :]: value
652
+ for argument, value in kwargs.items()
653
+ if argument.startswith("vit_")
654
+ }
655
+
656
+ # remove gpt2, vit kwargs from kwargs
657
+ for key in kwargs_gpt2.keys():
658
+ del kwargs["gpt2_" + key]
659
+ for key in kwargs_vit.keys():
660
+ del kwargs["vit_" + key]
661
+
662
+ # Load and initialize the gpt2 and vit model
663
+ gpt2_model = kwargs_gpt2.pop("model", None)
664
+ if gpt2_model is None:
665
+ assert (
666
+ gpt2_model_name_or_path is not None
667
+ ), "If `model` is not defined as an argument, a `gpt2_model_name_or_path` has to be defined"
668
+
669
+ if "config" not in kwargs_gpt2:
670
+ gpt2_config = GPT2Config.from_pretrained(gpt2_model_name_or_path)
671
+ kwargs_gpt2["config"] = gpt2_config
672
+
673
+ kwargs_gpt2["config"].add_cross_attention = True
674
+ gpt2_model = FlaxGPT2Model.from_pretrained(
675
+ gpt2_model_name_or_path, *model_args, **kwargs_gpt2
676
+ )
677
+
678
+ vit_model = kwargs_vit.pop("model", None)
679
+ if vit_model is None:
680
+ assert (
681
+ vit_model_name_or_path is not None
682
+ ), "If `model` is not defined as an argument, a `vit_model_name_or_path` has to be defined"
683
+
684
+ if "config" not in kwargs_vit:
685
+ vit_config = ViTConfig.from_pretrained(vit_model_name_or_path)
686
+ kwargs_vit["config"] = vit_config
687
+
688
+ vit_model = FlaxViTModel.from_pretrained(
689
+ vit_model_name_or_path, *model_args, **kwargs_vit
690
+ )
691
+
692
+ # instantiate config with corresponding kwargs
693
+ dtype = kwargs.pop("dtype", jnp.float32)
694
+ config = ViTGPT2Config.from_vit_gpt2_configs(
695
+ vit_model.config, gpt2_model.config, **kwargs
696
+ )
697
+
698
+ # init model
699
+ model = cls(config, *model_args, dtype=dtype, **kwargs)
700
+ model.params["model"]["encoder"] = vit_model.params
701
+ model.params["model"]["decoder"] = gpt2_model.params
702
+
703
+ return model
704
+
vit_gpt2/modeling_flax_vit_gpt2_lm.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax.core.frozen_dict import FrozenDict, unfreeze
7
+ from jax import lax
8
+ from jax.random import PRNGKey
9
+ from transformers import GPT2Config, FlaxViTModel, ViTConfig
10
+ from transformers.modeling_flax_outputs import (
11
+ FlaxCausalLMOutputWithCrossAttentions,
12
+ FlaxSeq2SeqLMOutput,
13
+ FlaxSeq2SeqModelOutput,
14
+ )
15
+ from transformers.models.bart.modeling_flax_bart import (
16
+ shift_tokens_right,
17
+ )
18
+ from .modeling_flax_gpt2 import (
19
+ FlaxGPT2Module,
20
+ FlaxGPT2Model,
21
+ FlaxGPT2LMHeadModule,
22
+ FlaxGPT2LMHeadModel,
23
+ FlaxPreTrainedModel
24
+ )
25
+ from transformers.models.vit.modeling_flax_vit import FlaxViTModule
26
+
27
+ from .configuration_vit_gpt2 import ViTGPT2Config
28
+
29
+
30
+ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
31
+ """
32
+ Shift input ids one token to the right.
33
+ """
34
+ shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
35
+ shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
36
+ # replace possible -100 values in labels by `pad_token_id`
37
+ shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
38
+
39
+ return shifted_input_ids
40
+
41
+ class FlaxViTGPT2LMModule(nn.Module):
42
+ config: ViTGPT2Config
43
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
44
+
45
+ def setup(self):
46
+
47
+ self.encoder = FlaxViTModule(self.config.vit_config, dtype=self.dtype)
48
+ self.decoder = FlaxGPT2LMHeadModule(self.config.gpt2_config, dtype=self.dtype)
49
+
50
+ def _get_encoder_module(self):
51
+ return self.encoder
52
+
53
+ def _get_decoder_module(self):
54
+ return self.decoder
55
+
56
+ def __call__(
57
+ self,
58
+ pixel_values,
59
+ input_ids,
60
+ attention_mask,
61
+ position_ids,
62
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
63
+ output_attentions: bool = False,
64
+ output_hidden_states: bool = False,
65
+ return_dict: bool = True,
66
+ deterministic: bool = True,
67
+ ):
68
+ encoder_outputs = self.encoder(
69
+ pixel_values=pixel_values,
70
+ deterministic=deterministic,
71
+ output_attentions=output_attentions,
72
+ output_hidden_states=output_hidden_states,
73
+ return_dict=return_dict,
74
+ )
75
+
76
+ decoder_outputs = self.decoder(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ position_ids=position_ids,
80
+ encoder_hidden_states=encoder_outputs[0],
81
+ encoder_attention_mask=encoder_attention_mask,
82
+ deterministic=deterministic,
83
+ output_attentions=output_attentions,
84
+ output_hidden_states=output_hidden_states,
85
+ return_dict=return_dict
86
+ )
87
+
88
+ if not return_dict:
89
+ return decoder_outputs + encoder_outputs
90
+
91
+ return FlaxSeq2SeqLMOutput(
92
+ logits=decoder_outputs.logits,
93
+ decoder_hidden_states=decoder_outputs.decoder_hidden_states,
94
+ decoder_attentions=decoder_outputs.decoder_attentions,
95
+ cross_attentions=decoder_outputs.cross_attentions,
96
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
97
+ encoder_hidden_states=encoder_outputs.hidden_states,
98
+ encoder_attentions=encoder_outputs.attentions,
99
+ )
100
+
101
+ class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module):
102
+ config: ViTGPT2Config
103
+ dtype: jnp.dtype = jnp.float32
104
+ bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
105
+
106
+ def setup(self):
107
+ self.model = FlaxViTGPT2LMModule(config=self.config, dtype=self.dtype)
108
+
109
+ def _get_encoder_module(self):
110
+ return self.model.encoder
111
+
112
+ def _get_decoder_module(self):
113
+ return self.model.decoder
114
+
115
+ def __call__(
116
+ self,
117
+ pixel_values,
118
+ input_ids,
119
+ attention_mask,
120
+ position_ids,
121
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
122
+ output_attentions: bool = False,
123
+ output_hidden_states: bool = False,
124
+ return_dict: bool = True,
125
+ deterministic: bool = True,
126
+ ):
127
+ outputs = self.model(
128
+ pixel_values=pixel_values,
129
+ input_ids=input_ids,
130
+ attention_mask=attention_mask,
131
+ position_ids=position_ids,
132
+ encoder_attention_mask=encoder_attention_mask,
133
+ output_attentions=output_attentions,
134
+ output_hidden_states=output_hidden_states,
135
+ return_dict=return_dict,
136
+ deterministic=deterministic,
137
+ )
138
+
139
+ return outputs
140
+
141
+
142
+ class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
143
+ config_class = ViTGPT2Config
144
+ base_model_prefix: str = "model"
145
+ module_class: nn.Module = None
146
+
147
+ def __init__(
148
+ self,
149
+ config: ViTGPT2Config,
150
+ input_shape: Tuple = None,
151
+ seed: int = 0,
152
+ dtype: jnp.dtype = jnp.float32,
153
+ **kwargs,
154
+ ):
155
+ if input_shape is None:
156
+ input_shape = (
157
+ (1, config.vit_config.image_size, config.vit_config.image_size, 3),
158
+ (1, 1),
159
+ )
160
+
161
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
162
+ super().__init__(
163
+ config, module, input_shape=input_shape, seed=seed, dtype=dtype
164
+ )
165
+
166
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
167
+ # init input tensors
168
+ pixel_values = jax.random.normal(rng, input_shape[0])
169
+ # # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
170
+ # input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
171
+
172
+ input_ids = jnp.zeros(input_shape[1], dtype="i4")
173
+ attention_mask = jnp.ones_like(input_ids)
174
+
175
+ batch_size, sequence_length = input_ids.shape
176
+ position_ids = jnp.broadcast_to(
177
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
178
+ )
179
+
180
+ params_rng, dropout_rng = jax.random.split(rng)
181
+ rngs = {"params": params_rng, "dropout": dropout_rng}
182
+
183
+ return self.module.init(
184
+ rngs,
185
+ pixel_values,
186
+ input_ids,
187
+ attention_mask,
188
+ position_ids,
189
+ )["params"]
190
+
191
+ def init_cache(self, batch_size, max_length, encoder_outputs):
192
+
193
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
194
+ attention_mask = jnp.ones_like(input_ids)
195
+ position_ids = jnp.broadcast_to(
196
+ jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
197
+ input_ids.shape,
198
+ )
199
+
200
+ def _decoder_forward(
201
+ module,
202
+ input_ids,
203
+ attention_mask,
204
+ position_ids,
205
+ **kwargs,
206
+ ):
207
+ decoder_module = module._get_decoder_module()
208
+ return decoder_module(
209
+ input_ids,
210
+ attention_mask,
211
+ position_ids,
212
+ **kwargs,
213
+ )
214
+
215
+ init_variables = self.module.init(
216
+ jax.random.PRNGKey(0),
217
+ input_ids=input_ids,
218
+ attention_mask=attention_mask,
219
+ position_ids=position_ids,
220
+ encoder_hidden_states=encoder_outputs[0],
221
+ init_cache=True,
222
+ method=_decoder_forward, # we only need to call the decoder to init the cache
223
+ )
224
+ return unfreeze(init_variables["cache"])
225
+
226
+ def encode(
227
+ self,
228
+ pixel_values: jnp.ndarray,
229
+ attention_mask: Optional[jnp.ndarray] = None,
230
+ output_attentions: Optional[bool] = None,
231
+ output_hidden_states: Optional[bool] = None,
232
+ return_dict: Optional[bool] = None,
233
+ train: bool = False,
234
+ params: dict = None,
235
+ dropout_rng: PRNGKey = None,
236
+ ):
237
+ output_attentions = (
238
+ output_attentions
239
+ if output_attentions is not None
240
+ else self.config.output_attentions
241
+ )
242
+ output_hidden_states = (
243
+ output_hidden_states
244
+ if output_hidden_states is not None
245
+ else self.config.output_hidden_states
246
+ )
247
+ return_dict = (
248
+ return_dict if return_dict is not None else self.config.return_dict
249
+ )
250
+
251
+ pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
252
+
253
+ # Handle any PRNG if needed
254
+ rngs = {}
255
+ if dropout_rng is not None:
256
+ rngs["dropout"] = dropout_rng
257
+
258
+ def _encoder_forward(module, pixel_values, **kwargs):
259
+ encode_module = module._get_encoder_module()
260
+ return encode_module(pixel_values, **kwargs)
261
+
262
+ return self.module.apply(
263
+ {"params": params or self.params},
264
+ pixel_values=jnp.array(pixel_values, dtype="i4"),
265
+ output_attentions=output_attentions,
266
+ output_hidden_states=output_hidden_states,
267
+ return_dict=return_dict,
268
+ deterministic=not train,
269
+ rngs=rngs,
270
+ method=_encoder_forward,
271
+ )
272
+
273
+ def decode(
274
+ self,
275
+ input_ids,
276
+ encoder_outputs,
277
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
278
+ attention_mask: Optional[jnp.ndarray] = None,
279
+ position_ids: Optional[jnp.ndarray] = None,
280
+ past_key_values: dict = None,
281
+ output_attentions: Optional[bool] = None,
282
+ output_hidden_states: Optional[bool] = None,
283
+ return_dict: Optional[bool] = None,
284
+ train: bool = False,
285
+ params: dict = None,
286
+ dropout_rng: PRNGKey = None,
287
+ ):
288
+
289
+ output_attentions = (
290
+ output_attentions
291
+ if output_attentions is not None
292
+ else self.config.output_attentions
293
+ )
294
+ output_hidden_states = (
295
+ output_hidden_states
296
+ if output_hidden_states is not None
297
+ else self.config.output_hidden_states
298
+ )
299
+ return_dict = (
300
+ return_dict if return_dict is not None else self.config.return_dict
301
+ )
302
+
303
+ encoder_hidden_states = encoder_outputs[0]
304
+ if encoder_attention_mask is None:
305
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
306
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
307
+
308
+ batch_size, sequence_length = input_ids.shape
309
+ if attention_mask is None:
310
+ attention_mask = jnp.ones((batch_size, sequence_length))
311
+
312
+ if position_ids is None:
313
+ if past_key_values is not None:
314
+ raise ValueError(
315
+ "Make sure to provide `position_ids` when passing `past_key_values`."
316
+ )
317
+
318
+ position_ids = jnp.broadcast_to(
319
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
320
+ )
321
+
322
+ # Handle any PRNG if needed
323
+ rngs = {}
324
+ if dropout_rng is not None:
325
+ rngs["dropout"] = dropout_rng
326
+
327
+ inputs = {"params": params or self.params}
328
+
329
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
330
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
331
+ # it can be changed by FlaxGPT2Attention module
332
+ if past_key_values:
333
+ inputs["cache"] = past_key_values
334
+ mutable = ["cache"]
335
+ else:
336
+ mutable = False
337
+
338
+ def _decoder_forward(
339
+ module,
340
+ input_ids,
341
+ attention_mask,
342
+ position_ids,
343
+ **kwargs,
344
+ ):
345
+ decoder_module = module._get_decoder_module()
346
+ return decoder_module(
347
+ input_ids,
348
+ attention_mask,
349
+ position_ids,
350
+ **kwargs,
351
+ )
352
+
353
+ outputs = self.module.apply(
354
+ inputs,
355
+ input_ids=jnp.array(input_ids, dtype="i4"),
356
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
357
+ position_ids=jnp.array(position_ids, dtype="i4"),
358
+ encoder_hidden_states=encoder_hidden_states,
359
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
360
+ output_attentions=output_attentions,
361
+ output_hidden_states=output_hidden_states,
362
+ return_dict=return_dict,
363
+ deterministic=not train,
364
+ rngs=rngs,
365
+ mutable=mutable,
366
+ method=_decoder_forward,
367
+ )
368
+
369
+ # add updated cache to model output
370
+ if past_key_values is not None and return_dict:
371
+ outputs, past = outputs
372
+ outputs["past_key_values"] = unfreeze(past["cache"])
373
+ return outputs
374
+ elif past_key_values is not None and not return_dict:
375
+ outputs, past = outputs
376
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
377
+
378
+ return outputs
379
+
380
+ def __call__(
381
+ self,
382
+ pixel_values: jnp.ndarray,
383
+ input_ids: Optional[jnp.ndarray] = None,
384
+ attention_mask: Optional[jnp.ndarray] = None,
385
+ position_ids: Optional[jnp.ndarray] = None,
386
+ output_attentions: Optional[bool] = None,
387
+ output_hidden_states: Optional[bool] = None,
388
+ return_dict: Optional[bool] = None,
389
+ train: bool = False,
390
+ params: dict = None,
391
+ dropout_rng: PRNGKey = None,
392
+ ):
393
+ output_attentions = (
394
+ output_attentions
395
+ if output_attentions is not None
396
+ else self.config.output_attentions
397
+ )
398
+ output_hidden_states = (
399
+ output_hidden_states
400
+ if output_hidden_states is not None
401
+ else self.config.output_hidden_states
402
+ )
403
+ return_dict = (
404
+ return_dict if return_dict is not None else self.config.return_dict
405
+ )
406
+
407
+ pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
408
+
409
+ # # prepare encoder inputs
410
+ # if encoder_attention_mask is None:
411
+ # encoder_attention_mask = jnp.ones_like(input_ids)
412
+
413
+ # if position_ids is None:
414
+ # batch_size, sequence_length = input_ids.shape
415
+ # position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
416
+
417
+ # prepare decoder inputs
418
+ # if decoder_input_ids is None:
419
+ # decoder_input_ids = shift_tokens_right(
420
+ # input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
421
+ # ) # TODO: Check how to use this
422
+
423
+ if attention_mask is None:
424
+ attention_mask = jnp.ones_like(input_ids)
425
+ if position_ids is None:
426
+ batch_size, sequence_length = input_ids.shape
427
+ position_ids = jnp.broadcast_to(
428
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
429
+ )
430
+
431
+ # Handle any PRNG if needed
432
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
433
+
434
+ return self.module.apply(
435
+ {"params": params or self.params},
436
+ pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
437
+ input_ids=jnp.array(input_ids, dtype="i4"),
438
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
439
+ position_ids=jnp.array(position_ids, dtype="i4"),
440
+ output_attentions=output_attentions,
441
+ output_hidden_states=output_hidden_states,
442
+ return_dict=return_dict,
443
+ deterministic=not train,
444
+ rngs=rngs,
445
+ )
446
+
447
+
448
+ class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
449
+ module_class = FlaxViTGPT2LMForConditionalGenerationModule
450
+ dtype: jnp.dtype = jnp.float32
451
+
452
+ def decode(
453
+ self,
454
+ input_ids,
455
+ encoder_outputs,
456
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
457
+ attention_mask: Optional[jnp.ndarray] = None,
458
+ position_ids: Optional[jnp.ndarray] = None,
459
+ past_key_values: dict = None,
460
+ output_attentions: Optional[bool] = None,
461
+ output_hidden_states: Optional[bool] = None,
462
+ return_dict: Optional[bool] = None,
463
+ deterministic: bool = True,
464
+ params: dict = None,
465
+ dropout_rng: PRNGKey = None,
466
+ ):
467
+ output_attentions = (
468
+ output_attentions
469
+ if output_attentions is not None
470
+ else self.config.output_attentions
471
+ )
472
+ output_hidden_states = (
473
+ output_hidden_states
474
+ if output_hidden_states is not None
475
+ else self.config.output_hidden_states
476
+ )
477
+ return_dict = (
478
+ return_dict if return_dict is not None else self.config.return_dict
479
+ )
480
+
481
+ encoder_hidden_states = encoder_outputs[0]
482
+ if encoder_attention_mask is None:
483
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
484
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
485
+
486
+ batch_size, sequence_length = input_ids.shape
487
+ if attention_mask is None:
488
+ attention_mask = jnp.ones((batch_size, sequence_length))
489
+
490
+ if position_ids is None:
491
+ if past_key_values is not None:
492
+ raise ValueError(
493
+ "Make sure to provide `position_ids` when passing `past_key_values`."
494
+ )
495
+
496
+ position_ids = jnp.broadcast_to(
497
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
498
+ )
499
+
500
+ # Handle any PRNG if needed
501
+ rngs = {}
502
+ if dropout_rng is not None:
503
+ rngs["dropout"] = dropout_rng
504
+
505
+ inputs = {"params": params or self.params}
506
+
507
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
508
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
509
+ # it can be changed by FlaxGPT2Attention module
510
+ if past_key_values:
511
+ inputs["cache"] = past_key_values
512
+ mutable = ["cache"]
513
+ else:
514
+ mutable = False
515
+
516
+ def _decoder_forward(
517
+ module,
518
+ input_ids,
519
+ attention_mask,
520
+ position_ids,
521
+ **kwargs,
522
+ ):
523
+ decoder_module = module._get_decoder_module()
524
+ outputs = decoder_module(
525
+ input_ids,
526
+ attention_mask,
527
+ position_ids,
528
+ **kwargs,
529
+ )
530
+ lm_logits = outputs[0]
531
+
532
+ return lm_logits, outputs
533
+
534
+ outputs = self.module.apply(
535
+ inputs,
536
+ input_ids=jnp.array(input_ids, dtype="i4"),
537
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
538
+ position_ids=jnp.array(position_ids, dtype="i4"),
539
+ encoder_hidden_states=encoder_hidden_states,
540
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
541
+ output_attentions=output_attentions,
542
+ output_hidden_states=output_hidden_states,
543
+ return_dict=return_dict,
544
+ deterministic=deterministic,
545
+ rngs=rngs,
546
+ mutable=mutable,
547
+ method=_decoder_forward,
548
+ )
549
+
550
+ if past_key_values is None:
551
+ lm_logits, outputs = outputs
552
+ else:
553
+ (lm_logits, outputs), past = outputs
554
+
555
+ if return_dict:
556
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
557
+ logits=lm_logits,
558
+ hidden_states=outputs.decoder_hidden_states,
559
+ attentions=outputs.decoder_attentions,
560
+ cross_attentions=outputs.cross_attentions,
561
+ )
562
+ else:
563
+ outputs = (lm_logits,) + outputs[1:]
564
+
565
+ # add updated cache to model output
566
+ if past_key_values is not None and return_dict:
567
+ outputs["past_key_values"] = unfreeze(past["cache"])
568
+ return outputs
569
+ elif past_key_values is not None and not return_dict:
570
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
571
+
572
+ return outputs
573
+
574
+ def prepare_inputs_for_generation(
575
+ self,
576
+ input_ids,
577
+ max_length,
578
+ encoder_attention_mask: Optional[jnp.DeviceArray] = None,
579
+ attention_mask: Optional[jnp.DeviceArray] = None,
580
+ encoder_outputs=None,
581
+ **kwargs,
582
+ ):
583
+ # initializing the cache
584
+ batch_size, seq_length = input_ids.shape
585
+
586
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
587
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
588
+ # But since the decoder uses a causal mask, those positions are masked anyways.
589
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
590
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
591
+ if attention_mask is not None:
592
+ position_ids = attention_mask.cumsum(axis=-1) - 1
593
+ extended_attention_mask = lax.dynamic_update_slice(
594
+ extended_attention_mask, attention_mask, (0, 0)
595
+ )
596
+ else:
597
+ position_ids = jnp.broadcast_to(
598
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
599
+ )
600
+
601
+ return {
602
+ "past_key_values": past_key_values,
603
+ "encoder_outputs": encoder_outputs,
604
+ "encoder_attention_mask": encoder_attention_mask,
605
+ "attention_mask": extended_attention_mask,
606
+ "position_ids": position_ids,
607
+ }
608
+
609
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
610
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
611
+ model_kwargs["position_ids"] = (
612
+ model_kwargs["position_ids"][:, -1:] + 1
613
+ )
614
+ return model_kwargs
615
+
616
+ @classmethod
617
+ def from_vit_gpt2_pretrained(
618
+ cls,
619
+ vit_model_name_or_path: str = None,
620
+ gpt2_model_name_or_path: str = None,
621
+ *model_args,
622
+ **kwargs,
623
+ ) -> FlaxViTGPT2LMPreTrainedModel:
624
+
625
+ kwargs_gpt2 = {
626
+ argument[len("gpt2_") :]: value
627
+ for argument, value in kwargs.items()
628
+ if argument.startswith("gpt2_")
629
+ }
630
+
631
+ kwargs_vit = {
632
+ argument[len("vit_") :]: value
633
+ for argument, value in kwargs.items()
634
+ if argument.startswith("vit_")
635
+ }
636
+
637
+ # remove gpt2, vit kwargs from kwargs
638
+ for key in kwargs_gpt2.keys():
639
+ del kwargs["gpt2_" + key]
640
+ for key in kwargs_vit.keys():
641
+ del kwargs["vit_" + key]
642
+
643
+ # Load and initialize the gpt2 and vit model
644
+ gpt2_model = kwargs_gpt2.pop("model", None)
645
+ if gpt2_model is None:
646
+ assert (
647
+ gpt2_model_name_or_path is not None
648
+ ), "If `model` is not defined as an argument, a `gpt2_model_name_or_path` has to be defined"
649
+
650
+ if "config" not in kwargs_gpt2:
651
+ gpt2_config = GPT2Config.from_pretrained(gpt2_model_name_or_path)
652
+ kwargs_gpt2["config"] = gpt2_config
653
+
654
+ kwargs_gpt2["config"].add_cross_attention = True
655
+ gpt2_model = FlaxGPT2LMHeadModel.from_pretrained(
656
+ gpt2_model_name_or_path, *model_args, **kwargs_gpt2
657
+ )
658
+
659
+ vit_model = kwargs_vit.pop("model", None)
660
+ if vit_model is None:
661
+ assert (
662
+ vit_model_name_or_path is not None
663
+ ), "If `model` is not defined as an argument, a `vit_model_name_or_path` has to be defined"
664
+
665
+ if "config" not in kwargs_vit:
666
+ vit_config = ViTConfig.from_pretrained(vit_model_name_or_path)
667
+ kwargs_vit["config"] = vit_config
668
+
669
+ vit_model = FlaxViTModel.from_pretrained(
670
+ vit_model_name_or_path, *model_args, **kwargs_vit
671
+ )
672
+
673
+ # instantiate config with corresponding kwargs
674
+ dtype = kwargs.pop("dtype", jnp.float32)
675
+ config = ViTGPT2Config.from_vit_gpt2_configs(
676
+ vit_model.config, gpt2_model.config, **kwargs
677
+ )
678
+
679
+ # init model
680
+ model = cls(config, *model_args, dtype=dtype, **kwargs)
681
+ model.params["model"]["encoder"] = vit_model.params
682
+ model.params["model"]["decoder"] = gpt2_model.params
683
+
684
+ return model