thon | |
import torch | |
from tqdm import tqdm | |
max_length = model.config.n_positions | |
stride = 512 | |
seq_len = encodings.input_ids.size(1) | |
nlls = [] | |
prev_end_loc = 0 | |
for begin_loc in tqdm(range(0, seq_len, stride)): | |
end_loc = min(begin_loc + max_length, seq_len) | |
trg_len = end_loc - prev_end_loc # may be different from stride on last loop | |
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) | |
target_ids = input_ids.clone() | |
target_ids[:, :-trg_len] = -100 | |
with torch.no_grad(): | |
outputs = model(input_ids, labels=target_ids) | |
# loss is calculated using CrossEntropyLoss which averages over valid labels | |
# N.B. |