|
This means that the model will have at least 512 tokens |
|
for context when calculating the conditional likelihood of any one token (provided there are 512 preceding tokens |
|
available to condition on). |
|
thon |
|
import torch |
|
from tqdm import tqdm |
|
max_length = model.config.n_positions |
|
stride = 512 |
|
seq_len = encodings.input_ids.size(1) |
|
nlls = [] |
|
prev_end_loc = 0 |
|
for begin_loc in tqdm(range(0, seq_len, stride)): |
|
end_loc = min(begin_loc + max_length, seq_len) |
|
trg_len = end_loc - prev_end_loc # may be different from stride on last loop |
|
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) |
|
target_ids = input_ids.clone() |
|
target_ids[:, :-trg_len] = -100 |
|
with torch.no_grad(): |
|
outputs = model(input_ids, labels=target_ids) |
|
|
|
# loss is calculated using CrossEntropyLoss which averages over valid labels |
|
# N.B. |