File size: 841 Bytes
57bdca5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
This means that the model will have at least 512 tokens for context when calculating the conditional likelihood of any one token (provided there are 512 preceding tokens available to condition on). thon import torch from tqdm import tqdm max_length = model.config.n_positions stride = 512 seq_len = encodings.input_ids.size(1) nlls = [] prev_end_loc = 0 for begin_loc in tqdm(range(0, seq_len, stride)): end_loc = min(begin_loc + max_length, seq_len) trg_len = end_loc - prev_end_loc # may be different from stride on last loop input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) target_ids = input_ids.clone() target_ids[:, :-trg_len] = -100 with torch.no_grad(): outputs = model(input_ids, labels=target_ids) # loss is calculated using CrossEntropyLoss which averages over valid labels # N.B. |