Varun Wadhwa commited on
Commit
376eac5
·
unverified ·
1 Parent(s): cb2cd7f
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -147,7 +147,7 @@ def evaluate_model(model, dataloader, device):
147
 
148
  # Get predictions
149
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
150
- mask = attention_mask.cpu().numpy().astype(bool)
151
 
152
  # Process each sequence in the batch
153
  for i in range(current_batch_size):
 
147
 
148
  # Get predictions
149
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
150
+ mask = attention_mask
151
 
152
  # Process each sequence in the batch
153
  for i in range(current_batch_size):