Spaces:
Runtime error
Runtime error
import unittest | |
import torch | |
from torch import nn | |
from torch.nn import Parameter | |
from apex import amp | |
from apex.parallel.LARC import LARC | |
from utils import common_init | |
class MyModel(torch.nn.Module): | |
def __init__(self, unique): | |
super(MyModel, self).__init__() | |
self.weight0 = Parameter( | |
unique + torch.arange(2, device="cuda", dtype=torch.float32) | |
) | |
def forward(self, input): | |
return (input * self.weight0).sum() | |
class TestLARC(unittest.TestCase): | |
def setUp(self): | |
self.x = torch.ones((2), device="cuda", dtype=torch.float32) | |
common_init(self) | |
def tearDown(self): | |
pass | |
def test_larc_mixed_precision(self): | |
for opt_level in ["O0", "O1", "O2", "O3"]: | |
model = MyModel(1) | |
optimizer = LARC( | |
torch.optim.SGD( | |
[{"params": model.parameters(), "lr": 0.25}], momentum=0.125 | |
) | |
) | |
model, optimizer = amp.initialize( | |
model, optimizer, opt_level=opt_level, verbosity=0 | |
) | |
optimizer.zero_grad() | |
loss = model(self.x) | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
optimizer.step() | |
if __name__ == "__main__": | |
unittest.main() | |