Spaces:
Runtime error
Runtime error
File size: 3,406 Bytes
8a42f8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import logging
import torch
from torch.testing._internal import common_utils
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
class MappingTestBase:
def test_reduce(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}")
expected = torch.full(
(10, 10, 10, 10),
50 * tensor_model_paralell_world_size,
device=f"cuda:{self.rank}",
)
self.assertTrue(
torch.equal(mappings._reduce(t), expected),
msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
)
parallel_state.destroy_model_parallel()
def test_split(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
tensors = [
torch.randn(10, 1)
for _ in range(tensor_model_paralell_world_size)
]
x = torch.cat(tensors, 1)
out = mappings._split_along_last_dim(x)
self.assertTrue(
torch.equal(
out, tensors[parallel_state.get_tensor_model_parallel_rank()]
),
msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}"
)
parallel_state.destroy_model_parallel()
def test_gather(self):
for tensor_model_paralell_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_paralell_world_size > 0:
continue
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_paralell_world_size
)
device = f"cuda:{self.rank}"
gathered = mappings._gather_along_last_dim(
torch.tensor(
[parallel_state.get_tensor_model_parallel_rank()], device=device
)
)
expected = torch.tensor(
[rank for rank in range(tensor_model_paralell_world_size)],
device=device,
)
self.assertTrue(
torch.equal(gathered, expected),
msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
)
parallel_state.destroy_model_parallel()
class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass
class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
|