kadirnar's picture
Upload 494 files
8a42f8f verified
raw
history blame
24.3 kB
import logging
import unittest
import typing
import torch
import torch.nn as nn
from torch.testing._internal import common_utils
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import layers
from apex.transformer.testing.commons import set_random_seed
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
# N.B.(mkozuki): Disable TF32 matrix multiply.
# Matrices used in this test are so small that TF32 matmul
# can be less precise so that `self.assertEqual` raises.
torch.backends.cuda.matmul.allow_tf32 = False
class TensorParallelLayerTestBase:
BATCH_SIZE: int = 8
SEQUENCE_LENGTH: int = 128
VOCAB_SIZE: int = 1024
HIDDEN_SIZE: int = 256
INPUT_SIZE_COEFF: int = 256
OUTPUT_SIZE_COEFF: int = 256
SEED: int = 123456
@property
def tensor_shape(self) -> typing.Sequence[int]:
return [self.SEQUENCE_LENGTH, self.BATCH_SIZE, self.HIDDEN_SIZE]
@torch.no_grad()
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs")
def test_all_gather_parity(self) -> None:
if self.DISTRIBUTED_BACKEND == "ucc":
self.skipTest("torch_ucc does NOT support `torch.distributed._all_gather_base` as of 2022/06/15")
from torch.distributed.distributed_c10d import all_gather, _all_gather_base # NOQA
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}")
with torch.no_grad():
tensor = tensor_model_parallel_rank * torch.ones(
self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device)
numel = tensor.numel()
numel_gathered = tensor_model_parallel_world_size * numel
gathered = torch.empty(
torch.Size((numel_gathered,)),
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
chunks = [
gathered[i * numel : (i + 1) * numel]
for i in range(tensor_model_parallel_world_size)
]
all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
gathered_for_base = torch.empty(
torch.Size((numel_gathered,)),
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
_all_gather_base(
gathered_for_base,
tensor,
group=parallel_state.get_tensor_model_parallel_group(),
)
msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
self.assertEqual(gathered, gathered_for_base, msg=msg)
parallel_state.destroy_model_parallel()
@torch.no_grad()
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs")
def test_reduce_scatter_parity(self) -> None:
if self.DISTRIBUTED_BACKEND == "ucc":
self.skipTest("torch_ucc does NOT support `torch.distributed._reduce_scatter_base` as of 2022/06/15")
from torch.distributed.distributed_c10d import reduce_scatter, _reduce_scatter_base # NOQA
for tensor_model_parallel_world_size in range(2, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}")
with torch.no_grad():
input = torch.cat([
i * torch.ones(self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device)
for i in range(tensor_model_parallel_world_size)
])
input_list = [t.clone() for t in input.chunk(tensor_model_parallel_world_size)]
output = torch.empty(
self.tensor_shape,
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
reduce_scatter(
output, input_list,
group=parallel_state.get_tensor_model_parallel_group(),
)
output_for_base = torch.empty(
self.tensor_shape,
device=cur_tensor_model_device,
dtype=torch.float32,
requires_grad=False,
)
_reduce_scatter_base(
output_for_base,
input,
group=parallel_state.get_tensor_model_parallel_group(),
)
msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
self.assertEqual(output, output_for_base, msg=msg)
self.assertEqual(input, torch.cat(input_list), msg=msg)
parallel_state.destroy_model_parallel()
def test_parallel_embedding(self) -> None:
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
set_random_seed(self.SEED + 1)
input_tensor = torch.randint(
0,
self.VOCAB_SIZE,
(
self.BATCH_SIZE,
self.SEQUENCE_LENGTH,
),
device="cuda",
)
loss_weight = torch.randn(
(
self.BATCH_SIZE,
self.SEQUENCE_LENGTH,
self.HIDDEN_SIZE,
),
device="cuda",
)
set_random_seed(self.SEED)
embedding_torch = nn.Embedding(
self.VOCAB_SIZE,
self.HIDDEN_SIZE,
).cuda()
output_torch = embedding_torch(input_tensor)
loss_torch = torch.mul(output_torch, loss_weight).sum()
loss_torch.backward()
# N.B.(mkozuki): With affine weight initialization on GPU,
# it's super difficult to keep the consistency with nn.Embedding.
# Thus, turning on `use_cpu_initialization`.
set_random_seed(self.SEED)
embedding_vocab_parallel = layers.VocabParallelEmbedding(
self.VOCAB_SIZE,
self.HIDDEN_SIZE,
init_method=nn.init.normal_,
use_cpu_initialization=True,
).cuda()
output_vocab_parallel = embedding_vocab_parallel(input_tensor)
loss_vocab_parallel = torch.mul(
output_vocab_parallel, loss_weight
).sum()
loss_vocab_parallel.backward()
msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
self.assertEqual(output_torch, output_vocab_parallel, msg=msg)
self.assertEqual(loss_torch, loss_vocab_parallel, msg=msg)
splitted_weight_torch = torch.split(
embedding_torch.weight.grad,
self.VOCAB_SIZE
// tensor_model_parallel_world_size,
0,
)[parallel_state.get_tensor_model_parallel_rank()]
self.assertEqual(
splitted_weight_torch, embedding_vocab_parallel.weight.grad, msg=msg,
)
parallel_state.destroy_model_parallel()
def _affine_weight_init_test_impl(
self, init_device: str, is_column_parallel: bool
) -> None:
dim = int(not is_column_parallel)
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
weight_shape = (
(self.OUTPUT_SIZE_COEFF, input_size)
if is_column_parallel
else (output_size, self.INPUT_SIZE_COEFF)
)
weight = torch.empty(weight_shape)
set_random_seed(self.SEED)
sharding_dim_size = (
self.OUTPUT_SIZE_COEFF
if is_column_parallel
else self.INPUT_SIZE_COEFF
)
if init_device == "cpu":
layers._initialize_affine_weight_cpu(
weight,
output_size,
input_size,
sharding_dim_size,
dim,
nn.init.normal_,
params_dtype=torch.float32,
)
else:
layers._initialize_affine_weight_gpu(
weight, torch.nn.init.normal_, dim
)
# Target
set_random_seed(self.SEED)
if init_device == "cpu":
main_weight = torch.empty(output_size, input_size)
nn.init.normal_(main_weight)
curr_weight = torch.split(main_weight, sharding_dim_size, dim=dim)[
parallel_state.get_tensor_model_parallel_rank()
]
else:
curr_weight = torch.empty(*weight_shape)
nn.init.normal_(curr_weight)
self.assertEqual(
curr_weight, weight, msg=f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}")
parallel_state.destroy_model_parallel()
def test_affine_weight_init_column_parallel_cpu(self) -> None:
self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=True)
def test_affine_weight_init_column_parallel_gpu(self) -> None:
self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=True)
def test_affine_weight_init_row_parallel_cpu(self) -> None:
self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=False)
def test_affine_weight_init_row_parallel_gpu(self) -> None:
self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=False)
def test_row_parallel_linear(self) -> None:
self._row_parallel_linear_test_impl(False, False, False)
def test_row_parallel_linear_gradient_accumulation_fusion(self) -> None:
self._row_parallel_linear_test_impl(True, False, False)
def test_row_parallel_linear_gradient_accumulation_fusion_in_fp16(self) -> None:
self._row_parallel_linear_test_impl(True, True, False)
# fails on native ucc and torch ucc: ucc does not support reduce scatter
@unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >=2 GPUs")
def test_row_parallel_linear_sequence_parallel(self) -> None:
self._row_parallel_linear_test_impl(False, False, True)
# TODO(mkozuki): Merge this with `_column_parallel_linear_test_impl`
# Note that `input_is_parallel` is unique to `RowParallelLinear` which could make the merge complicated.
def _row_parallel_linear_test_impl(
self,
gradient_accumulation_fusion: bool,
accumulation_in_fp16: bool,
sequence_parallel_enabled: bool,
) -> None:
tensor_shape = (
self.SEQUENCE_LENGTH,
self.BATCH_SIZE,
self.HIDDEN_SIZE,
)
for tensor_model_parallel_world_size in range(
1 + int(sequence_parallel_enabled), self.world_size + 1
):
if self.world_size % tensor_model_parallel_world_size:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
set_random_seed(self.SEED)
linear = layers.RowParallelLinear(
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
keep_master_weight_for_test=True,
params_dtype=torch.float32,
use_cpu_initialization=True,
gradient_accumulation_fusion=gradient_accumulation_fusion,
accumulation_in_fp16=accumulation_in_fp16,
sequence_parallel_enabled=sequence_parallel_enabled,
# n.b.(mkozuki): RowParallelLinear is constructed with `input_is_parallel=True`
# by default, e.g. https://github.com/NVIDIA/NeMo/blob/782b4e1652aaa43c8be390d9\
# db0dc89544afa080/nemo/collections/nlp/modules/common/megatron/transformer.py#L204
input_is_parallel=True,
).cuda()
if accumulation_in_fp16:
linear = linear.half()
# Simulate the situation where fusion of weight grad calculation and gradient accumulation is enabled.
if gradient_accumulation_fusion:
with torch.no_grad():
linear.weight.main_grad = torch.zeros_like(linear.weight)
msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
with torch.no_grad():
orig_input_tensor = torch.randn(tensor_shape, requires_grad=True, device="cuda")
orig_loss_weight = torch.randn(tensor_shape, device="cuda")
input_tensor = orig_input_tensor.chunk(
chunks=tensor_model_parallel_world_size,
dim=2,
)[parallel_state.get_tensor_model_parallel_rank()].contiguous()
if sequence_parallel_enabled:
loss_weight = orig_loss_weight.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()]
else:
loss_weight = orig_loss_weight
if accumulation_in_fp16:
orig_input_tensor = orig_input_tensor.half()
input_tensor = input_tensor.half()
loss_weight = loss_weight.half()
input_tensor.requires_grad_()
output, _ = linear(input_tensor)
loss = torch.mul(output, loss_weight).sum()
loss.backward()
self.assertIsNotNone(input_tensor.grad, msg=msg)
ref_linear = nn.Linear(
in_features=self.HIDDEN_SIZE,
out_features=self.HIDDEN_SIZE,
bias=False,
device="cuda",
)
with torch.no_grad():
dldy = orig_loss_weight.clone()
x = orig_input_tensor.clone()
ref_linear.weight.copy_(linear.master_weight)
if accumulation_in_fp16:
ref_linear = ref_linear.half()
x.requires_grad_()
expected_output = ref_linear(x)
expected_loss = torch.mul(expected_output, dldy).sum()
expected_loss.backward()
if not accumulation_in_fp16:
if sequence_parallel_enabled:
self.assertEqual(
x=output,
y=expected_output.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()],
msg=msg,
)
else:
self.assertEqual(
x=output,
y=expected_output,
msg=msg,
)
grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad"
# NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel.
if tensor_model_parallel_world_size == 1:
self.assertEqual(
x=getattr(linear.weight, grad_attr_name),
y=ref_linear.weight.grad.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()],
msg=msg,
)
parallel_state.destroy_model_parallel()
def test_column_parallel_linear(self):
self._column_parallel_linear_test_impl(False, False, False, False)
def test_column_parallel_linear_async(self):
self._column_parallel_linear_test_impl(True, False, False, False)
def test_column_parallel_linear_gradient_accumulation_fusion(self):
self._column_parallel_linear_test_impl(False, True, False, False)
def test_column_parallel_linear_gradient_accumulation_fusion_in_fp16(self):
self._column_parallel_linear_test_impl(False, True, True, False)
def test_column_parallel_linear_sequence_parallel(self):
if self.DISTRIBUTED_BACKEND == "ucc":
self.skipTest("Backward's reduce_scatter fails. as of 2022/06/15")
self._column_parallel_linear_test_impl(False, False, False, True)
@unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >= 2 GPUs")
def test_column_parallel_linear_exception(self):
with self.assertRaisesRegex(
RuntimeError,
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.",
):
self._column_parallel_linear_test_impl(True, False, False, True)
def _column_parallel_linear_test_impl(
self,
async_tensor_model_parallel_allreduce: bool,
gradient_accumulation_fusion: bool,
accumulation_in_fp16: bool,
sequence_parallel_enabled: bool,
):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if async_tensor_model_parallel_allreduce and sequence_parallel_enabled:
if tensor_model_parallel_world_size == 1:
continue
if self.world_size % tensor_model_parallel_world_size:
continue
msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
)
input_tensor_shape = self.tensor_shape
expected_output_shape = self.tensor_shape
# When sequence parallel, `gather_output` is disabled, i.e.,
# output of matmul isn't gathered in dimension of feature/hidden (last dim).
if sequence_parallel_enabled:
expected_output_shape[-1] //= tensor_model_parallel_world_size
# tensor's shape is [sequence length, batch size, hidden size]
set_random_seed(self.SEED)
linear = layers.ColumnParallelLinear(
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
bias=False,
keep_master_weight_for_test=True,
params_dtype=torch.float32,
use_cpu_initialization=True,
gather_output=not sequence_parallel_enabled,
no_async_tensor_model_parallel_allreduce=not async_tensor_model_parallel_allreduce,
gradient_accumulation_fusion=gradient_accumulation_fusion,
accumulation_in_fp16=accumulation_in_fp16,
sequence_parallel_enabled=sequence_parallel_enabled,
).cuda()
if accumulation_in_fp16:
linear = linear.half()
# Simulate the situation where fusion of weight grad calculation and gradient accumulation happens.
if gradient_accumulation_fusion:
with torch.no_grad():
linear.weight.main_grad = torch.zeros_like(linear.weight)
orig_input_tensor = torch.randn(input_tensor_shape, device="cuda", requires_grad=True)
if accumulation_in_fp16:
orig_input_tensor = orig_input_tensor.half()
if sequence_parallel_enabled:
input_tensor = list(
orig_input_tensor.chunk(tensor_model_parallel_world_size, dim=0)
)[parallel_state.get_tensor_model_parallel_rank()]
else:
input_tensor = orig_input_tensor
output, _ = linear(input_tensor)
# The order of dimension is expected to be (sequence, batch, hidden)
self.assertEqual(output.shape, expected_output_shape, msg=msg)
orig_loss_weight = torch.randn(input_tensor_shape, device="cuda")
if accumulation_in_fp16:
orig_loss_weight = orig_loss_weight.half()
if sequence_parallel_enabled:
loss_weight = orig_loss_weight.chunk(
tensor_model_parallel_world_size, dim=2,
)[parallel_state.get_tensor_model_parallel_rank()]
else:
loss_weight = orig_loss_weight
loss = torch.mul(output, loss_weight).sum()
loss.backward()
with torch.no_grad():
dldy = orig_loss_weight.clone()
x = orig_input_tensor.clone()
ref_linear = nn.Linear(
in_features=self.HIDDEN_SIZE,
out_features=self.HIDDEN_SIZE,
bias=False,
device="cuda",
)
if accumulation_in_fp16:
ref_linear = ref_linear.half()
# NOTE(mkozuki): `master_weight` is available because `keep_master_weight_for_test` is set.
ref_linear.weight.copy_(linear.master_weight)
x.requires_grad_()
expected_output = ref_linear(x)
if sequence_parallel_enabled:
chunk = expected_output.chunk(
tensor_model_parallel_world_size,
dim=2,
)[parallel_state.get_tensor_model_parallel_rank()]
self.assertEqual(
x=output,
y=chunk,
msg=msg,
)
else:
self.assertEqual(
x=output,
y=expected_output,
msg=msg,
)
expected_loss = torch.mul(expected_output, dldy).sum()
expected_loss.backward()
grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad"
# NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel.
if tensor_model_parallel_world_size == 1:
self.assertEqual(
x=getattr(linear.weight, grad_attr_name),
y=ref_linear.weight.grad.chunk(
chunks=tensor_model_parallel_world_size,
dim=0,
)[parallel_state.get_tensor_model_parallel_rank()],
msg=msg,
)
parallel_state.destroy_model_parallel()
class NcclTensorParallelLayerTest(TensorParallelLayerTestBase, NcclDistributedTestBase):
pass
class UccTensorParallelLayerTest(TensorParallelLayerTestBase, UccDistributedTestBase):
pass
if __name__ == "__main__":
common_utils.run_tests()