Spaces:
Runtime error
Runtime error
import logging | |
import torch | |
from torch.testing._internal import common_utils | |
from apex.transformer import parallel_state | |
from apex.transformer.tensor_parallel import mappings | |
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase | |
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase | |
logging.getLogger("torch").setLevel(logging.WARNING) | |
logging.getLogger("apex").setLevel(logging.WARNING) | |
class MappingTestBase: | |
def test_reduce(self): | |
for tensor_model_paralell_world_size in range(1, self.world_size + 1): | |
if self.world_size % tensor_model_paralell_world_size > 0: | |
continue | |
parallel_state.initialize_model_parallel( | |
tensor_model_parallel_size_=tensor_model_paralell_world_size | |
) | |
t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}") | |
expected = torch.full( | |
(10, 10, 10, 10), | |
50 * tensor_model_paralell_world_size, | |
device=f"cuda:{self.rank}", | |
) | |
self.assertTrue( | |
torch.equal(mappings._reduce(t), expected), | |
msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", | |
) | |
parallel_state.destroy_model_parallel() | |
def test_split(self): | |
for tensor_model_paralell_world_size in range(1, self.world_size + 1): | |
if self.world_size % tensor_model_paralell_world_size > 0: | |
continue | |
parallel_state.initialize_model_parallel( | |
tensor_model_parallel_size_=tensor_model_paralell_world_size | |
) | |
tensors = [ | |
torch.randn(10, 1) | |
for _ in range(tensor_model_paralell_world_size) | |
] | |
x = torch.cat(tensors, 1) | |
out = mappings._split_along_last_dim(x) | |
self.assertTrue( | |
torch.equal( | |
out, tensors[parallel_state.get_tensor_model_parallel_rank()] | |
), | |
msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}" | |
) | |
parallel_state.destroy_model_parallel() | |
def test_gather(self): | |
for tensor_model_paralell_world_size in range(1, self.world_size + 1): | |
if self.world_size % tensor_model_paralell_world_size > 0: | |
continue | |
parallel_state.initialize_model_parallel( | |
tensor_model_parallel_size_=tensor_model_paralell_world_size | |
) | |
device = f"cuda:{self.rank}" | |
gathered = mappings._gather_along_last_dim( | |
torch.tensor( | |
[parallel_state.get_tensor_model_parallel_rank()], device=device | |
) | |
) | |
expected = torch.tensor( | |
[rank for rank in range(tensor_model_paralell_world_size)], | |
device=device, | |
) | |
self.assertTrue( | |
torch.equal(gathered, expected), | |
msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}", | |
) | |
parallel_state.destroy_model_parallel() | |
class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass | |
class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass | |
if __name__ == "__main__": | |
common_utils.run_tests() | |