Spaces:
Runtime error
Runtime error
import torch | |
import argparse | |
import os | |
from apex import amp | |
# FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead) | |
from apex.parallel import DistributedDataParallel | |
parser = argparse.ArgumentParser() | |
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied | |
# automatically by torch.distributed.launch. | |
parser.add_argument("--local_rank", default=0, type=int) | |
args = parser.parse_args() | |
# FOR DISTRIBUTED: If we are running under torch.distributed.launch, | |
# the 'WORLD_SIZE' environment variable will also be set automatically. | |
args.distributed = False | |
if 'WORLD_SIZE' in os.environ: | |
args.distributed = int(os.environ['WORLD_SIZE']) > 1 | |
if args.distributed: | |
# FOR DISTRIBUTED: Set the device according to local_rank. | |
torch.cuda.set_device(args.local_rank) | |
# FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide | |
# environment variables, and requires that you use init_method=`env://`. | |
torch.distributed.init_process_group(backend='nccl', | |
init_method='env://') | |
torch.backends.cudnn.benchmark = True | |
N, D_in, D_out = 64, 1024, 16 | |
# Each process receives its own batch of "fake input data" and "fake target data." | |
# The "training loop" in each process just uses this fake batch over and over. | |
# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic | |
# example of distributed data sampling for both training and validation. | |
x = torch.randn(N, D_in, device='cuda') | |
y = torch.randn(N, D_out, device='cuda') | |
model = torch.nn.Linear(D_in, D_out).cuda() | |
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") | |
if args.distributed: | |
# FOR DISTRIBUTED: After amp.initialize, wrap the model with | |
# apex.parallel.DistributedDataParallel. | |
model = DistributedDataParallel(model) | |
# torch.nn.parallel.DistributedDataParallel is also fine, with some added args: | |
# model = torch.nn.parallel.DistributedDataParallel(model, | |
# device_ids=[args.local_rank], | |
# output_device=args.local_rank) | |
loss_fn = torch.nn.MSELoss() | |
for t in range(500): | |
optimizer.zero_grad() | |
y_pred = model(x) | |
loss = loss_fn(y_pred, y) | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
optimizer.step() | |
if args.local_rank == 0: | |
print("final loss = ", loss) | |