Open-Sora / apex /tests /distributed /DDP /ddp_race_condition_test.py
kadirnar's picture
Upload 494 files
8a42f8f verified
raw
history blame
2.42 kB
import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch.nn import Module
from apex.parallel import DistributedDataParallel as DDP
import argparse
import os
parser = argparse.ArgumentParser(description='allreduce hook example')
parser.add_argument("--local_rank", default=0, type=int)
args = parser.parse_args()
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed:
args.gpu = args.local_rank % torch.cuda.device_count()
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
torch.set_printoptions(precision=10)
torch.manual_seed(args.local_rank)
class Model(Module):
def __init__(self):
super(Model, self).__init__()
self.a = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(1.0))
self.b = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(2.0))
def forward(self, input):
return (input*self.a)*self.b
model = Model()
# model = DDP(model, message_size=1, gradient_predivide_factor=8.0)
# model = DDP(model, delay_allreduce=True)
# model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])
model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3)
x = torch.cuda.FloatTensor(4096*4096)
passed = True
torch.cuda.cudart().cudaProfilerStart()
for i in range(10):
x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity
model.zero_grad()
out = model(x)
loss = out.sum()
# torch.cuda.nvtx.range_push("backward")
loss.backward()
# torch.cuda.nvtx.range_pop()
# torch.cuda.nvtx.range_push("synchronize() + info")
# torch.cuda.synchronize()
print("i = {}".format(i))
def info(name, param, val):
expected = val*4096*4096*(2.*i+1)/2.
actual = param.grad.data.sum().item()
print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format(
param.grad.data_ptr(), expected, actual))
return (expected == actual)
if not info("model.a", model.module.a, 2.): passed = False
if not info("model.b", model.module.b, 1.): passed = False
# torch.cuda.nvtx.range_pop()
torch.cuda.cudart().cudaProfilerStop()
print("passed = ", passed)