File size: 205 Bytes
e8861c0
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import torch.nn as nn 
import torch.cuda.amp as amp


from src.core import register
import src.misc.dist as dist 


__all__ = ['GradScaler']

GradScaler = register(amp.grad_scaler.GradScaler)