Open-Sora / apex /tests /L0 /run_transformer /test_microbatches.py
kadirnar's picture
Upload 494 files
8a42f8f verified
raw
history blame
3.53 kB
import logging
from typing import List, Optional
from torch.testing._internal import common_utils
logging.getLogger("torch").setLevel(logging.WARNING)
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel.utils import (
_reconfigure_microbatch_calculator,
get_micro_batch_size,
get_num_microbatches,
get_current_global_batch_size,
update_num_microbatches,
)
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
logging.getLogger("apex").setLevel(logging.WARNING)
class MicrobatchCalculatorTestBase:
GLOBAL_BATCH_SIZE: int = 1024
MICRO_BATCH_SIZE: int = 1
def _test(self, rampup_batch_size: Optional[List[int]]) -> None:
for data_parallel_size in range(1, self.world_size + 1):
expected_global_batch_size = self.GLOBAL_BATCH_SIZE
expected_micro_batch_size = self.MICRO_BATCH_SIZE
if rampup_batch_size:
expected_global_batch_size = rampup_batch_size[0]
num_consumed_samples = 0
step_of_global_batch_size = rampup_batch_size[1]
threshold = rampup_batch_size[2]
if data_parallel_size > 1 and data_parallel_size % 2 != 0:
continue
if self.world_size % data_parallel_size != 0:
continue
msg = f"data_parallel_size: {data_parallel_size}"
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=self.world_size // data_parallel_size,
pipeline_model_parallel_size_=1,
)
self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size(), msg=msg)
_reconfigure_microbatch_calculator(
self.rank,
rampup_batch_size,
self.GLOBAL_BATCH_SIZE,
self.MICRO_BATCH_SIZE,
data_parallel_size,
)
self.assertEqual(get_micro_batch_size(), expected_micro_batch_size, msg=msg)
self.assertEqual(get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size, msg=msg)
current_global_batch_size = get_current_global_batch_size()
self.assertEqual(current_global_batch_size, expected_global_batch_size, msg=msg)
# Make sure `global_batch_size` equals to the final global batch size after
# certain number of updates.
if rampup_batch_size:
update_num_microbatches(current_global_batch_size)
for i in range(100):
current_global_batch_size = get_current_global_batch_size()
update_num_microbatches(current_global_batch_size)
current_global_batch_size = get_current_global_batch_size()
self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE, msg=msg)
parallel_state.destroy_model_parallel()
def test_constant_microbatch_calculator(self):
self._test(rampup_batch_size=None)
def test_dynamic_microbatch_calculator(self):
self._test(rampup_batch_size=[256, 128, 500])
class NcclMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, NcclDistributedTestBase): pass
class UccMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, UccDistributedTestBase): pass
if __name__ == "__main__":
common_utils.run_tests()