Update summarize.py
Browse files- summarize.py +2 -2
summarize.py
CHANGED
@@ -45,9 +45,9 @@ def summarize_and_score(ids, mask, model, tokenizer, **kwargs):
|
|
45 |
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
|
46 |
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
|
47 |
|
48 |
-
global_attention_mask = torch.zeros_like(attention_mask)
|
49 |
# put global attention on <s> token
|
50 |
-
global_attention_mask[:, 0] = 1
|
51 |
|
52 |
summary_pred_ids = model.generate(
|
53 |
input_ids,
|
|
|
45 |
input_ids = ids.to("cuda") if torch.cuda.is_available() else ids
|
46 |
attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask
|
47 |
|
48 |
+
#global_attention_mask = torch.zeros_like(attention_mask)
|
49 |
# put global attention on <s> token
|
50 |
+
#global_attention_mask[:, 0] = 1
|
51 |
|
52 |
summary_pred_ids = model.generate(
|
53 |
input_ids,
|