Spaces:
Runtime error
Runtime error
File size: 4,445 Bytes
8a42f8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import logging
import torch
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class TransformerRandomTestBase:
def test_set_cuda_rng_state(self):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
size, seed = 123, 1234
torch.cuda.manual_seed(seed)
tensor = torch.cuda.FloatTensor(size)
rng_state = torch.cuda.get_rng_state()
rng_state_clone = rng_state.clone()
for _ in range(5):
torch.randn(size, out=tensor)
result_1 = tensor.clone()
self.assertEqual(rng_state.sub(rng_state_clone).max(), 0, msg=msg)
self.assertGreater(
torch.cuda.get_rng_state().sub(rng_state_clone).max(), 0,
msg=msg,
)
new_rng_state = torch.cuda.get_rng_state()
self.assertGreater(new_rng_state.sub(rng_state).max(), 0, msg=msg)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
tensor_parallel.random._set_cuda_rng_state(rng_state)
for _ in range(5):
torch.randn(size, out=tensor)
result_2 = tensor.clone()
self.assertEqual(result_2, result_1, msg=msg)
self.assertEqual(rng_state.sub(rng_state_clone).max(), 0, msg=msg)
parallel_state.destroy_model_parallel()
def test_cuda_rng_tracker(self):
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
if self.world_size % tensor_model_parallel_world_size:
continue
msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size
)
seed_1, seed_2, size = 1234, 4321, [12, 21]
tensor = torch.cuda.FloatTensor(size)
torch.cuda.manual_seed(seed_1)
torch.randn(size, out=tensor)
target_11 = tensor.clone()
torch.randn(size, out=tensor)
target_12 = tensor.clone()
torch.cuda.manual_seed(seed_2)
torch.randn(size, out=tensor)
targt_21 = tensor.clone()
torch.randn(size, out=tensor)
target_22 = tensor.clone()
torch.cuda.manual_seed(seed_1)
tensor_parallel.random.get_cuda_rng_tracker().add("test", seed_2)
torch.randn(size, out=tensor)
result_11 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork("test"):
torch.randn(size, out=tensor)
result_21 = tensor.clone()
torch.randn(size, out=tensor)
result_12 = tensor.clone()
with tensor_parallel.random.get_cuda_rng_tracker().fork("test"):
torch.randn(size, out=tensor)
result_22 = tensor.clone()
self.assertEqual(target_11, result_11, msg=msg)
self.assertEqual(target_12, result_12, msg=msg)
self.assertEqual(targt_21, result_21, msg=msg)
self.assertEqual(target_22, result_22, msg=msg)
self.assertNotEqual(result_11, result_21, msg=msg)
self.assertNotEqual(result_21, result_22, msg=msg)
tensor_parallel.random.get_cuda_rng_tracker().reset()
parallel_state.destroy_model_parallel()
class NcclTransformerRandomTest(TransformerRandomTestBase, NcclDistributedTestBase): pass
class UccTransformerRandomTest(TransformerRandomTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()
|