Spaces:
Build error
Build error
File size: 18,489 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 |
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pjit partitioner with Mixture of Experts overrides."""
from typing import Any, Callable, Optional, Sequence, Union
from absl import logging
from flax import core as flax_core
import jax
import numpy as np
from t5x import adafactor
from t5x import partitioning as t5x_partitioning
from t5x import train_state as train_state_lib
from t5x.contrib.moe import training_utils
DataLayout = t5x_partitioning.DataLayout
FlaxOptimTrainState = train_state_lib.FlaxOptimTrainState
HardwareMesh = t5x_partitioning.HardwareMesh
InferenceState = train_state_lib.InferenceState
LogicalAxisRules = t5x_partitioning.LogicalAxisRules
PartitionSpec = t5x_partitioning.PartitionSpec
Pytree = Any
TrainState = train_state_lib.TrainState
class MoePjitPartitioner(t5x_partitioning.PjitPartitioner):
"""Pjit partitioner with overrides for Mixture of Experts support.
This MoE partitioner has two overrides relative to the default partitioner:
(1) It prepends an 'expert' axis to all MoE optimizer state terms, so that
they are sharded along the 'expert' axis; see get_logical_axes().
(2) In cases where model parallelism is used and the number of experts is less
than the number of devices, we treat the 'model' axis as a secondary data
axis. This allows us to decouple expert parallelism ('data' mesh axis)
from data parallelism ('data' and 'model' axes).
"""
def __init__(self,
num_experts: int,
num_partitions: Optional[int] = None,
model_parallel_submesh: Optional[HardwareMesh] = None,
params_on_devices: bool = True,
logical_axis_rules: Optional[LogicalAxisRules] = None,
state_filter_fn: Optional[Callable[[str], bool]] = None):
"""Configures the partitioner.
Args:
num_experts: Total number of experts across all devices.
num_partitions: Specifies the size of the model parallel submesh to be
automatically selected for the current topology. See
`model_parallel_submesh` for details on how this submesh is used.
Mutually exclusive with `model_parallel_submesh`.
model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh
model-parallel device tile -- an axis of accelerator parallelism
orthogonal to data parallelism. See t5x/partitioning.py for details.
This argument is mutually exclusive with `num_partitions`.
params_on_devices: Whether to keep the params on devices. If False, params
stay in the host memory.
logical_axis_rules: A priority-ordered sequence of KV tuples that maps
logical axis names to either `None` (not sharded), 'model' (to shard
across the model-parallel submesh), or 'data' (to shard across the
data-parallel submesh).
state_filter_fn: Function to identify which optimizer state axis rules
should be overridden to be sharded along the 'expert' axis. If None
(default), Adafactor expert sharding overrides are used.
"""
# If True, treat 'model' axis as secondary data axis.
self.two_data_axes = _override_model_axis(num_experts, num_partitions,
model_parallel_submesh)
if self.two_data_axes:
# Override num_partitions to repurpose the 'model' axis as a secondary
# data axis, along which only the batch is sharded. Experts will be
# replicated along this secondary data axis.
num_partitions = jax.device_count() // num_experts
# Override user specified model parallel submesh. Rely on T5X partitioning
# to determine new submesh from updated `num_partitions`.
logging.info(
'Overriding user specified `model_parallel_submesh`=%s to support '
'expert parallelism for updated `num_partitions`=%d',
model_parallel_submesh, num_partitions)
model_parallel_submesh = None
super().__init__(
num_partitions=num_partitions,
model_parallel_submesh=model_parallel_submesh,
params_on_devices=params_on_devices,
logical_axis_rules=logical_axis_rules)
self._state_filter_fn = state_filter_fn
def get_data_layout(self,
batch_size: Optional[int] = None,
host_index: Optional[int] = None) -> DataLayout:
"""Returns filled `DataLayout` based on the partitioned model layout.
Overrides default data layout in case were both mesh axes ('data' and
'model') are treated as data axes.
Args:
batch_size: If set, indicates the requested batch size. If not set, the
batch size is inferred from the layout.
host_index: Indicates the host index to use for the calculations, if not
set - use JAX-provided one. Should be in [0, num_hosts) interval and the
order should match the order of corresponding CPU devices in
`jax.devices()`.
Returns:
Filled `DataLayout` structure.
"""
if self.two_data_axes:
if host_index is not None:
raise NotImplementedError('Explicit host_index is not yet implemented.')
mesh_size = self._local_chunker.global_mesh.shape[
'data'] * self._local_chunker.global_mesh.shape['model']
batch_size = batch_size or mesh_size
if batch_size % mesh_size:
raise ValueError(
f'Batch size ({batch_size}) must be divisible by corresponding '
f'mesh size ({mesh_size}).')
num_shards = self._local_chunker.num_chunks['data']
if batch_size % num_shards:
raise ValueError(
f'Batch size ({batch_size}) must be divisible by number of '
f'replicas ({num_shards}).')
replica_id = self._local_chunker.get_local_chunk_info(
(batch_size,), ('data', 'model')).replica_id
return DataLayout(
batch_size=batch_size,
shard_id=self._local_chunker.chunk_ids['data'],
num_shards=num_shards,
is_first_host_in_replica_set=(replica_id == 0))
else:
return super().get_data_layout(batch_size, host_index)
def get_logical_axes(
self, train_state: Union[FlaxOptimTrainState, InferenceState]
) -> Union[FlaxOptimTrainState, InferenceState]:
"""Returns a copy of TrainState with Optional[AxisNames] as leaves.
Overrides the default logical axes by prepending the 'expert' axis to any
MoE optimizer state terms (identified by self._state_filter_fn) so they are
correctly sharded along the 'expert' axis.
Args:
train_state: Object holding all relevant training of inference state.
Returns:
State object matching structure of input train_state but with axis names
as leaves.
"""
logical_axes = train_state.as_logical_axes()
if isinstance(logical_axes, InferenceState):
# InferenceState does not contain any optimizer state, so we skip all
# expert partitioning overrides.
return logical_axes
else:
train_state: FlaxOptimTrainState
state_filter_fn = (
self._state_filter_fn or _infer_state_filter_fn(train_state))
if state_filter_fn is None:
# No state updates required.
return logical_axes
prepend_expert = lambda x: PartitionSpec( # pylint: disable=g-long-lambda
'expert',) + x if x else PartitionSpec('expert',)
optimizer_axes = logical_axes._optimizer # pylint: disable=protected-access
state_dict = flax_core.unfreeze(optimizer_axes.state_dict())
state_dict['state']['param_states'] = training_utils.tree_map_with_names(
prepend_expert, state_dict['state']['param_states'], state_filter_fn)
return train_state.restore_state(state_dict)
def partition(
self,
fn: Callable, # pylint: disable=g-bare-generic
in_axis_resources: Pytree,
out_axis_resources: Pytree,
static_argnums: Union[int, Sequence[int]] = (),
donate_argnums: Union[int, Sequence[int]] = ()
) -> t5x_partitioning.PjittedFnWithContext:
"""Partitions the computation using pjit.
Overrides the default pjit partitioning in cases where expert and data axes
are decoupled -- wherein we treat the 'model' axis as a secondary data axis.
Args:
fn: Function to partition.
in_axis_resources: Pytree of structure matching that of arguments to `fn`,
with all actual arguments replaced by resource assignment
specifications.
out_axis_resources: Like `in_axis_resources`, but specifies resource
assignment for function outputs.
static_argnums: Specifies which positional arguments to treat as static
(compile-time constant) in the partitioned function.
donate_argnums: Specifies which argument buffers are "donated" to the
computation.
Returns:
A partitioned version of the input function.
"""
if self.two_data_axes:
# Both axes are used for data parallelism in this case, so we override the
# partition specs.
in_axis_resources = _override_partition_specs(in_axis_resources)
out_axis_resources = _override_partition_specs(out_axis_resources)
pjitted = t5x_partitioning.pjit(
fn,
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources,
static_argnums=static_argnums,
donate_argnums=donate_argnums,
backend=self._backend)
return t5x_partitioning.PjittedFnWithContext(pjitted, self.mesh,
self._logical_axis_rules)
def standard_logical_axis_rules(
num_experts: int,
num_partitions: Optional[int] = None,
model_parallel_submesh: Optional[HardwareMesh] = None,
activation_partitioning_dims: int = 1,
parameter_partitioning_dims: int = 1,
additional_rules: Optional[LogicalAxisRules] = None):
"""Returns partitioning rules for MoE models.
The partitioning rules vary based on whether the expert and data axes need to
be decoupled; see also MoePjitPartitioner for details of when expert and data
axes need to be decouple.
2D parameter sharding (`parameter_partitioning_dims=2`) is not supported.
Sharding parameters along the 'data' axis will interfere with expert
parallelism, because experts are also partitioned along the 'data' axis.
Args:
num_experts: Total number of experts across all devices.
num_partitions: Size of the model parallel submesh. Model parallelism is
only used if num_model_partitions > 1. Ignored if model_parallel_submesh
is specified.
model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh
model-parallel device tile -- an axis of accelerator parallelism
orthogonal to data parallelism. Model parallelism is only used if
np.prod(model_parallel_submesh) > 1. Mutually exclusive with
`num_partitions`.
activation_partitioning_dims: Enables 2-D activation sharding when set to 2.
parameter_partitioning_dims: Enables 2-D parameter sharding when set to 2.
additional_rules: Additional rules (a sequence of tuples) that will be
appended to the standard rules.
Returns:
Sequence of logical axis rules.
Raises:
ValueError if parameter_partitioning_dims=2.
"""
if parameter_partitioning_dims == 2:
raise ValueError('2D parameter sharding (`parameter_partitioning_dims=2`) '
'is not supported for MoE.')
default_rules = t5x_partitioning.standard_logical_axis_rules(
activation_partitioning_dims, parameter_partitioning_dims)
moe_rules = [
('expert', 'data'), # Shard experts along the data axis
('expert_mlp', 'model'), # Expert MLPs partitioned along model axis
('expert_group', None), # Replicated axis for all-to-all constraints
('expert_replicas', None), # Experts replicated along this axis
('unmodeled', None), # Replicated weights
]
standard_rules = list(default_rules) + moe_rules
if additional_rules:
standard_rules.extend(additional_rules)
if _override_model_axis(num_experts, num_partitions, model_parallel_submesh):
overridden_rules = []
for logical_axis, mesh_axis in standard_rules:
if logical_axis == 'batch':
# Because we now treat the 'model' axis as a second data axis, we want
# to shard batches across both axes.
overridden_mesh_axis = ('data', 'model')
elif logical_axis == 'expert_replicas':
# "model" axis is repurposed as a second data axis, along which experts
# are replicated.
overridden_mesh_axis = 'model'
elif mesh_axis == 'model':
# Any weights ordinarily partitioned along the model axis, should be
# explicitly replicated.
overridden_mesh_axis = None
else:
overridden_mesh_axis = mesh_axis
overridden_rules.append((logical_axis, overridden_mesh_axis))
return overridden_rules
else:
return standard_rules
def data_partition_spec(two_data_axes: bool) -> PartitionSpec:
"""Returns data partitioning spec.
Args:
two_data_axes: If True, use 'model' axis as secondary data axis. Otherwise,
only use 'data' axis for data sharding.
Returns:
Mesh dependent partition spec.
"""
if two_data_axes:
# Use 'model' axis as secondary data axis. Shard batches across both axes.
return PartitionSpec(('data', 'model'),)
else:
return PartitionSpec('data',)
def _override_model_axis(
num_experts: int, num_partitions: Optional[int],
model_parallel_submesh: Optional[HardwareMesh]) -> bool:
"""Returns true iff there is no model parallelism & num experts < num devices.
Args:
num_experts: Total number of experts across all devices.
num_partitions: Size of the model parallel submesh. Model parallelism is
only used if num_model_partitions > 1. Mutually exclusive with
`model_parallel_submesh`.
model_parallel_submesh: 4-tuple that specifies the `(x, y, z, c)` submesh
model-parallel device tile -- an axis of accelerator parallelism
orthogonal to data parallelism. Model parallelism is only used if
np.prod(model_parallel_submesh) > 1. Mutually exclusive with
`num_partitions`.
Returns:
True if there is no model parallelism & num experts < num devices; False
otherwise.
"""
if (num_partitions is None) == (model_parallel_submesh is None):
raise ValueError(
'One, and only one, of {num_partitions, model_parallel_submesh} must '
'be specified. Received: %s and %s' %
(num_partitions, model_parallel_submesh))
if num_experts == 0 or jax.device_count() <= num_experts:
# No expert replication required. No need to override model mesh axis.
return False
return ((num_partitions is not None and num_partitions <= 1) or
(model_parallel_submesh is not None and
np.prod(model_parallel_submesh) <= 1))
def _override_partition_specs(resources: Pytree):
"""Override axis resources for two data axes setup.
In the two data axes setup, we treat the 'model' axis as a secondary data
axis. To this end, we override any hardcoded, raw partition specs:
- PartitionSpec('data',) -> PartitionSpec(('data', 'model'),)
- PartitionSpec('model',) -> None
There is no need to override any params or optimizer state as these will
inherit the correct specs from the logical axis rules; see
standard_logical_axis_rules().
Args:
resources: Axis resource assignment specifications.
Returns:
Axis resources with partition specs overridden to use 'model' as secondary
'data' axis.
"""
def _maybe_overridde_spec(axis_resource: Pytree):
"""Overrides "data" and "model" partition specs; leaves others unchanged."""
if axis_resource == PartitionSpec('data',):
# Shard all batches across both axes.
return PartitionSpec(('data', 'model'),)
elif axis_resource == PartitionSpec('model',):
# No model parallelism.
return None
else:
return axis_resource
if resources is None:
return resources
elif not isinstance(resources, Sequence):
return _maybe_overridde_spec(resources)
else:
overridden_resources = []
for resource in resources:
overridden_resources.append(_maybe_overridde_spec(resource))
return tuple(overridden_resources)
def _infer_state_filter_fn(
train_state: FlaxOptimTrainState) -> Optional[Callable[[str], bool]]:
"""Infers relevant regex matching sharded expert model state for optimizer.
Only the Adafactor optimizer is currently supported.
The model state generally inherits the correct partitioning specs from the
model parameters, except in cases where the kernel is factored (`v_col` and
`v_row` terms); see derive_logical_axes():
https://github.com/google-research/t5x/blob/main/t5x/adafactor.py#L591. For
those cases, we use the state_filter_fn to identify the factored kernel terms
that need to be partitioned along the expert axis.
Args:
train_state: Object holding optimizer and optimizer state (parameters).
Returns:
Function to identify which model state is sharded along 'expert' axis.
Raises:
ValueError if optimizer (on train state) is not an Adafactor optimizer.
"""
optimizer = train_state._optimizer # pylint: disable=protected-access
optimizer_def = optimizer.optimizer_def
# TODO(jamesleethorp): Revisit once other T5X optimizers are available.
if not isinstance(optimizer_def, adafactor.Adafactor):
raise ValueError('Inferred MoE overrides are currently only available for '
f'the Adafactor optimizer. Received: {optimizer_def}')
if optimizer_def.hyper_params.factored:
# Factored kernel terms (`v_col` and `v_row`) need to be identified for
# expert sharding.
return training_utils.match_fn(r'.*expert.*/kernel/v_.*')
else:
# Non-factored kernel terms (`v`) inherit the correct specs, so no state
# updates will be required.
return None
|