Spaces:
Runtime error
Runtime error
File size: 2,424 Bytes
8a42f8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch.nn import Module
from apex.parallel import DistributedDataParallel as DDP
import argparse
import os
parser = argparse.ArgumentParser(description='allreduce hook example')
parser.add_argument("--local_rank", default=0, type=int)
args = parser.parse_args()
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed:
args.gpu = args.local_rank % torch.cuda.device_count()
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
torch.set_printoptions(precision=10)
torch.manual_seed(args.local_rank)
class Model(Module):
def __init__(self):
super(Model, self).__init__()
self.a = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(1.0))
self.b = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(2.0))
def forward(self, input):
return (input*self.a)*self.b
model = Model()
# model = DDP(model, message_size=1, gradient_predivide_factor=8.0)
# model = DDP(model, delay_allreduce=True)
# model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])
model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3)
x = torch.cuda.FloatTensor(4096*4096)
passed = True
torch.cuda.cudart().cudaProfilerStart()
for i in range(10):
x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity
model.zero_grad()
out = model(x)
loss = out.sum()
# torch.cuda.nvtx.range_push("backward")
loss.backward()
# torch.cuda.nvtx.range_pop()
# torch.cuda.nvtx.range_push("synchronize() + info")
# torch.cuda.synchronize()
print("i = {}".format(i))
def info(name, param, val):
expected = val*4096*4096*(2.*i+1)/2.
actual = param.grad.data.sum().item()
print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format(
param.grad.data_ptr(), expected, actual))
return (expected == actual)
if not info("model.a", model.module.a, 2.): passed = False
if not info("model.b", model.module.b, 1.): passed = False
# torch.cuda.nvtx.range_pop()
torch.cuda.cudart().cudaProfilerStop()
print("passed = ", passed)
|