Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import functools | |
| import sys | |
| import unittest | |
| import torch | |
| from fairseq.distributed import utils as dist_utils | |
| from .utils import objects_are_equal, spawn_and_init | |
| class DistributedTest(unittest.TestCase): | |
| def setUp(self): | |
| if not torch.cuda.is_available(): | |
| raise unittest.SkipTest("CUDA not available, skipping test") | |
| if sys.platform == "win32": | |
| raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") | |
| if torch.cuda.device_count() < 2: | |
| raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") | |
| class TestBroadcastObject(DistributedTest): | |
| def test_str(self): | |
| spawn_and_init( | |
| functools.partial( | |
| TestBroadcastObject._test_broadcast_object, "hello world" | |
| ), | |
| world_size=2, | |
| ) | |
| def test_tensor(self): | |
| spawn_and_init( | |
| functools.partial( | |
| TestBroadcastObject._test_broadcast_object, | |
| torch.rand(5), | |
| ), | |
| world_size=2, | |
| ) | |
| def test_complex(self): | |
| spawn_and_init( | |
| functools.partial( | |
| TestBroadcastObject._test_broadcast_object, | |
| { | |
| "a": "1", | |
| "b": [2, torch.rand(2, 3), 3], | |
| "c": (torch.rand(2, 3), 4), | |
| "d": {5, torch.rand(5)}, | |
| "e": torch.rand(5), | |
| "f": torch.rand(5).int().cuda(), | |
| }, | |
| ), | |
| world_size=2, | |
| ) | |
| def _test_broadcast_object(ref_obj, rank, group): | |
| obj = dist_utils.broadcast_object( | |
| ref_obj if rank == 0 else None, src_rank=0, group=group | |
| ) | |
| assert objects_are_equal(ref_obj, obj) | |
| class TestAllGatherList(DistributedTest): | |
| def test_str_equality(self): | |
| spawn_and_init( | |
| functools.partial( | |
| TestAllGatherList._test_all_gather_list_equality, | |
| "hello world", | |
| ), | |
| world_size=2, | |
| ) | |
| def test_tensor_equality(self): | |
| spawn_and_init( | |
| functools.partial( | |
| TestAllGatherList._test_all_gather_list_equality, | |
| torch.rand(5), | |
| ), | |
| world_size=2, | |
| ) | |
| def test_complex_equality(self): | |
| spawn_and_init( | |
| functools.partial( | |
| TestAllGatherList._test_all_gather_list_equality, | |
| { | |
| "a": "1", | |
| "b": [2, torch.rand(2, 3), 3], | |
| "c": (torch.rand(2, 3), 4), | |
| "d": {5, torch.rand(5)}, | |
| "e": torch.rand(5), | |
| "f": torch.rand(5).int(), | |
| }, | |
| ), | |
| world_size=2, | |
| ) | |
| def _test_all_gather_list_equality(ref_obj, rank, group): | |
| objs = dist_utils.all_gather_list(ref_obj, group) | |
| for obj in objs: | |
| assert objects_are_equal(ref_obj, obj) | |
| def test_rank_tensor(self): | |
| spawn_and_init( | |
| TestAllGatherList._test_all_gather_list_rank_tensor, world_size=2 | |
| ) | |
| def _test_all_gather_list_rank_tensor(rank, group): | |
| obj = torch.tensor([rank]) | |
| objs = dist_utils.all_gather_list(obj, group) | |
| for i, obj in enumerate(objs): | |
| assert obj.item() == i | |
| if __name__ == "__main__": | |
| unittest.main() | |