Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -79,7 +79,7 @@ def compute_splade_vector(text: str) -> models.SparseVector:
|
|
| 79 |
with torch.no_grad():
|
| 80 |
output = splade_model(**tokens)
|
| 81 |
|
| 82 |
-
logits, attention_mask = output.logits, tokens
|
| 83 |
relu_log = torch.log(1 + torch.relu(logits))
|
| 84 |
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
| 85 |
max_val, _ = torch.max(weighted_log, dim=1)
|
|
|
|
| 79 |
with torch.no_grad():
|
| 80 |
output = splade_model(**tokens)
|
| 81 |
|
| 82 |
+
logits, attention_mask = output.logits, tokens['attention_mask']
|
| 83 |
relu_log = torch.log(1 + torch.relu(logits))
|
| 84 |
weighted_log = relu_log * attention_mask.unsqueeze(-1)
|
| 85 |
max_val, _ = torch.max(weighted_log, dim=1)
|