Spaces:
Runtime error
Runtime error
import logging | |
from typing import List, Optional | |
from torch.testing._internal import common_utils | |
logging.getLogger("torch").setLevel(logging.WARNING) | |
from apex.transformer import parallel_state | |
from apex.transformer.pipeline_parallel.utils import ( | |
_reconfigure_microbatch_calculator, | |
get_micro_batch_size, | |
get_num_microbatches, | |
get_current_global_batch_size, | |
update_num_microbatches, | |
) | |
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase | |
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase | |
logging.getLogger("apex").setLevel(logging.WARNING) | |
class MicrobatchCalculatorTestBase: | |
GLOBAL_BATCH_SIZE: int = 1024 | |
MICRO_BATCH_SIZE: int = 1 | |
def _test(self, rampup_batch_size: Optional[List[int]]) -> None: | |
for data_parallel_size in range(1, self.world_size + 1): | |
expected_global_batch_size = self.GLOBAL_BATCH_SIZE | |
expected_micro_batch_size = self.MICRO_BATCH_SIZE | |
if rampup_batch_size: | |
expected_global_batch_size = rampup_batch_size[0] | |
num_consumed_samples = 0 | |
step_of_global_batch_size = rampup_batch_size[1] | |
threshold = rampup_batch_size[2] | |
if data_parallel_size > 1 and data_parallel_size % 2 != 0: | |
continue | |
if self.world_size % data_parallel_size != 0: | |
continue | |
msg = f"data_parallel_size: {data_parallel_size}" | |
parallel_state.initialize_model_parallel( | |
tensor_model_parallel_size_=self.world_size // data_parallel_size, | |
pipeline_model_parallel_size_=1, | |
) | |
self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size(), msg=msg) | |
_reconfigure_microbatch_calculator( | |
self.rank, | |
rampup_batch_size, | |
self.GLOBAL_BATCH_SIZE, | |
self.MICRO_BATCH_SIZE, | |
data_parallel_size, | |
) | |
self.assertEqual(get_micro_batch_size(), expected_micro_batch_size, msg=msg) | |
self.assertEqual(get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size, msg=msg) | |
current_global_batch_size = get_current_global_batch_size() | |
self.assertEqual(current_global_batch_size, expected_global_batch_size, msg=msg) | |
# Make sure `global_batch_size` equals to the final global batch size after | |
# certain number of updates. | |
if rampup_batch_size: | |
update_num_microbatches(current_global_batch_size) | |
for i in range(100): | |
current_global_batch_size = get_current_global_batch_size() | |
update_num_microbatches(current_global_batch_size) | |
current_global_batch_size = get_current_global_batch_size() | |
self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE, msg=msg) | |
parallel_state.destroy_model_parallel() | |
def test_constant_microbatch_calculator(self): | |
self._test(rampup_batch_size=None) | |
def test_dynamic_microbatch_calculator(self): | |
self._test(rampup_batch_size=[256, 128, 500]) | |
class NcclMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, NcclDistributedTestBase): pass | |
class UccMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, UccDistributedTestBase): pass | |
if __name__ == "__main__": | |
common_utils.run_tests() | |