File size: 5,358 Bytes
1d2ae3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# coding=utf-8
# Copyright 2023 The T5X Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for partitioning."""

from typing import Any, Mapping, MutableMapping, Optional, Tuple

import flax.core
import flax.serialization
import flax.struct
import jax.numpy as jnp
from flax import traverse_util
from flax.core import scope as flax_scope
from flax.linen import partitioning as flax_partitioning


EMPTY_DICT = flax.core.freeze({})
FrozenDict = flax_scope.FrozenDict
FrozenVariableDict = flax_scope.FrozenVariableDict
MutableVariableDict = flax_scope.MutableVariableDict
VariableDict = flax_scope.VariableDict


def _validate_params_axes(params_axes, params):
    axis_names = flax_partitioning.get_axis_names(params_axes)
    missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set(
        traverse_util.flatten_dict(axis_names, sep="/")
    )
    if missing_params_axes:
        raise ValueError(f"Missing axis names for parameters: {missing_params_axes}")


def _split_variables_and_axes(variables_and_axes: FrozenVariableDict) -> Tuple[FrozenVariableDict, FrozenVariableDict]:
    """Splits `variables_and_axes` into two separate dicts with the same keys."""
    # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`.
    variables = {}
    axes = {}
    for k, v in variables_and_axes.items():
        if k.endswith("_axes"):
            axes[k[:-5]] = v  # k without "_axes".
            _validate_params_axes(v, variables_and_axes[k[:-5]])  # k without "_axes".
        else:
            variables[k] = v
    return flax.core.freeze(variables), flax.core.freeze(axes)


class InferenceState(flax.struct.PyTreeNode):
    """State compatible with FlaxOptimTrainState without optimizer state."""

    step: jnp.ndarray
    params: flax_scope.FrozenVariableDict
    params_axes: Optional[flax_scope.FrozenVariableDict] = None
    flax_mutables: flax_scope.FrozenDict = EMPTY_DICT
    flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None

    @classmethod
    def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
        other_variables, params = model_variables.pop("params")
        if "params_axes" in other_variables:
            other_variables, params_axes = other_variables.pop("params_axes")
            _validate_params_axes(params_axes, params)
        else:
            params_axes = None

        # Split other_variables into mutables and their corresponding axes.
        flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
        flax_mutables_axes = flax_mutables_axes or None
        return InferenceState(
            step=jnp.array(0),
            params=params,
            params_axes=params_axes,
            flax_mutables=flax_mutables,
            flax_mutables_axes=flax_mutables_axes,
        )

    @property
    def param_states(self) -> FrozenVariableDict:
        """The optimizer states of the parameters as a PyTree."""
        raise NotImplementedError("InferenceState has no optimizer states.")

    def apply_gradient(self, *args, **kwargs) -> "InferenceState":
        raise NotImplementedError("InferenceState does not support `apply_gradient`.")

    def state_dict(self) -> MutableMapping[str, Any]:
        state_dict = {"target": flax.core.unfreeze(self.params), "state": {"step": self.step}}
        if self.flax_mutables:
            state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables)
        return state_dict

    def replace_step(self, step: jnp.ndarray) -> "InferenceState":
        return self.replace(step=step)

    def replace_params(self, params: FrozenVariableDict) -> "InferenceState":
        return self.replace(params=params)

    def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState":
        return self.replace(flax_mutables=flax_mutables)

    def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState":
        return self.replace(
            params=flax.core.freeze(state_dict["target"]),
            step=state_dict["state"]["step"],
            flax_mutables=flax.core.freeze(state_dict["flax_mutables"])
            if "flax_mutables" in state_dict
            else EMPTY_DICT,
        )

    def as_logical_axes(self) -> "InferenceState":
        # Set step to None so that when the logical axes are processed by the
        # flax.partitioning.logical_to_mesh_axes function, it will be skipped
        # because jax.tree_map will short circut and never call the function on the
        # step.
        flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT
        return InferenceState(
            step=None,
            params=flax_partitioning.get_axis_names(self.params_axes),
            flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes),
        )