Justin44 commited on
Commit
4c62e5e
·
verified ·
1 Parent(s): ee5d451

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -79,7 +79,7 @@ def compute_splade_vector(text: str) -> models.SparseVector:
79
  with torch.no_grad():
80
  output = splade_model(**tokens)
81
 
82
- logits, attention_mask = output.logits, tokens.attention_mask
83
  relu_log = torch.log(1 + torch.relu(logits))
84
  weighted_log = relu_log * attention_mask.unsqueeze(-1)
85
  max_val, _ = torch.max(weighted_log, dim=1)
 
79
  with torch.no_grad():
80
  output = splade_model(**tokens)
81
 
82
+ logits, attention_mask = output.logits, tokens['attention_mask']
83
  relu_log = torch.log(1 + torch.relu(logits))
84
  weighted_log = relu_log * attention_mask.unsqueeze(-1)
85
  max_val, _ = torch.max(weighted_log, dim=1)