Spaces:
Sleeping
Sleeping
Update lxmert/src/ExplanationGenerator.py
Browse files
lxmert/src/ExplanationGenerator.py
CHANGED
|
@@ -163,7 +163,7 @@ class GeneratorOurs:
|
|
| 163 |
one_hot[0, index] = 1
|
| 164 |
one_hot_vector = one_hot
|
| 165 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 166 |
-
one_hot = torch.sum(one_hot
|
| 167 |
|
| 168 |
model.zero_grad()
|
| 169 |
one_hot.backward(retain_graph=True)
|
|
@@ -400,7 +400,7 @@ class GeneratorBaselines:
|
|
| 400 |
one_hot[0, index] = 1
|
| 401 |
one_hot_vector = one_hot
|
| 402 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 403 |
-
one_hot = torch.sum(one_hot
|
| 404 |
|
| 405 |
model.zero_grad()
|
| 406 |
one_hot.backward(retain_graph=True)
|
|
|
|
| 163 |
one_hot[0, index] = 1
|
| 164 |
one_hot_vector = one_hot
|
| 165 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 166 |
+
one_hot = torch.sum(one_hot * output)
|
| 167 |
|
| 168 |
model.zero_grad()
|
| 169 |
one_hot.backward(retain_graph=True)
|
|
|
|
| 400 |
one_hot[0, index] = 1
|
| 401 |
one_hot_vector = one_hot
|
| 402 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
| 403 |
+
one_hot = torch.sum(one_hot * output)
|
| 404 |
|
| 405 |
model.zero_grad()
|
| 406 |
one_hot.backward(retain_graph=True)
|