| thon | |
| import torch | |
| from tqdm import tqdm | |
| max_length = model.config.n_positions | |
| stride = 512 | |
| seq_len = encodings.input_ids.size(1) | |
| nlls = [] | |
| prev_end_loc = 0 | |
| for begin_loc in tqdm(range(0, seq_len, stride)): | |
| end_loc = min(begin_loc + max_length, seq_len) | |
| trg_len = end_loc - prev_end_loc # may be different from stride on last loop | |
| input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) | |
| target_ids = input_ids.clone() | |
| target_ids[:, :-trg_len] = -100 | |
| with torch.no_grad(): | |
| outputs = model(input_ids, labels=target_ids) | |
| # loss is calculated using CrossEntropyLoss which averages over valid labels | |
| # N.B. | 
