ALeLacheur commited on
Commit
268fcc9
·
verified ·
1 Parent(s): 8f644ec

Update audio_diffusion_attacks_forhf/src/balancer.py

Browse files
audio_diffusion_attacks_forhf/src/balancer.py CHANGED
@@ -90,6 +90,8 @@ class Balancer:
90
  grads = {}
91
  for name, loss in losses.items():
92
  # Compute partial derivative of the less with respect to the input.
 
 
93
  grad, = autograd.grad(loss, [input], retain_graph=True)
94
  if self.per_batch_item:
95
  # We do not average the gradient over the batch dimension.
 
90
  grads = {}
91
  for name, loss in losses.items():
92
  # Compute partial derivative of the less with respect to the input.
93
+ #Andy added:
94
+ loss.requires_grad = True
95
  grad, = autograd.grad(loss, [input], retain_graph=True)
96
  if self.per_batch_item:
97
  # We do not average the gradient over the batch dimension.