caesar-one commited on
Commit
624a902
·
verified ·
1 Parent(s): 39157e5

Upload ConstBERT

Browse files
Files changed (1) hide show
  1. modeling.py +2 -2
modeling.py CHANGED
@@ -21,10 +21,10 @@ class MixedPrecisionManager():
21
  self.activated = activated
22
 
23
  if self.activated:
24
- self.scaler = torch.cuda.amp.GradScaler()
25
 
26
  def context(self):
27
- return torch.cuda.amp.autocast() if self.activated else NullContextManager()
28
 
29
  def backward(self, loss):
30
  if self.activated:
 
21
  self.activated = activated
22
 
23
  if self.activated:
24
+ self.scaler = torch.amp.GradScaler("cuda")
25
 
26
  def context(self):
27
+ return torch.amp.autocast("cuda") if self.activated else NullContextManager()
28
 
29
  def backward(self, loss):
30
  if self.activated: