WwYc commited on
Commit
d5812e2
·
verified ·
1 Parent(s): b2c3238

Update lxmert/src/ExplanationGenerator.py

Browse files
Files changed (1) hide show
  1. lxmert/src/ExplanationGenerator.py +1 -1
lxmert/src/ExplanationGenerator.py CHANGED
@@ -317,7 +317,7 @@ class GeneratorOursAblationNoAggregation:
317
  one_hot[0, index] = 1
318
  one_hot_vector = one_hot
319
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
320
- one_hot = torch.sum(one_hot.cuda() * output)
321
 
322
  model.zero_grad()
323
  one_hot.backward(retain_graph=True)
 
317
  one_hot[0, index] = 1
318
  one_hot_vector = one_hot
319
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
320
+ one_hot = torch.sum(one_hot * output)
321
 
322
  model.zero_grad()
323
  one_hot.backward(retain_graph=True)