Spaces:
Runtime error
Runtime error
Upload train.py
Browse files- TabPFN/train.py +96 -124
TabPFN/train.py
CHANGED
@@ -17,45 +17,54 @@ import priors
|
|
17 |
import encoders
|
18 |
import positional_encodings
|
19 |
from utils import init_dist
|
20 |
-
from torch.cuda.amp import autocast
|
|
|
21 |
|
22 |
class Losses():
|
23 |
gaussian = nn.GaussianNLLLoss(full=True, reduction='none')
|
24 |
mse = nn.MSELoss(reduction='none')
|
25 |
-
ce
|
|
|
|
|
26 |
bce = nn.BCEWithLogitsLoss(reduction='none')
|
27 |
|
28 |
|
29 |
-
|
|
|
30 |
epochs=10, steps_per_epoch=100, batch_size=200, bptt=10, lr=None, weight_decay=0.0, warmup_epochs=10, input_normalization=False,
|
31 |
y_encoder_generator=None, pos_encoder_generator=None, decoder=None, extra_prior_kwargs_dict={}, scheduler=get_cosine_schedule_with_warmup,
|
32 |
load_weights_from_this_state_dict=None, validation_period=10, single_eval_pos_gen=None, bptt_extra_samples=None, gpu_device='cuda:0',
|
33 |
-
aggregate_k_gradients=1, verbose=True, style_encoder_generator=None,
|
34 |
-
initializer=None, initialize_with_model=None, train_mixed_precision=False,
|
35 |
):
|
36 |
-
assert (epochs is None) != (total_available_time_in_s is None)
|
37 |
-
start_of_training = time.time()
|
38 |
device = gpu_device if torch.cuda.is_available() else 'cpu:0'
|
39 |
print(f'Using {device} device')
|
40 |
using_dist, rank, device = init_dist(device)
|
41 |
-
|
42 |
-
dl = priordataloader_class(num_steps=steps_per_epoch, batch_size=batch_size, seq_len=bptt_sampler, seq_len_maximum=bptt+(bptt_extra_samples if bptt_extra_samples else 0), device=device, **extra_prior_kwargs_dict)
|
43 |
-
if dl.fuse_x_y:
|
44 |
-
raise Exception("Illegal parameter")
|
45 |
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
|
50 |
-
|
|
|
|
|
51 |
if isinstance(criterion, nn.GaussianNLLLoss):
|
52 |
-
n_out
|
53 |
elif isinstance(criterion, nn.CrossEntropyLoss):
|
54 |
-
n_out
|
|
|
|
|
|
|
55 |
model = TransformerModel(encoder, n_out, emsize, nhead, nhid, nlayers, dropout, style_encoder=style_encoder,
|
56 |
-
y_encoder=y_encoder_generator(
|
57 |
pos_encoder=(pos_encoder_generator or positional_encodings.NoPositionalEncoding)(emsize, bptt*2),
|
58 |
-
decoder=decoder, init_method=initializer, **model_extra_args
|
59 |
)
|
60 |
model.criterion = criterion
|
61 |
if load_weights_from_this_state_dict is not None:
|
@@ -75,6 +84,7 @@ def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=
|
|
75 |
if using_dist:
|
76 |
print("Distributed training")
|
77 |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, broadcast_buffers=False)
|
|
|
78 |
|
79 |
|
80 |
# learning rate
|
@@ -84,21 +94,25 @@ def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=
|
|
84 |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
85 |
scheduler = scheduler(optimizer, warmup_epochs, epochs if epochs is not None else 100) # when training for fixed time lr schedule takes 100 steps
|
86 |
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
88 |
model.train() # Turn on the train mode
|
89 |
total_loss = 0.
|
90 |
total_positional_losses = 0.
|
91 |
total_positional_losses_recorded = 0
|
|
|
|
|
92 |
before_get_batch = time.time()
|
93 |
assert len(dl) % aggregate_k_gradients == 0, 'Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it.'
|
94 |
-
|
95 |
-
for batch, (data, targets) in enumerate(dl):
|
96 |
if using_dist and not (batch % aggregate_k_gradients == aggregate_k_gradients - 1):
|
97 |
cm = model.no_sync()
|
98 |
-
#print(f'p={rank}, no_sync', force=True)
|
99 |
else:
|
100 |
cm = nullcontext()
|
101 |
-
#print(f'p={rank}, sync', force=True)
|
102 |
with cm:
|
103 |
time_to_get_batch = time.time() - before_get_batch
|
104 |
before_forward = time.time()
|
@@ -107,100 +121,75 @@ def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=
|
|
107 |
else:
|
108 |
single_eval_pos = targets.shape[0] - bptt_extra_samples
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
for
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
#output, targets = output[:, is_compatible], targets[:, is_compatible]
|
139 |
-
|
140 |
-
if single_eval_pos is not None:
|
141 |
-
targets = targets[single_eval_pos:]
|
142 |
-
if isinstance(criterion, nn.GaussianNLLLoss):
|
143 |
-
assert output.shape[-1] == 2, \
|
144 |
-
'need to write a little bit of code to handle multiple regression targets at once'
|
145 |
-
|
146 |
-
mean_pred = output[..., 0]
|
147 |
-
var_pred = output[..., 1].abs()
|
148 |
-
losses = criterion(mean_pred.flatten(), targets.to(device).flatten(), var=var_pred.flatten())
|
149 |
-
elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
|
150 |
-
losses = criterion(output.flatten(), targets.to(device).flatten())
|
151 |
-
elif isinstance(criterion, (nn.CrossEntropyLoss)):
|
152 |
-
#print(n_out, targets.min(), targets.max(), force=True)
|
153 |
-
losses = criterion(output.reshape(-1, n_out), targets.to(device).long().flatten())
|
154 |
-
else:
|
155 |
-
losses = criterion(output.reshape(-1, n_out), targets.to(device).flatten())
|
156 |
-
losses = losses.view(*output.shape[0:2])
|
157 |
-
loss = losses.mean(0) @ is_compatible.float() / losses.shape[1]
|
158 |
-
#loss = torch_nanmean(losses, axis=[0, 1]) * is_compatible.float().mean()
|
159 |
-
# not sure whether we can go without the nan checks.
|
160 |
-
|
161 |
loss.backward()
|
162 |
|
163 |
-
if
|
164 |
-
|
165 |
-
with torch.no_grad():
|
166 |
-
for p in model.parameters():
|
167 |
-
if p.grad is not None:
|
168 |
-
p.grad.div_(valid_batch_steps)
|
169 |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
|
170 |
try:
|
171 |
-
|
|
|
|
|
|
|
|
|
172 |
except:
|
173 |
print("Invalid optimization step encountered")
|
174 |
optimizer.zero_grad()
|
175 |
-
valid_batch_steps = 0.0
|
176 |
|
177 |
step_time = time.time() - before_forward
|
178 |
|
179 |
if not torch.isnan(loss):
|
180 |
-
total_loss +=
|
181 |
total_positional_losses += losses.mean(1).cpu().detach() if single_eval_pos is None else \
|
182 |
-
nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
|
|
|
183 |
|
184 |
total_positional_losses_recorded += torch.ones(bptt) if single_eval_pos is None else \
|
185 |
nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
|
|
|
|
|
|
|
186 |
|
187 |
before_get_batch = time.time()
|
188 |
-
return total_loss / steps_per_epoch, (
|
189 |
-
|
|
|
190 |
|
191 |
-
best_val_loss = float("inf")
|
192 |
-
best_model = None
|
193 |
total_loss = float('inf')
|
194 |
total_positional_losses = float('inf')
|
195 |
try:
|
196 |
for epoch in (range(1, epochs + 1) if epochs is not None else itertools.count(1)):
|
197 |
|
198 |
epoch_start_time = time.time()
|
199 |
-
|
200 |
-
|
201 |
-
total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train_step()
|
202 |
-
else:
|
203 |
-
total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train_step()
|
204 |
if hasattr(dl, 'validate') and epoch % validation_period == 0:
|
205 |
with torch.no_grad():
|
206 |
val_score = dl.validate(model)
|
@@ -213,25 +202,23 @@ def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=
|
|
213 |
f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | '
|
214 |
f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
|
215 |
f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}'
|
216 |
-
f' forward time {forward_time:5.2f}'
|
|
|
|
|
217 |
print('-' * 89)
|
218 |
|
219 |
# stepping with wallclock time based scheduler
|
220 |
-
current_time = time.time()
|
221 |
if epoch_callback is not None and rank == 0:
|
222 |
-
epoch_callback(model, epoch / epochs
|
223 |
-
|
224 |
-
)
|
225 |
-
if epochs is None and (current_time - start_of_training) > total_available_time_in_s: # noqa
|
226 |
-
break
|
227 |
-
if epochs is None:
|
228 |
-
scheduler.step((current_time - epoch_start_time) / total_available_time_in_s * 100)
|
229 |
-
else:
|
230 |
-
scheduler.step()
|
231 |
except KeyboardInterrupt:
|
232 |
pass
|
233 |
|
234 |
-
|
|
|
|
|
|
|
|
|
235 |
|
236 |
def _parse_args(config_parser, parser):
|
237 |
# Do we have a config file to parse?
|
@@ -261,16 +248,17 @@ if __name__ == '__main__':
|
|
261 |
parser.add_argument('--max_y', type=float, help='barnll can only model y in strict ranges, this is the maximum y can take.')
|
262 |
parser.add_argument('--num_buckets', default=100, type=int)
|
263 |
#parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
|
264 |
-
parser.add_argument("--extra_prior_kwargs_dict", default={
|
265 |
parser.add_argument('--encoder', default='linear', type=str, help='Specify depending on the prior.')
|
266 |
parser.add_argument('--y_encoder', default='linear', type=str, help='Specify depending on the prior. You should specify this if you do not fuse x and y.')
|
267 |
-
parser.add_argument('--pos_encoder', default='
|
268 |
parser.add_argument('--bptt', default=10, type=int)
|
269 |
parser.add_argument('--epochs', default=200, type=int)
|
270 |
parser.add_argument('--warmup_epochs', default=50, type=int)
|
271 |
parser.add_argument('--validation_period', default=10, type=int)
|
272 |
parser.add_argument('--permutation_invariant_max_eval_pos', default=None, type=int, help='Set this to an int to ')
|
273 |
parser.add_argument('--permutation_invariant_sampling', default='weighted', help="Only relevant if --permutation_invariant_max_eval_pos is set.")
|
|
|
274 |
|
275 |
# these can likely be mostly left at defaults
|
276 |
parser.add_argument('--emsize', default=512, type=int) # sometimes even larger is better e.g. 1024
|
@@ -309,28 +297,12 @@ if __name__ == '__main__':
|
|
309 |
min_y = args.__dict__.pop('min_y')
|
310 |
# criterion = nn.MSELoss(reduction='none')
|
311 |
|
312 |
-
def get_y_sample():
|
313 |
-
dl = prior(num_steps=1, batch_size=args.batch_size * args.steps_per_epoch, seq_len=args.bptt, device=device,
|
314 |
-
**args.extra_prior_kwargs_dict)
|
315 |
-
y_sample = next(iter(dl))[-1]
|
316 |
-
print(f'Creating Bar distribution with borders from y sample of size {y_sample.numel()}')
|
317 |
-
return y_sample
|
318 |
-
|
319 |
if loss_function == 'ce':
|
320 |
criterion = nn.CrossEntropyLoss(reduction='none')
|
321 |
elif loss_function == 'gaussnll':
|
322 |
criterion = nn.GaussianNLLLoss(reduction='none', full=True)
|
323 |
elif loss_function == 'mse':
|
324 |
criterion = nn.MSELoss(reduction='none')
|
325 |
-
elif loss_function == 'barnll':
|
326 |
-
criterion = BarDistribution(borders=get_bucket_limits(num_buckets, full_range=(min_y,max_y)))
|
327 |
-
elif loss_function == 'adaptivebarnll':
|
328 |
-
borders = get_bucket_limits(num_buckets, ys=get_y_sample(), full_range=(min_y,max_y))
|
329 |
-
criterion = BarDistribution(borders=borders)
|
330 |
-
elif loss_function == 'adaptivefullsupportbarnll':
|
331 |
-
assert min_y is None and max_y is None, "Please do not specify `min_y` and `max_y` with `unboundedadaptivebarnll`."
|
332 |
-
borders = get_bucket_limits(num_buckets, ys=get_y_sample())
|
333 |
-
criterion = FullSupportBarDistribution(borders=borders)
|
334 |
else:
|
335 |
raise NotImplementedError(f'loss_function == {loss_function}.')
|
336 |
|
|
|
17 |
import encoders
|
18 |
import positional_encodings
|
19 |
from utils import init_dist
|
20 |
+
from torch.cuda.amp import autocast, GradScaler
|
21 |
+
from torch import nn
|
22 |
|
23 |
class Losses():
|
24 |
gaussian = nn.GaussianNLLLoss(full=True, reduction='none')
|
25 |
mse = nn.MSELoss(reduction='none')
|
26 |
+
def ce(num_classes):
|
27 |
+
num_classes = num_classes.shape[0] if torch.is_tensor(num_classes) else num_classes
|
28 |
+
return nn.CrossEntropyLoss(reduction='none', weight=torch.ones(num_classes))
|
29 |
bce = nn.BCEWithLogitsLoss(reduction='none')
|
30 |
|
31 |
|
32 |
+
|
33 |
+
def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=200, nlayers=6, nhead=2, dropout=0.0,
|
34 |
epochs=10, steps_per_epoch=100, batch_size=200, bptt=10, lr=None, weight_decay=0.0, warmup_epochs=10, input_normalization=False,
|
35 |
y_encoder_generator=None, pos_encoder_generator=None, decoder=None, extra_prior_kwargs_dict={}, scheduler=get_cosine_schedule_with_warmup,
|
36 |
load_weights_from_this_state_dict=None, validation_period=10, single_eval_pos_gen=None, bptt_extra_samples=None, gpu_device='cuda:0',
|
37 |
+
aggregate_k_gradients=1, verbose=True, style_encoder_generator=None, epoch_callback=None,
|
38 |
+
initializer=None, initialize_with_model=None, train_mixed_precision=False, efficient_eval_masking=True, **model_extra_args
|
39 |
):
|
|
|
|
|
40 |
device = gpu_device if torch.cuda.is_available() else 'cpu:0'
|
41 |
print(f'Using {device} device')
|
42 |
using_dist, rank, device = init_dist(device)
|
43 |
+
single_eval_pos_gen = single_eval_pos_gen if callable(single_eval_pos_gen) else lambda: single_eval_pos_gen
|
|
|
|
|
|
|
44 |
|
45 |
+
def eval_pos_seq_len_sampler():
|
46 |
+
single_eval_pos = single_eval_pos_gen()
|
47 |
+
if bptt_extra_samples:
|
48 |
+
return single_eval_pos, single_eval_pos + bptt_extra_samples
|
49 |
+
else:
|
50 |
+
return single_eval_pos, bptt
|
51 |
+
dl = priordataloader_class(num_steps=steps_per_epoch, batch_size=batch_size, eval_pos_seq_len_sampler=eval_pos_seq_len_sampler, seq_len_maximum=bptt+(bptt_extra_samples if bptt_extra_samples else 0), device=device, **extra_prior_kwargs_dict)
|
52 |
|
53 |
+
encoder = encoder_generator(dl.num_features, emsize)
|
54 |
+
style_def = dl.get_test_batch()[0][0] # the style in batch of the form ((style, x, y), target, single_eval_pos)
|
55 |
+
print(f'Style definition of first 3 examples: {style_def[:3] if style_def is not None else None}')
|
56 |
+
style_encoder = style_encoder_generator(style_def.shape[1], emsize) if (style_def is not None) else None
|
57 |
if isinstance(criterion, nn.GaussianNLLLoss):
|
58 |
+
n_out = 2
|
59 |
elif isinstance(criterion, nn.CrossEntropyLoss):
|
60 |
+
n_out = criterion.weight.shape[0]
|
61 |
+
else:
|
62 |
+
n_out = 1
|
63 |
+
|
64 |
model = TransformerModel(encoder, n_out, emsize, nhead, nhid, nlayers, dropout, style_encoder=style_encoder,
|
65 |
+
y_encoder=y_encoder_generator(1, emsize), input_normalization=input_normalization,
|
66 |
pos_encoder=(pos_encoder_generator or positional_encodings.NoPositionalEncoding)(emsize, bptt*2),
|
67 |
+
decoder=decoder, init_method=initializer, efficient_eval_masking=efficient_eval_masking, **model_extra_args
|
68 |
)
|
69 |
model.criterion = criterion
|
70 |
if load_weights_from_this_state_dict is not None:
|
|
|
84 |
if using_dist:
|
85 |
print("Distributed training")
|
86 |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, broadcast_buffers=False)
|
87 |
+
dl.model = model
|
88 |
|
89 |
|
90 |
# learning rate
|
|
|
94 |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
95 |
scheduler = scheduler(optimizer, warmup_epochs, epochs if epochs is not None else 100) # when training for fixed time lr schedule takes 100 steps
|
96 |
|
97 |
+
scaler = GradScaler() if train_mixed_precision else None
|
98 |
+
|
99 |
+
# check that everything uses up-to-date APIs
|
100 |
+
utils.check_compatibility(dl)
|
101 |
+
|
102 |
+
def train_epoch():
|
103 |
model.train() # Turn on the train mode
|
104 |
total_loss = 0.
|
105 |
total_positional_losses = 0.
|
106 |
total_positional_losses_recorded = 0
|
107 |
+
nan_steps = 0
|
108 |
+
ignore_steps = 0
|
109 |
before_get_batch = time.time()
|
110 |
assert len(dl) % aggregate_k_gradients == 0, 'Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it.'
|
111 |
+
for batch, (data, targets, single_eval_pos) in enumerate(dl):
|
|
|
112 |
if using_dist and not (batch % aggregate_k_gradients == aggregate_k_gradients - 1):
|
113 |
cm = model.no_sync()
|
|
|
114 |
else:
|
115 |
cm = nullcontext()
|
|
|
116 |
with cm:
|
117 |
time_to_get_batch = time.time() - before_get_batch
|
118 |
before_forward = time.time()
|
|
|
121 |
else:
|
122 |
single_eval_pos = targets.shape[0] - bptt_extra_samples
|
123 |
|
124 |
+
with autocast(enabled=scaler is not None):
|
125 |
+
# If style is set to None, it should not be transferred to device
|
126 |
+
output = model(tuple(e.to(device) if torch.is_tensor(e) else e for e in data) if isinstance(data, tuple) else data.to(device)
|
127 |
+
, single_eval_pos=single_eval_pos)
|
128 |
+
|
129 |
+
forward_time = time.time() - before_forward
|
130 |
+
|
131 |
+
if single_eval_pos is not None:
|
132 |
+
targets = targets[single_eval_pos:]
|
133 |
+
if isinstance(criterion, nn.GaussianNLLLoss):
|
134 |
+
assert output.shape[-1] == 2, \
|
135 |
+
'need to write a little bit of code to handle multiple regression targets at once'
|
136 |
+
|
137 |
+
mean_pred = output[..., 0]
|
138 |
+
var_pred = output[..., 1].abs()
|
139 |
+
losses = criterion(mean_pred.flatten(), targets.to(device).flatten(), var=var_pred.flatten())
|
140 |
+
elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
|
141 |
+
losses = criterion(output.flatten(), targets.to(device).flatten())
|
142 |
+
elif isinstance(criterion, nn.CrossEntropyLoss):
|
143 |
+
losses = criterion(output.reshape(-1, n_out), targets.to(device).long().flatten())
|
144 |
+
else:
|
145 |
+
losses = criterion(output, targets)
|
146 |
+
losses = losses.view(*output.shape[0:2])
|
147 |
+
loss, nan_share = utils.torch_nanmean(losses.mean(0), return_nanshare=True)
|
148 |
+
loss = loss / aggregate_k_gradients
|
149 |
+
|
150 |
+
if scaler: loss = scaler.scale(loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
loss.backward()
|
152 |
|
153 |
+
if batch % aggregate_k_gradients == aggregate_k_gradients - 1:
|
154 |
+
if scaler: scaler.unscale_(optimizer)
|
|
|
|
|
|
|
|
|
155 |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
|
156 |
try:
|
157 |
+
if scaler:
|
158 |
+
scaler.step(optimizer)
|
159 |
+
scaler.update()
|
160 |
+
else:
|
161 |
+
optimizer.step()
|
162 |
except:
|
163 |
print("Invalid optimization step encountered")
|
164 |
optimizer.zero_grad()
|
|
|
165 |
|
166 |
step_time = time.time() - before_forward
|
167 |
|
168 |
if not torch.isnan(loss):
|
169 |
+
total_loss += losses.mean().cpu().detach().item()
|
170 |
total_positional_losses += losses.mean(1).cpu().detach() if single_eval_pos is None else \
|
171 |
+
nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)*\
|
172 |
+
losses[:bptt-single_eval_pos].mean().cpu().detach()
|
173 |
|
174 |
total_positional_losses_recorded += torch.ones(bptt) if single_eval_pos is None else \
|
175 |
nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
|
176 |
+
nan_steps += nan_share
|
177 |
+
ignore_steps += (targets == -100).float().mean()
|
178 |
+
|
179 |
|
180 |
before_get_batch = time.time()
|
181 |
+
return total_loss / steps_per_epoch, (total_positional_losses / total_positional_losses_recorded).tolist(),\
|
182 |
+
time_to_get_batch, forward_time, step_time, nan_steps.cpu().item()/(batch+1),\
|
183 |
+
ignore_steps.cpu().item()/(batch+1)
|
184 |
|
|
|
|
|
185 |
total_loss = float('inf')
|
186 |
total_positional_losses = float('inf')
|
187 |
try:
|
188 |
for epoch in (range(1, epochs + 1) if epochs is not None else itertools.count(1)):
|
189 |
|
190 |
epoch_start_time = time.time()
|
191 |
+
total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time, nan_share, ignore_share =\
|
192 |
+
train_epoch()
|
|
|
|
|
|
|
193 |
if hasattr(dl, 'validate') and epoch % validation_period == 0:
|
194 |
with torch.no_grad():
|
195 |
val_score = dl.validate(model)
|
|
|
202 |
f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | '
|
203 |
f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
|
204 |
f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}'
|
205 |
+
f' forward time {forward_time:5.2f}'
|
206 |
+
f' nan share {nan_share:5.2f} ignore share (for classification tasks) {ignore_share:5.4f}'
|
207 |
+
+ (f'val score {val_score}' if val_score is not None else ''))
|
208 |
print('-' * 89)
|
209 |
|
210 |
# stepping with wallclock time based scheduler
|
|
|
211 |
if epoch_callback is not None and rank == 0:
|
212 |
+
epoch_callback(model, epoch / epochs)
|
213 |
+
scheduler.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
except KeyboardInterrupt:
|
215 |
pass
|
216 |
|
217 |
+
if rank == 0: # trivially true for non-parallel training
|
218 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
219 |
+
model = model.module
|
220 |
+
dl = None
|
221 |
+
return total_loss, total_positional_losses, model.to('cpu'), dl
|
222 |
|
223 |
def _parse_args(config_parser, parser):
|
224 |
# Do we have a config file to parse?
|
|
|
248 |
parser.add_argument('--max_y', type=float, help='barnll can only model y in strict ranges, this is the maximum y can take.')
|
249 |
parser.add_argument('--num_buckets', default=100, type=int)
|
250 |
#parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
|
251 |
+
parser.add_argument("--extra_prior_kwargs_dict", default={}, dest="extra_prior_kwargs_dict", action=StoreDictKeyPair, nargs="+", metavar="KEY=VAL", help='Specify depending on the prior.')
|
252 |
parser.add_argument('--encoder', default='linear', type=str, help='Specify depending on the prior.')
|
253 |
parser.add_argument('--y_encoder', default='linear', type=str, help='Specify depending on the prior. You should specify this if you do not fuse x and y.')
|
254 |
+
parser.add_argument('--pos_encoder', default='none', type=str, help='Specify depending on the prior.')
|
255 |
parser.add_argument('--bptt', default=10, type=int)
|
256 |
parser.add_argument('--epochs', default=200, type=int)
|
257 |
parser.add_argument('--warmup_epochs', default=50, type=int)
|
258 |
parser.add_argument('--validation_period', default=10, type=int)
|
259 |
parser.add_argument('--permutation_invariant_max_eval_pos', default=None, type=int, help='Set this to an int to ')
|
260 |
parser.add_argument('--permutation_invariant_sampling', default='weighted', help="Only relevant if --permutation_invariant_max_eval_pos is set.")
|
261 |
+
parser.add_argument('--train_mixed_precision', action='store_true')
|
262 |
|
263 |
# these can likely be mostly left at defaults
|
264 |
parser.add_argument('--emsize', default=512, type=int) # sometimes even larger is better e.g. 1024
|
|
|
297 |
min_y = args.__dict__.pop('min_y')
|
298 |
# criterion = nn.MSELoss(reduction='none')
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
if loss_function == 'ce':
|
301 |
criterion = nn.CrossEntropyLoss(reduction='none')
|
302 |
elif loss_function == 'gaussnll':
|
303 |
criterion = nn.GaussianNLLLoss(reduction='none', full=True)
|
304 |
elif loss_function == 'mse':
|
305 |
criterion = nn.MSELoss(reduction='none')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
else:
|
307 |
raise NotImplementedError(f'loss_function == {loss_function}.')
|
308 |
|