Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
from multiprocessing import Manager | |
import random | |
import unittest | |
import torch | |
import torch.nn as nn | |
from fairseq import distributed_utils, optim | |
class Model(nn.Module): | |
def __init__(self, input_size, output_size): | |
super(Model, self).__init__() | |
self.fc = nn.Linear(input_size, output_size) | |
def forward(self, input): | |
output = self.fc(input) | |
return output | |
def setup_model_loss_criterion(args, rank, is_cuda): | |
""" | |
setup model, criterion and optimizer based on input args | |
""" | |
args.distributed_rank = rank | |
distributed_utils.distributed_init(args) | |
torch.manual_seed(1) | |
model = Model(args.input_size, args.nb_classes) | |
loss_fn = nn.CrossEntropyLoss() | |
if is_cuda: | |
model = model.cuda() | |
loss_fn = loss_fn.cuda() | |
optimizer = optim.sgd.SGD(args, model.parameters()) | |
optimizer = optim.FairseqBMUF(args, optimizer) | |
return model, loss_fn, optimizer | |
def train_step(input, target, model, loss_fn, optimizer): | |
"""Do forward, backward and parameter update.""" | |
model.train() | |
output = model(input) | |
loss = loss_fn(output, target) | |
optimizer.backward(loss) | |
optimizer.step() | |
def single_gpu_training(args, rank, iterations, shared_results): | |
is_cuda = torch.cuda.is_available() | |
if is_cuda: | |
torch.cuda.set_device(rank) | |
model, loss_fn, optimizer = setup_model_loss_criterion(args, rank, is_cuda) | |
for _ in range(iterations): | |
input = torch.randn(1, args.input_size) | |
target = torch.empty(args.batch_size, dtype=torch.long).random_(args.nb_classes) | |
if is_cuda: | |
input = input.cuda() | |
target = target.cuda() | |
train_step(input, target, model, loss_fn, optimizer) | |
results = [] | |
for param in model.parameters(): | |
if len(results) == 0: | |
results = param.flatten().cpu().data | |
else: | |
results = torch.cat((results, param.flatten().cpu().data), 0) | |
shared_results[rank] = results | |
def setup_args(): | |
args = argparse.Namespace() | |
args.global_sync_iter = 20 | |
args.block_momentum = 0.875 | |
args.block_lr = 0.5 | |
args.input_size = 5 | |
args.nb_classes = 2 | |
args.batch_size = 1 | |
args.lr = [1e-3] | |
args.momentum = 0 | |
args.weight_decay = 0 | |
args.warmup_iterations = 0 | |
args.use_nbm = True | |
args.average_sync = True | |
args.global_sync_iter = 1 | |
args.distributed_backend = "gloo" | |
args.distributed_world_size = 2 | |
port = random.randint(10000, 20000) | |
args.distributed_init_method = "tcp://localhost:{port}".format(port=port) | |
args.distributed_init_host = "localhost" | |
args.distributed_port = port + 1 | |
args.local_world_size = args.distributed_world_size | |
return args | |
class TestBMUF(unittest.TestCase): | |
def bmuf_process(self, args, iterations): | |
processes = [] | |
results = Manager().dict() | |
ctx = torch.multiprocessing.get_context("spawn") | |
for rank in range(args.distributed_world_size): | |
p = ctx.Process( | |
target=single_gpu_training, args=(args, rank, iterations, results) | |
) | |
p.start() | |
processes.append(p) | |
for p in processes: | |
p.join() | |
# Make sure params in both machines are same | |
assert len(results) == 2 | |
self.assertAlmostEqual(results[0], results[1]) | |
def test_bmuf_sync(self): | |
# Train model for 1 iteration and do bmuf sync without doing warmup | |
args = setup_args() | |
iterations = 1 | |
self.bmuf_process(args, iterations) | |
def test_warmup_sync(self): | |
# Train model for 20 iteration and do warmup sync without doing bmuf sync | |
args = setup_args() | |
args.warmup_iterations = 20 | |
iterations = 20 | |
self.bmuf_process(args, iterations) | |
def test_warmup_sync_bmuf_sync(self): | |
# Train model for 25 iteration and do warmup sync after 20 iteration | |
# and bmuf sync after 25 iteration | |
args = setup_args() | |
args.warmup_iterations = 20 | |
args.global_sync_iter = 5 | |
iterations = 25 | |
self.bmuf_process(args, iterations) | |
def assertAlmostEqual(self, t1, t2): | |
self.assertEqual(t1.size(), t2.size(), "size mismatch") | |
self.assertLess((t1 - t2).abs().max(), 1e-4) | |
if __name__ == '__main__': | |
unittest.main() | |