Spaces:
Build error
Build error
| # Copyright 2022 The T5X Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Utilities for partitioning.""" | |
| import abc | |
| import collections | |
| import dataclasses | |
| import typing | |
| from typing import Any, Callable, Optional, Sequence, Tuple, Union | |
| from absl import logging | |
| import cached_property | |
| from flax import traverse_util | |
| from flax.linen import partitioning as flax_partitioning | |
| import jax | |
| from jax import numpy as jnp | |
| from jax import random | |
| from jax.experimental import PartitionSpec | |
| from jax.experimental.maps import Mesh | |
| from jax.experimental.pjit import pjit as jax_pjit | |
| import numpy as np | |
| from t5x import train_state as train_state_lib | |
| JaxDevice = jax.lib.xla_client.Device | |
| TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores). | |
| OtherMesh = Tuple[int, int] | |
| HardwareMesh = Union[TpuMesh, OtherMesh] | |
| PyTreeDef = type(jax.tree_structure(None)) | |
| TrainState = train_state_lib.TrainState | |
| LogicalAxisRules = Sequence[Tuple[str, Optional[str]]] | |
| if typing.TYPE_CHECKING: # See b/163639353 | |
| cached_property = property # pylint: disable=invalid-name | |
| else: | |
| cached_property = cached_property.cached_property | |
| class AxisNames(tuple): | |
| """Tuple of strings specifying name for each axis. | |
| We create a separate class for this so JAX's pytree utilities can distinguish | |
| it from a tuple that should be treated as a pytree, instead treating it as a | |
| leaf. | |
| """ | |
| def __new__(cls, *names): | |
| return tuple.__new__(AxisNames, names) | |
| def __repr__(self): | |
| return 'AxisNames%s' % tuple.__repr__(self) | |
| # pjit wrappers for cpu fallback. | |
| # ----------------------------------------------------------------------------- | |
| # TODO(levskaya): upstream this fallback behavior to jax pjit. | |
| def pjit( | |
| fun: Callable, # pylint: disable=g-bare-generic | |
| in_axis_resources, | |
| out_axis_resources, | |
| static_argnums: Union[int, Sequence[int]] = (), | |
| donate_argnums: Union[int, Sequence[int]] = (), | |
| backend: Optional[str] = None): | |
| """Wrapper for pjit that calls normal jit on cpu.""" | |
| if jax.devices(backend)[0].platform == 'cpu': | |
| return jax.jit( | |
| fun, static_argnums=static_argnums, donate_argnums=donate_argnums) | |
| else: | |
| return jax_pjit( | |
| fun, | |
| in_axis_resources, | |
| out_axis_resources, | |
| static_argnums=static_argnums, | |
| donate_argnums=donate_argnums) | |
| def with_sharding_constraint(x, axis_resources): | |
| """Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit.""" | |
| if jax.devices()[0].platform == 'cpu' or not global_mesh_defined(): | |
| return x | |
| else: | |
| return jax.experimental.pjit.with_sharding_constraint(x, axis_resources) | |
| # pjit Mesh creation functions. | |
| # ----------------------------------------------------------------------------- | |
| def bounds_from_last_device( | |
| last_device: jax.lib.xla_client.Device) -> HardwareMesh: | |
| """Get the bound from the given last device.""" | |
| # Must be passed the device at the highest-coordinate corner of the | |
| # relevant mesh, which is a requirement we know is satisfied by the last | |
| # device in jax.devices(). | |
| if hasattr(last_device, 'coords'): | |
| x, y, z = last_device.coords | |
| return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 | |
| else: | |
| # On non-TPU platforms, the "mesh" is hosts x devices per host in order | |
| # to take advantage of faster within-host interconnect. | |
| return jax.host_count(), jax.local_device_count() | |
| def get_coords(device: jax.lib.xla_client.Device) -> HardwareMesh: | |
| """Returns the coordinates of the given device.""" | |
| if hasattr(device, 'coords'): | |
| return (*device.coords, device.core_on_chip) | |
| return (device.process_index, device.id % jax.local_device_count()) | |
| def global_mesh_defined(): | |
| """Checks if global xmap/pjit mesh resource environment is defined.""" | |
| maps_env = jax.experimental.maps.thread_resources.env | |
| return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison | |
| def get_mesh(model_parallel_submesh: HardwareMesh, | |
| input_devices: Sequence[JaxDevice] = (), | |
| input_local_devices: Sequence[JaxDevice] = (), | |
| tile_by_host_if_needed: bool = True, | |
| backend: Optional[str] = None) -> Mesh: | |
| """Construct an xmap/pjit Mesh for the given model-parallel submesh. | |
| The resulting mesh has two resource axes: 'model', with the provided submesh | |
| shape, and 'data', which covers the rest of the mesh. | |
| Args: | |
| model_parallel_submesh: a HardwareMesh spec, namely (x,y,z,core) on TPU for | |
| a single model-parallel replica's "tile" in the physical device mesh. The | |
| first three elements (`x`, `y`, and `z`) should be factors of the pod | |
| slice; e.g., if you are using df_4x8, then `x` should be a factor of 4 | |
| (one of 1, 2, 4), `y` should be a factor of 8 (one of 1, 2, 4, 8), and `z` | |
| must be 1, because TPU v3 slices are only 2D. `z` can be >1 for TPU v4 | |
| (and maybe later TPUs) that allow 3D slices. `core` is the number of cores | |
| to use from each TPU node. As communication is usually fastest inside the | |
| same node, if you need a tile of more than 1 core, then | |
| you should first increase `core`: e.g., for TPU v3, (1,1,1,2) is better | |
| than (2,1,1,1). To pick a good spec, try a few possible values until you | |
| get high TPU utilization. | |
| input_devices: the devices to use, will use jax.devices() if this is not | |
| set. | |
| input_local_devices: the local devices to use, will use jax.local_devices() | |
| if this is not set. | |
| tile_by_host_if_needed: JAX currently requires that the parts of any sharded | |
| array that are located on one host's local devices form a single | |
| contiguous slice. A best effort will be made to achieve this without | |
| "tiling" the device assignment over hosts (which can reduce XLA collective | |
| performance). If this flag is True, then the device assignment will be | |
| tiled over hosts if necessary to satisfy this constraint and create a | |
| buildable mesh; if false, mesh construction will fail instead. | |
| backend: get devices from the pinned backend, if specified. This is | |
| useful for explicitly specifying the devices other than relying on | |
| jax_platform_name. | |
| Returns: | |
| A xmap / pjit Mesh containing the virtual device mesh with data, model axes. | |
| """ | |
| input_devices = input_devices or jax.devices(backend) | |
| input_local_devices = input_local_devices or jax.local_devices(0, backend) | |
| last_device = input_devices[-1] | |
| global_hardware_mesh = bounds_from_last_device(last_device) | |
| mesh_ndim = len(global_hardware_mesh) | |
| local_hardware_mesh = bounds_from_last_device(input_local_devices[-1]) | |
| mesh_err = ( | |
| f'each dimension of the model parallel submesh {model_parallel_submesh} ' | |
| 'must be a factor of the corresponding dimension of the global device ' | |
| f'mesh {global_hardware_mesh}') | |
| assert not any( | |
| g % m | |
| for g, m in zip(global_hardware_mesh, model_parallel_submesh)), mesh_err | |
| assert not any( | |
| g % l for g, l in zip(global_hardware_mesh, local_hardware_mesh)) | |
| devices = np.empty(global_hardware_mesh, dtype=np.object) | |
| for device in input_devices: | |
| device_coords = get_coords(device) | |
| devices[device_coords] = device | |
| tile_by_host = tile_by_host_if_needed | |
| if len(global_hardware_mesh) == 4: | |
| # enable contiguous local chunks without host tiling by making Z major | |
| global_hardware_mesh = typing.cast(Tuple[int, int, int, int], | |
| global_hardware_mesh) | |
| model_parallel_submesh = typing.cast(Tuple[int, int, int, int], | |
| model_parallel_submesh) | |
| gx, gy, gz, gc = global_hardware_mesh | |
| mx, my, mz, mc = model_parallel_submesh | |
| if (mx == gx > 1 and my == mz == 1) or (mx == 1 and my == gy > 1 and | |
| mz == gz > 1): | |
| logging.info('ensuring YZ plane has a Z-major device order') | |
| # YZ should be ZY | |
| assert mc == gc, (mc, gc) | |
| global_hardware_mesh = gx, gz, gy, gc | |
| model_parallel_submesh = mx, mz, my, mc | |
| devices = devices.swapaxes(1, 2) | |
| tile_by_host = False | |
| if (my == gy > 1 and mx == mz == 1) or (my == 1 and mx == gx > 1 and | |
| mz == gz > 1): | |
| logging.info('ensuring XZ plane has a Z-major device order') | |
| # XZ should be ZX | |
| assert mc == gc, (mc, gc) | |
| global_hardware_mesh = gz, gy, gx, gc | |
| model_parallel_submesh = mz, my, mx, mc | |
| devices = devices.swapaxes(0, 2) | |
| tile_by_host = False | |
| if tile_by_host: | |
| logging.warning( | |
| 'Tiling device assignment mesh by hosts, which may lead to ' | |
| 'reduced XLA collective performance. To avoid this, modify ' | |
| 'the model parallel submesh or run with more tasks per host.') | |
| tile_err = ( | |
| 'to tile the mesh by hosts, each dimension of the model parallel ' | |
| 'submesh must be either a factor or a multiple of the corresponding ' | |
| 'dimension of the per-host submesh') | |
| def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]: | |
| """Split a global mesh dimension into four tiling components. | |
| Args: | |
| g: global mesh bounds dimension size | |
| m: model-parallel submesh bounds dimension size | |
| l: local submesh bounds dimension size | |
| Returns: | |
| The resulting tuple divides the dimension into the hosts component of | |
| the data-parallel submesh, the devices component of the data-parallel | |
| submesh, the hosts component of the model-parallel submesh, and the | |
| devices component of the model-parallel submesh. | |
| """ | |
| d = g // m | |
| if m >= l: | |
| assert not m % l, tile_err | |
| return (d, 1, m // l, l) | |
| else: | |
| assert not l % m, tile_err | |
| return (d // (l // m), l // m, 1, m) | |
| # e.g. [(x_data_hosts, x_data_devs, x_model_hosts, x_model_devs), ...] | |
| dh_dd_mh_md_tups = map(dh_dd_mh_md, global_hardware_mesh, | |
| model_parallel_submesh, local_hardware_mesh) | |
| # reshape to e.g. (x_dh, x_dd, x_mh, x_md, y_dh, ...) | |
| devices = devices.reshape(*(s for t in dh_dd_mh_md_tups for s in t)) # pylint: disable=g-complex-comprehension | |
| # TODO(jekbradbury): reorder local subgroups for ring locality | |
| # Transpose to [data_host], [data_device], [model_host], [model_device] | |
| # block ordering e.g. (x_dh, y_dh, ..., x_dd, y_dd, ...) | |
| devices = devices.transpose(*(4 * i for i in range(mesh_ndim)), | |
| *(4 * i + 1 for i in range(mesh_ndim)), | |
| *(4 * i + 2 for i in range(mesh_ndim)), | |
| *(4 * i + 3 for i in range(mesh_ndim))) | |
| else: | |
| # e.g. [(x_data, x_model), (y_data, y_model), ...] | |
| model_data_tups = [ | |
| (g // m, m) | |
| for g, m in zip(global_hardware_mesh, model_parallel_submesh) | |
| ] | |
| # reshape to e.g. (x_data, x_model, y_data, y_model...) | |
| devices = devices.reshape(*(s for t in model_data_tups for s in t)) # pylint: disable=g-complex-comprehension | |
| # TODO(jekbradbury): reorder small subgroups for ring locality | |
| # transpose to e.g. (x_data, y_data, ..., x_model, ...) | |
| devices = devices.transpose(*(2 * i for i in range(mesh_ndim)), | |
| *(2 * i + 1 for i in range(mesh_ndim))) | |
| # reshape to (data, model) | |
| devices = devices.reshape(-1, np.prod(model_parallel_submesh)) | |
| global_mesh = Mesh(devices, ['data', 'model']) | |
| logging.info('global_mesh axes_names: %s', global_mesh.axis_names) | |
| logging.info('global_mesh devices: %s', global_mesh.devices) | |
| return global_mesh | |
| def get_cpu_mesh() -> Mesh: | |
| """Trivial mesh for CPU Testing.""" | |
| devices = np.empty((jax.host_count(), jax.local_device_count()), | |
| dtype=np.object) | |
| for device in jax.devices(): | |
| devices[device.process_index, device.id % jax.local_device_count()] = device | |
| return Mesh(devices, ['data', 'model']) | |
| def get_gpu_mesh() -> Mesh: | |
| """Simple mesh for GPUs.""" | |
| devices = np.empty((jax.host_count(), jax.local_device_count()), | |
| dtype=np.object) | |
| for device in jax.devices(): | |
| devices[device.process_index, device.id % jax.local_device_count()] = device | |
| return Mesh(devices, ['data', 'model']) | |
| def default_mesh(num_partitions: int, | |
| model_parallel_submesh: Optional[HardwareMesh] = None, | |
| backend: Optional[str] = None) -> Mesh: | |
| """Attempt to return a default mesh for simple cases. | |
| Args: | |
| num_partitions: number of partitions to use, will be ignored if | |
| model_parallel_submesh is provided. | |
| model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use as | |
| the model-parallel device tile. | |
| backend: get devices from the pinned backend, if specified. This is useful | |
| for explicitly specifying the devices other than relying on | |
| jax_platform_name. | |
| Returns: | |
| xmap/pjit 2D Mesh with 'data', 'model' mesh axes. | |
| """ | |
| last_device = jax.devices(backend)[-1] | |
| platform = last_device.platform | |
| device_kind = last_device.device_kind | |
| bounds = bounds_from_last_device(last_device) | |
| if model_parallel_submesh: | |
| return get_mesh(model_parallel_submesh, backend=backend) | |
| if platform == 'cpu': | |
| return get_cpu_mesh() | |
| elif platform == 'gpu': | |
| return get_gpu_mesh() | |
| mps = None | |
| if device_kind in ('TPU v2', 'TPU v3'): | |
| if num_partitions == 1: | |
| mps = (1, 1, 1, 1) | |
| elif num_partitions == 2: | |
| mps = (1, 1, 1, 2) | |
| elif num_partitions == 4: | |
| mps = (2, 1, 1, 2) | |
| elif num_partitions == 8: | |
| mps = (2, 2, 1, 2) | |
| elif num_partitions == 16: | |
| mps = (4, 2, 1, 2) | |
| # assume the use of megacore on TPU v4 | |
| elif device_kind == 'TPU v4' and bounds[3] == 1: | |
| if num_partitions == 1: | |
| mps = (1, 1, 1, 1) | |
| elif num_partitions == 2: | |
| mps = (1, 2, 1, 1) | |
| elif num_partitions == 4: | |
| if bounds[0] >= 4: | |
| mps = (4, 1, 1, 1) | |
| else: | |
| mps = (2, 2, 1, 1) | |
| elif num_partitions == 8: | |
| if bounds[2] >= 8: | |
| mps = (1, 1, 8, 1) | |
| else: | |
| mps = (4, 2, 1, 1) | |
| elif num_partitions == 16: | |
| if bounds[2] >= 16: | |
| mps = (1, 1, 16, 1) | |
| elif bounds[0] >= 8: | |
| mps = (8, 2, 1, 1) | |
| else: | |
| mps = (4, 4, 1, 1) | |
| if mps is None: | |
| raise ValueError('No default mesh for this configuration: specify ' | |
| 'config.model_parallel_submesh explicitly.') | |
| return get_mesh(mps, backend=backend) | |
| # Data chunking helper. | |
| # ----------------------------------------------------------------------------- | |
| class LocalChunkInfo: | |
| # The logical slice of an array located on this host's local devices. | |
| slice: Tuple[slice, ...] | |
| # A unique index for this host/local chunk among chunks with the same slice. | |
| replica_id: int | |
| class LocalChunker: | |
| """Utility class to aid chunking of sharded arrays in multihost settings.""" | |
| def __init__(self, global_mesh: Mesh): | |
| self.global_mesh = global_mesh | |
| local_mesh = global_mesh.local_mesh | |
| first_local_device = local_mesh.devices.reshape(-1)[0] | |
| host_location = collections.OrderedDict( | |
| zip( | |
| global_mesh.shape.keys(), | |
| list(zip(*np.nonzero( | |
| global_mesh.devices == first_local_device)))[0])) | |
| self.num_chunks = collections.OrderedDict() | |
| self.chunk_ids = collections.OrderedDict() | |
| self.mesh_axes = list(global_mesh.shape.keys()) | |
| for mesh_axis in self.mesh_axes: | |
| num_devices_per_chunk = local_mesh.shape[mesh_axis] | |
| self.num_chunks[mesh_axis] = ( | |
| global_mesh.shape[mesh_axis] // num_devices_per_chunk) | |
| self.chunk_ids[mesh_axis] = ( | |
| host_location[mesh_axis] // num_devices_per_chunk) | |
| def get_local_chunk_info( | |
| self, global_shape: Tuple[int, ...], | |
| mesh_axes: Sequence[Optional[str]]) -> LocalChunkInfo: | |
| """Get the local chunk info for a given array shape and sharded axes. | |
| Args: | |
| global_shape: the global, unsharded shape of the array to chunk. | |
| mesh_axes: a sequence of names (or None) of equal rank to `global_shape` | |
| that specifies which mesh dimensions the array is sharded along. | |
| Returns: | |
| LocalChunkInfo containing the logical slices of the array found on this | |
| host's local devices, as well as the replica index for this chunk among | |
| chunks with the same slice. The latter is used to determine which | |
| host should write this chunk during checkpointing. | |
| """ | |
| local_slice = [slice(None) for dim in global_shape] | |
| sharded_mesh_axes = set() | |
| for i, (mesh_axis, size) in enumerate(zip(mesh_axes, global_shape)): | |
| if not mesh_axis: | |
| continue | |
| sharded_mesh_axes.add(mesh_axis) | |
| if not isinstance(mesh_axis, str): | |
| raise NotImplementedError('TODO(jekbradbury)') | |
| chunk_id = self.chunk_ids[mesh_axis] | |
| chunk_size = size // self.num_chunks[mesh_axis] | |
| local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size) | |
| replicated_mesh_axes = [ | |
| mesh_axis for mesh_axis in self.mesh_axes | |
| if mesh_axis not in sharded_mesh_axes | |
| ] | |
| replica_id = 0 | |
| for mesh_axis in replicated_mesh_axes: | |
| chunk_id = self.chunk_ids[mesh_axis] | |
| replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id | |
| return LocalChunkInfo(tuple(local_slice), replica_id) | |
| def standard_logical_axis_rules( | |
| activation_partitioning_dims: int = 1, | |
| parameter_partitioning_dims: int = 1, | |
| additional_rules: Optional[LogicalAxisRules] = None) -> LogicalAxisRules: | |
| """Default sharding rules for T5X model in terms of logical axis names. | |
| Args: | |
| activation_partitioning_dims: enables 2-D activation sharding when set to 2. | |
| parameter_partitioning_dims: enables 2-D parameter sharding when set to 2. | |
| additional_rules: additional rules (a sequence of tuples) that will be | |
| appended to the standard rules. | |
| Returns: | |
| Sequence of logical axis rules | |
| """ | |
| logging.info( | |
| '`activation_partitioning_dims` = %d, `parameter_partitioning_dims` = %d', | |
| activation_partitioning_dims, parameter_partitioning_dims) | |
| if activation_partitioning_dims == 1 and parameter_partitioning_dims == 1: | |
| rules = [ | |
| ('batch', 'data'), | |
| ('vocab', 'model'), | |
| ('embed', None), | |
| ('mlp', 'model'), | |
| ('heads', 'model'), | |
| ('kv', None), | |
| ('joined_kv', 'model'), # joined heads+kv dim in 2D attn param layouts | |
| ] | |
| elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 1: | |
| rules = [ | |
| ('batch', 'data'), | |
| ('vocab', 'model'), | |
| ('mlp', 'model'), | |
| ('heads', 'model'), | |
| ('kv', None), | |
| ('joined_kv', 'model'), | |
| ('embed', 'model'), | |
| ] | |
| elif activation_partitioning_dims == 1 and parameter_partitioning_dims == 2: | |
| rules = [ | |
| ('batch', 'data'), | |
| ('vocab', 'model'), | |
| ('mlp', 'model'), | |
| ('heads', 'model'), | |
| ('kv', None), | |
| ('joined_kv', 'model'), | |
| ('embed', 'data'), | |
| ] | |
| elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 2: | |
| rules = [ | |
| ('batch', 'data'), | |
| ('vocab', 'model'), | |
| ('mlp', 'model'), | |
| ('heads', 'model'), | |
| ('kv', None), | |
| ('joined_kv', 'model'), | |
| ('embed', 'model'), | |
| ('embed', 'data'), | |
| ] | |
| else: | |
| raise ValueError( | |
| f'`activation_partitioning_dims` = {activation_partitioning_dims} ' | |
| f'`parameter_partitioning_dims` = {parameter_partitioning_dims} ' | |
| 'is not supported.') | |
| # Add the common rules for the replicated logical axes names. | |
| replicated_rules = [ | |
| ('relpos_buckets', None), | |
| ('abspos_buckets', None), | |
| ('length', None), | |
| ('layers', None), | |
| ('stack', None), | |
| ('mlp_activations', None), | |
| ] | |
| rules.extend(replicated_rules) | |
| if additional_rules: | |
| rules.extend(additional_rules) | |
| return rules | |
| # NB: This needs to be top-level for the jax compilation cache. | |
| def _id_fn(x, ix): | |
| """Identity function for copying parameters to the devices, sharded.""" | |
| # A pure identity such as `lambda x, *: x` can get optimized away, so we | |
| # include a random.split as a cheap function that cannot be optimized away. | |
| return x, random.split(jnp.array([ix, ix], dtype=jnp.uint32)) | |
| class DataLayout: | |
| """Represents data layout for the partitioned model.""" | |
| batch_size: int | |
| shard_id: int | |
| num_shards: int | |
| is_first_host_in_replica_set: bool | |
| PartitionedCallable = Callable[..., Any] | |
| CompiledPartitionedCallable = Callable[..., Any] | |
| class BasePartitioner(metaclass=abc.ABCMeta): | |
| """Interface for partitioning computations across hardware devices.""" | |
| def __init__(self, | |
| num_partitions: Optional[int] = None, | |
| model_parallel_submesh: Optional[HardwareMesh] = None, | |
| params_on_devices: bool = True, | |
| backend: Optional[str] = None): | |
| """Configures the partitioner. | |
| Args: | |
| num_partitions: the number of partitions to use. Ignored if | |
| `model_parallel_submesh` is provided. | |
| model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use | |
| as the model-parallel device tile. This submesh is used for the larger | |
| of the two parameter dimensions, and, if 2-D activation sharding is | |
| enabled, for the model dimension of activations. The rest of the mesh is | |
| used for data parallelism and, if 2-D parameter sharding is enabled, the | |
| other parameter dimension. | |
| params_on_devices: whether to keep the params on devices, if False - | |
| params stay in the host memory. Note that some partitioners might ignore | |
| this setting, for example if they don't support storing all params on | |
| device memory. | |
| backend: get devices from the pinned backend, if specified. This is useful | |
| for explicitly specifying the devices other than relying on | |
| jax_platform_name. | |
| """ | |
| if not num_partitions and not model_parallel_submesh: | |
| raise ValueError('At least one of `num_partitions` or ' | |
| '`model_parallel_submesh` must be set.') | |
| if model_parallel_submesh is not None and len(model_parallel_submesh) != 4: | |
| logging.error( | |
| '`model_parallel_submesh` must be either None or a 4-tuple. Got ' | |
| 'Got `num_partitions=%s`. A ValueError will be raised beginning ' | |
| 'March 1, 2022.', model_parallel_submesh) | |
| if bool(num_partitions) and bool(model_parallel_submesh): | |
| logging.error( | |
| 'At most one of `num_partitions` or `model_parallel_submesh` can be ' | |
| 'set. Got `num_partitions=%s` and `model_parallel_submesh`=%s. A ' | |
| 'ValueError will be raised beginning March 21, 2022.', num_partitions, | |
| model_parallel_submesh) | |
| self._num_partitions = num_partitions | |
| self._model_parallel_submesh = model_parallel_submesh | |
| self._params_on_devices = params_on_devices | |
| self._data_axis = 'data' | |
| self._backend = backend | |
| def mesh(self) -> Mesh: | |
| raise NotImplementedError | |
| def data_partition_spec(self) -> PartitionSpec: | |
| return PartitionSpec(self._data_axis) | |
| def get_data_layout(self, | |
| batch_size: Optional[int] = None, | |
| host_index: Optional[int] = None) -> DataLayout: | |
| """Returns filled `DataLayout` based on the partitioned model layout. | |
| Args: | |
| batch_size: if set, indicates the requested batch size. The exception will | |
| be raised if this batch size is not compatible with the layout. If not | |
| set, the batch size is inferred from the layout. | |
| host_index: indicates the host index to use for the calculations, if not | |
| set - use JAX-provided one. Should be in [0, num_hosts) interval and the | |
| order should match the order of corresponding CPU devices in | |
| `jax.devices()`. | |
| Returns: | |
| Filled `DataLayout` structure. | |
| """ | |
| if host_index is not None: | |
| raise NotImplementedError('Explicit host_index is not yet implemented.') | |
| if self._data_axis is None: | |
| return DataLayout( | |
| batch_size=batch_size, | |
| shard_id=0, | |
| num_shards=1, | |
| is_first_host_in_replica_set=(jax.process_index() == 0)) | |
| mesh_size = self._local_chunker.global_mesh.shape[self._data_axis] | |
| batch_size = batch_size or mesh_size | |
| if batch_size % mesh_size: | |
| raise ValueError( | |
| f'Batch size ({batch_size}) must be divisible by corresponding ' | |
| f'mesh size ({mesh_size}).') | |
| num_shards = self._local_chunker.num_chunks[self._data_axis] | |
| if batch_size % num_shards: | |
| raise ValueError( | |
| f'Batch size ({batch_size}) must be divisible by number of ' | |
| f'replicas ({num_shards}).') | |
| replica_id = self._local_chunker.get_local_chunk_info( | |
| (batch_size,), [self._data_axis]).replica_id | |
| return DataLayout( | |
| batch_size=batch_size, | |
| shard_id=self._local_chunker.chunk_ids[self._data_axis], | |
| num_shards=num_shards, | |
| is_first_host_in_replica_set=(replica_id == 0)) | |
| def get_local_chunk_info( | |
| self, global_shape: Tuple[int, ...], | |
| mesh_axes: Sequence[Optional[str]]) -> LocalChunkInfo: | |
| """Returns the local chunk info for a given array shape and sharded axes.""" | |
| return self._local_chunker.get_local_chunk_info(global_shape, mesh_axes) | |
| def params_on_devices(self): | |
| return self._params_on_devices | |
| def move_params_to_devices(self, train_state: TrainState, | |
| train_state_axes: TrainState) -> TrainState: | |
| """Moves the optimizer parameters to devices.""" | |
| p_id_fn = self.partition( | |
| _id_fn, | |
| in_axis_resources=(train_state_axes, None), | |
| out_axis_resources=(train_state_axes, None), | |
| donate_argnums=(0,)) | |
| train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32)) | |
| return train_state | |
| def _local_chunker(self): | |
| """Returns the chunker that matches the parameters of this partitioner.""" | |
| raise NotImplementedError | |
| def get_logical_axes(self, train_state: TrainState) -> TrainState: | |
| """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" | |
| # By default, return None for the logical axes. | |
| return train_state.restore_state( | |
| jax.tree_map(lambda x: None, train_state.state_dict())) | |
| def get_mesh_axes(self, train_state: TrainState) -> TrainState: | |
| """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" | |
| raise NotImplementedError | |
| def partition( | |
| self, | |
| fn: Callable, # pylint: disable=g-bare-generic | |
| in_axis_resources, | |
| out_axis_resources, | |
| static_argnums: Union[int, Sequence[int]] = (), | |
| donate_argnums: Union[int, Sequence[int]] = () | |
| ) -> PartitionedCallable: | |
| """Partitions the computation using partitioner-specific implementation. | |
| Args: | |
| fn: the function to partition. | |
| in_axis_resources: Pytree of structure matching that of arguments to `fn`, | |
| with all actual arguments replaced by resource assignment | |
| specifications. It is also valid to specify a pytree prefix (e.g. one | |
| value in place of a whole subtree), in which case the leaves get | |
| broadcast to all values in that subtree. | |
| The valid resource assignment specifications are: | |
| `None`: in which case the value will be replicated on all devices | |
| `PartitionSpec`: a tuple of length at most equal to the rank of the | |
| partitioned value. Each element can be a `None`, a mesh axis or a | |
| tuple of mesh axes, and specifies the set of resources assigned to | |
| partition the value's dimension matching its position in the spec. | |
| out_axis_resources: Like `in_axis_resources`, but specifies resource | |
| assignment for function outputs. | |
| static_argnums: an optional int or collection of ints that specify which | |
| positional arguments to treat as static (compile-time constant) in the | |
| partitioned function. | |
| donate_argnums: an optional int or collection of ints that specify which | |
| argument buffers are "donated" to the computation. It is safe to donate | |
| argument buffers if you no longer need them once the computation has | |
| finished. | |
| Returns: | |
| A partitioned version of the input function. | |
| """ | |
| raise NotImplementedError | |
| def compile(self, partitioned_fn: PartitionedCallable, | |
| *args) -> CompiledPartitionedCallable: | |
| """Compiles and returns the partitioned function, or the original. | |
| Args: | |
| partitioned_fn: The partitioned function. | |
| *args: Sample arguments to the partitioned function matching the input | |
| shapes that will be passed to the compiled function. | |
| Returns: | |
| The compiled function, or the original if this partitioner does not | |
| support compilation. | |
| """ | |
| raise NotImplementedError | |
| class PjittedFnWithContext(PartitionedCallable): | |
| """Wraps pjitted function to apply the appropriate contexts.""" | |
| def __init__(self, | |
| pjitted_fn, | |
| partition_mesh: Mesh, | |
| logical_axis_rules: flax_partitioning.LogicalRules = ()): | |
| self._pjitted_fn = pjitted_fn | |
| self._mesh = partition_mesh | |
| self._logical_axis_rules = logical_axis_rules | |
| def __call__(self, *args): | |
| with Mesh(self._mesh.devices, | |
| self._mesh.axis_names), flax_partitioning.axis_rules( | |
| self._logical_axis_rules): | |
| return self._pjitted_fn(*args) | |
| def lower(self, *args): | |
| with Mesh(self._mesh.devices, | |
| self._mesh.axis_names), flax_partitioning.axis_rules( | |
| self._logical_axis_rules): | |
| return self._pjitted_fn.lower(*args) | |
| class BasePjitPartitioner(BasePartitioner): | |
| """Partitioner that uses T5X version of jax.pjit.""" | |
| def _local_chunker(self) -> LocalChunker: | |
| return LocalChunker(self.mesh) | |
| def mesh(self) -> Mesh: | |
| return default_mesh(self._num_partitions, self._model_parallel_submesh, | |
| self._backend) | |
| def partition( | |
| self, | |
| fn: Callable, # pylint: disable=g-bare-generic | |
| in_axis_resources, | |
| out_axis_resources, | |
| static_argnums: Union[int, Sequence[int]] = (), | |
| donate_argnums: Union[int, Sequence[int]] = () | |
| ) -> PjittedFnWithContext: | |
| pjitted = pjit( | |
| fn, | |
| in_axis_resources=in_axis_resources, | |
| out_axis_resources=out_axis_resources, | |
| static_argnums=static_argnums, | |
| donate_argnums=donate_argnums, | |
| backend=self._backend) | |
| return PjittedFnWithContext(pjitted, self.mesh) | |
| def compile(self, partitioned_fn: PjittedFnWithContext, | |
| *args) -> CompiledPartitionedCallable: | |
| return partitioned_fn.lower(*args).compile() | |
| class PjitPartitioner(BasePjitPartitioner): | |
| """Partitioner that uses named axes and jax.pjit.""" | |
| def __init__(self, | |
| num_partitions: Optional[int] = None, | |
| model_parallel_submesh: Optional[HardwareMesh] = None, | |
| params_on_devices: bool = True, | |
| backend: Optional[str] = None, | |
| logical_axis_rules: Optional[LogicalAxisRules] = None): | |
| """PjitPartitioner constructor. | |
| See https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/partitioning for details. | |
| Args: | |
| num_partitions: an integer that specifies the size of the model parallel | |
| submesh to be automatically selected for the current topology. See | |
| `model_parallel_submesh` for details on how this submesh is used. | |
| Mutually exlusive with `model_parallel_submesh`. | |
| model_parallel_submesh: is a 4-tuple that specifies the `(x, y, z, c)` | |
| submesh model-parallel device tile, an axis of accelerator parallelism | |
| orthogonal to data parallelism. Array axes in a model's parameters or | |
| activations can be sharded over this submesh using axis rules (see | |
| `logical_axis_rules`) that map them to 'model'. The effective number of | |
| model sub-partitions is equal to `np.prod(model_parallel_submesh)` and | |
| must evenly divide the total number of devices (i.e., | |
| `jax.device_count() % np.prod(model_parallel_submesh) == 0`). The rest | |
| of the TPU mesh is the data parallel submesh, providing | |
| `jax.device_count() // np.prod(model_parallel_submesh)` partitions. It | |
| is used for data (batch) parallelism and to shard other array axes that | |
| are mapped to 'data'. This argument is mutually exclusive with | |
| `num_partitions`. | |
| params_on_devices: whether to keep the params on devices, if False - | |
| params stay in the host memory. Note that some partitioners might ignore | |
| this setting, for example if they don't support storing all params on | |
| device memory. | |
| backend: get devices from the pinned backend, if specified. This is | |
| useful for explicitly specifying the devices other than relying on | |
| jax_platform_name. | |
| logical_axis_rules: a priority-ordered sequence of KV tuples that maps | |
| logical axis names to either `None` (not sharded), 'model' (to shard | |
| across the model-parallel submesh), or 'data' (to shard across the | |
| data-parallel submesh). | |
| """ | |
| super().__init__( | |
| num_partitions=num_partitions, | |
| model_parallel_submesh=model_parallel_submesh, | |
| params_on_devices=params_on_devices, | |
| backend=backend) | |
| if logical_axis_rules is None: | |
| logical_axis_rules = standard_logical_axis_rules() | |
| self._logical_axis_rules = tuple(logical_axis_rules) | |
| self._data_axis, = flax_partitioning.logical_to_mesh_axes( | |
| ['batch'], logical_axis_rules) | |
| def partition( | |
| self, | |
| fn: Callable, # pylint: disable=g-bare-generic | |
| in_axis_resources, | |
| out_axis_resources, | |
| static_argnums: Union[int, Sequence[int]] = (), | |
| donate_argnums: Union[int, Sequence[int]] = () | |
| ) -> PjittedFnWithContext: | |
| """Partitions the function using jax.pjit.""" | |
| pjitted = pjit( | |
| fn, | |
| in_axis_resources=in_axis_resources, | |
| out_axis_resources=out_axis_resources, | |
| static_argnums=static_argnums, | |
| donate_argnums=donate_argnums, | |
| backend=self._backend) | |
| return PjittedFnWithContext(pjitted, self.mesh, self._logical_axis_rules) | |
| def logical_axis_rules(self): | |
| """Returns the logical axis rules.""" | |
| return self._logical_axis_rules | |
| def get_logical_axes(self, train_state: TrainState) -> TrainState: | |
| """Returns a copy of TrainState with Optional[AxisNames] as leaves.""" | |
| return train_state.as_logical_axes() | |
| def get_mesh_axes(self, train_state: TrainState) -> TrainState: | |
| """Returns a copy of TrainState with Optional[PartitionSpecs] as leaves.""" | |
| logical_axes = self.get_logical_axes(train_state) | |
| def _logical_to_mesh_axes(param_name, logical_axes): | |
| if logical_axes is None: | |
| return None | |
| elif logical_axes is traverse_util.empty_node: | |
| return traverse_util.empty_node | |
| try: | |
| return flax_partitioning.logical_to_mesh_axes(logical_axes, | |
| self._logical_axis_rules) | |
| except ValueError as e: | |
| raise ValueError(f'Failed to map logical axes for {param_name}') from e | |
| flat_logical_axes = traverse_util.flatten_dict( | |
| logical_axes.state_dict(), keep_empty_nodes=True, sep='/') | |
| flat_mesh_axes = { | |
| k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items() | |
| } | |
| return logical_axes.restore_state( | |
| traverse_util.unflatten_dict(flat_mesh_axes, sep='/')) | |