Blaise-g commited on
Commit
ece3bd3
Β·
1 Parent(s): ac2835f

Update summarize.py

Browse files
Files changed (1) hide show
  1. summarize.py +2 -2
summarize.py CHANGED
@@ -45,9 +45,9 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
45
  input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
46
  attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
47
 
48
- global_attention_mask = torch.zeros_like(attention_mask)
49
  # put global attention on <s> token
50
- global_attention_mask[:, 0] = 1
51
 
52
  summary_pred_ids = model.generate(
53
  input_ids,
 
45
  input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
46
  attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
47
 
48
+ #global_attention_mask = torch.zeros_like(attention_mask)
49
  # put global attention on <s> token
50
+ #global_attention_mask[:, 0] = 1
51
 
52
  summary_pred_ids = model.generate(
53
  input_ids,