cognitivess commited on
Commit
b3d5553
·
verified ·
1 Parent(s): 980a0c2

Rename cognitivess_model/modeling_flax_Cognitivess.py to cognitivess_model/modeling_flax_cognitivess.py

Browse files
cognitivess_model/modeling_flax_Cognitivess.py DELETED
@@ -1,744 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 Cognitivess and the HuggingFace Inc. team. All rights reserved.
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- """Flax Cognitivess model."""
15
-
16
- from functools import partial
17
- from typing import Optional, Tuple
18
-
19
- import flax.linen as nn
20
- import jax
21
- import jax.numpy as jnp
22
- import numpy as np
23
- from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
24
- from flax.linen import combine_masks, make_causal_mask
25
- from flax.linen.attention import dot_product_attention_weights
26
- from flax.traverse_util import flatten_dict, unflatten_dict
27
- from jax import lax
28
-
29
- from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
30
- from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
31
- from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
32
- from .configuration_Cognitivess import CognitivessConfig
33
-
34
-
35
- logger = logging.get_logger(__name__)
36
-
37
- _CONFIG_FOR_DOC = "CognitivessConfig"
38
- _CHECKPOINT_FOR_DOC = "afmck/testing-Cognitivess-tiny"
39
- _REAL_CHECKPOINT_FOR_DOC = "openlm-research/open_Cognitivess_3b_v2"
40
-
41
- Cognitivess_START_DOCSTRING = r"""
42
-
43
- This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
44
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
45
- etc.)
46
-
47
- This model is also a Flax Linen
48
- [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
49
- regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
50
-
51
- Finally, this model supports inherent JAX features such as:
52
-
53
- - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
54
- - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
55
- - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
56
- - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
57
-
58
- Parameters:
59
- config ([`CognitivessConfig`]): Model configuration class with all the parameters of the model.
60
- Initializing with a config file does not load the weights associated with the model, only the
61
- configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
62
- dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
63
- The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16`, or
64
- `jax.numpy.bfloat16`.
65
-
66
- This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
67
- specified all the computation will be performed with the given `dtype`.
68
-
69
- **Note that this only specifies the dtype of the computation and does not influence the dtype of model
70
- parameters.**
71
-
72
- If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
73
- [`~FlaxPreTrainedModel.to_bf16`].
74
- """
75
-
76
- Cognitivess_INPUTS_DOCSTRING = r"""
77
- Args:
78
- input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
79
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
80
- it.
81
-
82
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
83
- [`PreTrainedTokenizer.__call__`] for details.
84
-
85
- [What are input IDs?](../glossary#input-ids)
86
- attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
87
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
88
-
89
- - 1 for tokens that are **not masked**,
90
- - 0 for tokens that are **masked**.
91
-
92
- [What are attention masks?](../glossary#attention-mask)
93
-
94
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
95
- [`PreTrainedTokenizer.__call__`] for details.
96
-
97
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
98
- `past_key_values`).
99
-
100
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
101
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
102
- information on the default strategy.
103
-
104
- - 1 indicates the head is **not masked**,
105
- - 0 indicates the head is **masked**.
106
- position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
107
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
108
- config.n_positions - 1]`.
109
-
110
- [What are position IDs?](../glossary#position-ids)
111
- past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
112
- Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
113
- auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
114
- output_attentions (`bool`, *optional*):
115
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
116
- tensors for more detail.
117
- output_hidden_states (`bool`, *optional*):
118
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
119
- more detail.
120
- return_dict (`bool`, *optional*):
121
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
122
- """
123
-
124
-
125
- def create_sinusoidal_positions(num_pos, dim):
126
- inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
127
- freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
128
-
129
- emb = np.concatenate((freqs, freqs), axis=-1)
130
- out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
131
- return jnp.array(out[:, :, :num_pos])
132
-
133
-
134
- def rotate_half(tensor):
135
- """Rotates half the hidden dims of the input."""
136
- rotate_half_tensor = jnp.concatenate(
137
- (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1
138
- )
139
- return rotate_half_tensor
140
-
141
-
142
- def apply_rotary_pos_emb(tensor, sin_pos, cos_pos):
143
- return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos)
144
-
145
-
146
- class FlaxCognitivessRMSNorm(nn.Module):
147
- config: CognitivessConfig
148
- dtype: jnp.dtype = jnp.float32
149
-
150
- def setup(self):
151
- self.epsilon = self.config.rms_norm_eps
152
- self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size)
153
-
154
- def __call__(self, hidden_states):
155
- variance = jnp.asarray(hidden_states, dtype=jnp.float32)
156
- variance = jnp.power(variance, 2)
157
- variance = variance.mean(-1, keepdims=True)
158
- # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt`
159
- hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)
160
-
161
- return self.weight * jnp.asarray(hidden_states, dtype=self.dtype)
162
-
163
-
164
- class FlaxCognitivessRotaryEmbedding(nn.Module):
165
- config: CognitivessConfig
166
- dtype: jnp.dtype = jnp.float32
167
-
168
- def setup(self):
169
- head_dim = self.config.hidden_size // self.config.num_attention_heads
170
- self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim)
171
-
172
- def __call__(self, key, query, position_ids):
173
- sincos = self.sincos[position_ids]
174
- sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1)
175
-
176
- key = apply_rotary_pos_emb(key, sin_pos, cos_pos)
177
- query = apply_rotary_pos_emb(query, sin_pos, cos_pos)
178
-
179
- key = jnp.asarray(key, dtype=self.dtype)
180
- query = jnp.asarray(query, dtype=self.dtype)
181
-
182
- return key, query
183
-
184
-
185
- class FlaxCognitivessAttention(nn.Module):
186
- config: CognitivessConfig
187
- dtype: jnp.dtype = jnp.float32
188
- causal: bool = True
189
- is_cross_attention: bool = False
190
-
191
- def setup(self):
192
- config = self.config
193
- self.embed_dim = config.hidden_size
194
- self.num_heads = config.num_attention_heads
195
- self.head_dim = self.embed_dim // self.num_heads
196
- self.num_key_value_heads = config.num_key_value_heads
197
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
198
- self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
199
-
200
- dense = partial(
201
- nn.Dense,
202
- use_bias=config.attention_bias,
203
- dtype=self.dtype,
204
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
205
- )
206
-
207
- self.q_proj = dense(self.num_heads * self.head_dim)
208
- self.k_proj = dense(self.num_key_value_heads * self.head_dim)
209
- self.v_proj = dense(self.num_key_value_heads * self.head_dim)
210
- self.o_proj = dense(self.embed_dim)
211
- if (self.head_dim * self.num_heads) != self.embed_dim:
212
- raise ValueError(
213
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.embed_dim}"
214
- f" and `num_heads`: {self.num_heads})."
215
- )
216
-
217
- self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
218
- self.rotary_emb = FlaxCognitivessRotaryEmbedding(config, dtype=self.dtype)
219
-
220
- def _split_heads(self, hidden_states, num_heads):
221
- return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
222
-
223
- def _merge_heads(self, hidden_states):
224
- return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
225
-
226
- @nn.compact
227
- # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache
228
- def _concatenate_to_cache(self, key, value, query, attention_mask):
229
- """
230
- This function takes projected key, value states from a single input token and concatenates the states to cached
231
- states from previous steps. This function is slighly adapted from the official Flax repository:
232
- https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
233
- """
234
- # detect if we're initializing by absence of existing cache data.
235
- is_initialized = self.has_variable("cache", "cached_key")
236
- cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
237
- cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
238
- cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
239
-
240
- if is_initialized:
241
- *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
242
- # update key, value caches with our new 1d spatial slices
243
- cur_index = cache_index.value
244
- indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
245
- key = lax.dynamic_update_slice(cached_key.value, key, indices)
246
- value = lax.dynamic_update_slice(cached_value.value, value, indices)
247
- cached_key.value = key
248
- cached_value.value = value
249
- num_updated_cache_vectors = query.shape[1]
250
- cache_index.value = cache_index.value + num_updated_cache_vectors
251
- # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
252
- pad_mask = jnp.broadcast_to(
253
- jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
254
- tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
255
- )
256
- attention_mask = combine_masks(pad_mask, attention_mask)
257
- return key, value, attention_mask
258
-
259
- def __call__(
260
- self,
261
- hidden_states,
262
- attention_mask,
263
- position_ids,
264
- deterministic: bool = True,
265
- init_cache: bool = False,
266
- output_attentions: bool = False,
267
- ):
268
- query = self.q_proj(hidden_states)
269
- key = self.k_proj(hidden_states)
270
- value = self.v_proj(hidden_states)
271
-
272
- query = self._split_heads(query, self.num_heads)
273
- key = self._split_heads(key, self.num_key_value_heads)
274
- value = self._split_heads(value, self.num_key_value_heads)
275
-
276
- key, query = self.rotary_emb(key, query, position_ids)
277
-
278
- query_length, key_length = query.shape[1], key.shape[1]
279
-
280
- if self.has_variable("cache", "cached_key"):
281
- mask_shift = self.variables["cache"]["cache_index"]
282
- max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
283
- causal_mask = lax.dynamic_slice(
284
- self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
285
- )
286
- else:
287
- causal_mask = self.causal_mask[:, :, :query_length, :key_length]
288
-
289
- batch_size = hidden_states.shape[0]
290
- causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
291
-
292
- attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
293
- attention_mask = combine_masks(attention_mask, causal_mask)
294
-
295
- dropout_rng = None
296
- if not deterministic and self.config.attention_dropout > 0.0:
297
- dropout_rng = self.make_rng("dropout")
298
-
299
- # During fast autoregressive decoding, we feed one position at a time,
300
- # and cache the keys and values step by step.
301
- if self.has_variable("cache", "cached_key") or init_cache:
302
- key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
303
-
304
- key = jnp.repeat(key, self.num_key_value_groups, axis=2)
305
- value = jnp.repeat(value, self.num_key_value_groups, axis=2)
306
-
307
- # transform boolean mask into float mask
308
- attention_bias = lax.select(
309
- attention_mask > 0,
310
- jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
311
- jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
312
- )
313
-
314
- # usual dot product attention
315
- attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
316
- attn_weights = dot_product_attention_weights(
317
- query,
318
- key,
319
- bias=attention_bias,
320
- dropout_rng=dropout_rng,
321
- dropout_rate=self.config.attention_dropout,
322
- deterministic=deterministic,
323
- dtype=attention_dtype,
324
- )
325
-
326
- if self.attention_softmax_in_fp32:
327
- attn_weights = attn_weights.astype(self.dtype)
328
-
329
- attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
330
- attn_output = self._merge_heads(attn_output)
331
- attn_output = self.o_proj(attn_output)
332
-
333
- outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
334
- return outputs
335
-
336
-
337
- class FlaxCognitivessMLP(nn.Module):
338
- config: CognitivessConfig
339
- dtype: jnp.dtype = jnp.float32
340
-
341
- def setup(self):
342
- embed_dim = self.config.hidden_size
343
- inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim
344
-
345
- kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
346
- self.act = ACT2FN[self.config.hidden_act]
347
-
348
- self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
349
- self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
350
- self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init)
351
-
352
- def __call__(self, hidden_states):
353
- up_proj_states = self.up_proj(hidden_states)
354
- gate_states = self.act(self.gate_proj(hidden_states))
355
-
356
- hidden_states = self.down_proj(up_proj_states * gate_states)
357
- return hidden_states
358
-
359
-
360
- class FlaxCognitivessDecoderLayer(nn.Module):
361
- config: CognitivessConfig
362
- dtype: jnp.dtype = jnp.float32
363
-
364
- def setup(self):
365
- self.input_layernorm = FlaxCognitivessRMSNorm(self.config, dtype=self.dtype)
366
- self.self_attn = FlaxCognitivessAttention(self.config, dtype=self.dtype)
367
- self.post_attention_layernorm = FlaxCognitivessRMSNorm(self.config, dtype=self.dtype)
368
- self.mlp = FlaxCognitivessMLP(self.config, dtype=self.dtype)
369
-
370
- def __call__(
371
- self,
372
- hidden_states,
373
- attention_mask=None,
374
- position_ids=None,
375
- deterministic: bool = True,
376
- init_cache: bool = False,
377
- output_attentions: bool = False,
378
- ):
379
- residual = hidden_states
380
- hidden_states = self.input_layernorm(hidden_states)
381
- outputs = self.self_attn(
382
- hidden_states,
383
- attention_mask=attention_mask,
384
- position_ids=position_ids,
385
- deterministic=deterministic,
386
- init_cache=init_cache,
387
- output_attentions=output_attentions,
388
- )
389
- # residual connection
390
- attn_output = outputs[0]
391
- hidden_states = residual + attn_output
392
-
393
- residual = hidden_states
394
- hidden_states = self.post_attention_layernorm(hidden_states)
395
- hidden_states = self.mlp(hidden_states)
396
- # residual connection
397
- hidden_states = residual + hidden_states
398
-
399
- return (hidden_states,) + outputs[1:]
400
-
401
-
402
- # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Cognitivess, GPT_NEO->Cognitivess, transformer->model
403
- class FlaxCognitivessPreTrainedModel(FlaxPreTrainedModel):
404
- """
405
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
406
- models.
407
- """
408
-
409
- config_class = CognitivessConfig
410
- base_model_prefix = "model"
411
- module_class: nn.Module = None
412
-
413
- def __init__(
414
- self,
415
- config: CognitivessConfig,
416
- input_shape: Tuple = (1, 1),
417
- seed: int = 0,
418
- dtype: jnp.dtype = jnp.float32,
419
- _do_init: bool = True,
420
- **kwargs,
421
- ):
422
- module = self.module_class(config=config, dtype=dtype, **kwargs)
423
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
424
-
425
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
426
- # init input tensors
427
- input_ids = jnp.zeros(input_shape, dtype="i4")
428
- attention_mask = jnp.ones_like(input_ids)
429
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
430
- params_rng, dropout_rng = jax.random.split(rng)
431
- rngs = {"params": params_rng, "dropout": dropout_rng}
432
-
433
- random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
434
-
435
- if params is not None:
436
- random_params = flatten_dict(unfreeze(random_params))
437
- params = flatten_dict(unfreeze(params))
438
- for missing_key in self._missing_keys:
439
- params[missing_key] = random_params[missing_key]
440
- self._missing_keys = set()
441
- return freeze(unflatten_dict(params))
442
- else:
443
- return random_params
444
-
445
- def init_cache(self, batch_size, max_length):
446
- r"""
447
- Args:
448
- batch_size (`int`):
449
- batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
450
- max_length (`int`):
451
- maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
452
- cache.
453
- """
454
- # init input variables to retrieve cache
455
- input_ids = jnp.ones((batch_size, max_length))
456
- attention_mask = jnp.ones_like(input_ids)
457
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
458
-
459
- init_variables = self.module.init(
460
- jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
461
- )
462
- return unfreeze(init_variables["cache"])
463
-
464
- @add_start_docstrings_to_model_forward(Cognitivess_INPUTS_DOCSTRING)
465
- def __call__(
466
- self,
467
- input_ids,
468
- attention_mask=None,
469
- position_ids=None,
470
- params: dict = None,
471
- past_key_values: dict = None,
472
- dropout_rng: jax.random.PRNGKey = None,
473
- train: bool = False,
474
- output_attentions: Optional[bool] = None,
475
- output_hidden_states: Optional[bool] = None,
476
- return_dict: Optional[bool] = None,
477
- ):
478
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
479
- output_hidden_states = (
480
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
481
- )
482
- return_dict = return_dict if return_dict is not None else self.config.return_dict
483
-
484
- batch_size, sequence_length = input_ids.shape
485
-
486
- if position_ids is None:
487
- if past_key_values is not None:
488
- raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
489
-
490
- position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
491
-
492
- if attention_mask is None:
493
- attention_mask = jnp.ones((batch_size, sequence_length))
494
-
495
- # Handle any PRNG if needed
496
- rngs = {}
497
- if dropout_rng is not None:
498
- rngs["dropout"] = dropout_rng
499
-
500
- inputs = {"params": params or self.params}
501
-
502
- # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxCognitivessAttention module
503
- if past_key_values:
504
- inputs["cache"] = past_key_values
505
- mutable = ["cache"]
506
- else:
507
- mutable = False
508
-
509
- outputs = self.module.apply(
510
- inputs,
511
- jnp.array(input_ids, dtype="i4"),
512
- jnp.array(attention_mask, dtype="i4"),
513
- jnp.array(position_ids, dtype="i4"),
514
- not train,
515
- False,
516
- output_attentions,
517
- output_hidden_states,
518
- return_dict,
519
- rngs=rngs,
520
- mutable=mutable,
521
- )
522
-
523
- # add updated cache to model output
524
- if past_key_values is not None and return_dict:
525
- outputs, past_key_values = outputs
526
- outputs["past_key_values"] = unfreeze(past_key_values["cache"])
527
- return outputs
528
- elif past_key_values is not None and not return_dict:
529
- outputs, past_key_values = outputs
530
- outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
531
-
532
- return outputs
533
-
534
-
535
- class FlaxCognitivessLayerCollection(nn.Module):
536
- config: CognitivessConfig
537
- dtype: jnp.dtype = jnp.float32
538
-
539
- def setup(self):
540
- self.blocks = [
541
- FlaxCognitivessDecoderLayer(self.config, dtype=self.dtype, name=str(i))
542
- for i in range(self.config.num_hidden_layers)
543
- ]
544
-
545
- def __call__(
546
- self,
547
- hidden_states,
548
- attention_mask=None,
549
- position_ids=None,
550
- deterministic: bool = True,
551
- init_cache: bool = False,
552
- output_attentions: bool = False,
553
- output_hidden_states: bool = False,
554
- return_dict: bool = False,
555
- ):
556
- all_attentions = () if output_attentions else None
557
- all_hidden_states = () if output_hidden_states else None
558
-
559
- for block in self.blocks:
560
- if output_hidden_states:
561
- all_hidden_states += (hidden_states,)
562
- layer_outputs = block(
563
- hidden_states,
564
- attention_mask=attention_mask,
565
- position_ids=position_ids,
566
- deterministic=deterministic,
567
- init_cache=init_cache,
568
- output_attentions=output_attentions,
569
- )
570
- hidden_states = layer_outputs[0]
571
-
572
- if output_attentions:
573
- all_attentions += (layer_outputs[1],)
574
-
575
- # this contains possible `None` values - `FlaxCognitivessModule` will filter them out
576
- outputs = (hidden_states, all_hidden_states, all_attentions)
577
-
578
- return outputs
579
-
580
-
581
- class FlaxCognitivessModule(nn.Module):
582
- config: CognitivessConfig
583
- dtype: jnp.dtype = jnp.float32
584
-
585
- def setup(self):
586
- self.hidden_size = self.config.hidden_size
587
- embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)
588
- self.embed_tokens = nn.Embed(
589
- self.config.vocab_size,
590
- self.hidden_size,
591
- embedding_init=embedding_init,
592
- dtype=self.dtype,
593
- )
594
- self.layers = FlaxCognitivessLayerCollection(self.config, dtype=self.dtype)
595
- self.norm = FlaxCognitivessRMSNorm(self.config, dtype=self.dtype)
596
-
597
- def __call__(
598
- self,
599
- input_ids,
600
- attention_mask=None,
601
- position_ids=None,
602
- deterministic=True,
603
- init_cache: bool = False,
604
- output_attentions: bool = False,
605
- output_hidden_states: bool = False,
606
- return_dict: bool = True,
607
- ):
608
- input_embeds = self.embed_tokens(input_ids.astype("i4"))
609
-
610
- outputs = self.layers(
611
- input_embeds,
612
- position_ids=position_ids,
613
- attention_mask=attention_mask,
614
- deterministic=deterministic,
615
- init_cache=init_cache,
616
- output_attentions=output_attentions,
617
- output_hidden_states=output_hidden_states,
618
- return_dict=return_dict,
619
- )
620
-
621
- hidden_states = outputs[0]
622
- hidden_states = self.norm(hidden_states)
623
-
624
- if output_hidden_states:
625
- all_hidden_states = outputs[1] + (hidden_states,)
626
- outputs = (hidden_states, all_hidden_states) + outputs[2:]
627
- else:
628
- outputs = (hidden_states,) + outputs[1:]
629
-
630
- if not return_dict:
631
- return tuple(v for v in outputs if v is not None)
632
-
633
- return FlaxBaseModelOutput(
634
- last_hidden_state=hidden_states,
635
- hidden_states=outputs[1],
636
- attentions=outputs[-1],
637
- )
638
-
639
-
640
- @add_start_docstrings(
641
- "The bare Cognitivess Model transformer outputting raw hidden-states without any specific head on top.",
642
- Cognitivess_START_DOCSTRING,
643
- )
644
- class FlaxCognitivessModel(FlaxCognitivessPreTrainedModel):
645
- module_class = FlaxCognitivessModule
646
-
647
-
648
- append_call_sample_docstring(
649
- FlaxCognitivessModel,
650
- _CHECKPOINT_FOR_DOC,
651
- FlaxBaseModelOutput,
652
- _CONFIG_FOR_DOC,
653
- real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
654
- )
655
-
656
-
657
- class FlaxCognitivessForCausalLMModule(nn.Module):
658
- config: CognitivessConfig
659
- dtype: jnp.dtype = jnp.float32
660
-
661
- def setup(self):
662
- self.model = FlaxCognitivessModule(self.config, dtype=self.dtype)
663
- self.lm_head = nn.Dense(
664
- self.config.vocab_size,
665
- use_bias=False,
666
- dtype=self.dtype,
667
- kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
668
- )
669
-
670
- def __call__(
671
- self,
672
- input_ids,
673
- attention_mask=None,
674
- position_ids=None,
675
- deterministic: bool = True,
676
- init_cache: bool = False,
677
- output_attentions: bool = False,
678
- output_hidden_states: bool = False,
679
- return_dict: bool = True,
680
- ):
681
- outputs = self.model(
682
- input_ids,
683
- position_ids=position_ids,
684
- attention_mask=attention_mask,
685
- deterministic=deterministic,
686
- init_cache=init_cache,
687
- output_attentions=output_attentions,
688
- output_hidden_states=output_hidden_states,
689
- return_dict=return_dict,
690
- )
691
-
692
- hidden_states = outputs[0]
693
- lm_logits = self.lm_head(hidden_states)
694
-
695
- if not return_dict:
696
- return (lm_logits,) + outputs[1:]
697
-
698
- return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
699
-
700
-
701
- @add_start_docstrings(
702
- """
703
- The Cognitivess Model transformer with a language modeling head (linear layer) on top.
704
- """,
705
- Cognitivess_START_DOCSTRING,
706
- )
707
- # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Cognitivess
708
- class FlaxCognitivessForCausalLM(FlaxCognitivessPreTrainedModel):
709
- module_class = FlaxCognitivessForCausalLMModule
710
-
711
- def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
712
- # initializing the cache
713
- batch_size, seq_length = input_ids.shape
714
-
715
- past_key_values = self.init_cache(batch_size, max_length)
716
- # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
717
- # But since Cognitivess uses a causal mask, those positions are masked anyways.
718
- # Thus we can create a single static attention_mask here, which is more efficient for compilation
719
- extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
720
- if attention_mask is not None:
721
- position_ids = attention_mask.cumsum(axis=-1) - 1
722
- extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
723
- else:
724
- position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
725
-
726
- return {
727
- "past_key_values": past_key_values,
728
- "attention_mask": extended_attention_mask,
729
- "position_ids": position_ids,
730
- }
731
-
732
- def update_inputs_for_generation(self, model_outputs, model_kwargs):
733
- model_kwargs["past_key_values"] = model_outputs.past_key_values
734
- model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
735
- return model_kwargs
736
-
737
-
738
- append_call_sample_docstring(
739
- FlaxCognitivessForCausalLM,
740
- _CHECKPOINT_FOR_DOC,
741
- FlaxCausalLMOutput,
742
- _CONFIG_FOR_DOC,
743
- real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
744
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cognitivess_model/modeling_flax_cognitivess.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from transformers.models.llama.modeling_flax_llama import (
2
+ FlaxLlamaForCausalLM as FlaxCognitivessForCausalLM,
3
+ FlaxLlamaModel as FlaxCognitivessModel,
4
+ )
5
+
6
+ # You can add more specific code or changes here if needed.