Spaces:
Runtime error
Runtime error
| import torch | |
| import argparse | |
| import os | |
| from apex import amp | |
| # FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead) | |
| from apex.parallel import DistributedDataParallel | |
| parser = argparse.ArgumentParser() | |
| # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied | |
| # automatically by torch.distributed.launch. | |
| parser.add_argument("--local_rank", default=0, type=int) | |
| args = parser.parse_args() | |
| # FOR DISTRIBUTED: If we are running under torch.distributed.launch, | |
| # the 'WORLD_SIZE' environment variable will also be set automatically. | |
| args.distributed = False | |
| if 'WORLD_SIZE' in os.environ: | |
| args.distributed = int(os.environ['WORLD_SIZE']) > 1 | |
| if args.distributed: | |
| # FOR DISTRIBUTED: Set the device according to local_rank. | |
| torch.cuda.set_device(args.local_rank) | |
| # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide | |
| # environment variables, and requires that you use init_method=`env://`. | |
| torch.distributed.init_process_group(backend='nccl', | |
| init_method='env://') | |
| torch.backends.cudnn.benchmark = True | |
| N, D_in, D_out = 64, 1024, 16 | |
| # Each process receives its own batch of "fake input data" and "fake target data." | |
| # The "training loop" in each process just uses this fake batch over and over. | |
| # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic | |
| # example of distributed data sampling for both training and validation. | |
| x = torch.randn(N, D_in, device='cuda') | |
| y = torch.randn(N, D_out, device='cuda') | |
| model = torch.nn.Linear(D_in, D_out).cuda() | |
| optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |
| model, optimizer = amp.initialize(model, optimizer, opt_level="O1") | |
| if args.distributed: | |
| # FOR DISTRIBUTED: After amp.initialize, wrap the model with | |
| # apex.parallel.DistributedDataParallel. | |
| model = DistributedDataParallel(model) | |
| # torch.nn.parallel.DistributedDataParallel is also fine, with some added args: | |
| # model = torch.nn.parallel.DistributedDataParallel(model, | |
| # device_ids=[args.local_rank], | |
| # output_device=args.local_rank) | |
| loss_fn = torch.nn.MSELoss() | |
| for t in range(500): | |
| optimizer.zero_grad() | |
| y_pred = model(x) | |
| loss = loss_fn(y_pred, y) | |
| with amp.scale_loss(loss, optimizer) as scaled_loss: | |
| scaled_loss.backward() | |
| optimizer.step() | |
| if args.local_rank == 0: | |
| print("final loss = ", loss) | |