kadirnar's picture
Upload 494 files
8a42f8f verified
raw
history blame
1.34 kB
import unittest
import torch
from torch import nn
from torch.nn import Parameter
from apex import amp
from apex.parallel.LARC import LARC
from utils import common_init
class MyModel(torch.nn.Module):
def __init__(self, unique):
super(MyModel, self).__init__()
self.weight0 = Parameter(
unique + torch.arange(2, device="cuda", dtype=torch.float32)
)
def forward(self, input):
return (input * self.weight0).sum()
class TestLARC(unittest.TestCase):
def setUp(self):
self.x = torch.ones((2), device="cuda", dtype=torch.float32)
common_init(self)
def tearDown(self):
pass
def test_larc_mixed_precision(self):
for opt_level in ["O0", "O1", "O2", "O3"]:
model = MyModel(1)
optimizer = LARC(
torch.optim.SGD(
[{"params": model.parameters(), "lr": 0.25}], momentum=0.125
)
)
model, optimizer = amp.initialize(
model, optimizer, opt_level=opt_level, verbosity=0
)
optimizer.zero_grad()
loss = model(self.x)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if __name__ == "__main__":
unittest.main()