TabPFN commited on
Commit
6c53143
·
1 Parent(s): e0f96a9

Upload train.py

Browse files
Files changed (1) hide show
  1. TabPFN/train.py +96 -124
TabPFN/train.py CHANGED
@@ -17,45 +17,54 @@ import priors
17
  import encoders
18
  import positional_encodings
19
  from utils import init_dist
20
- from torch.cuda.amp import autocast
 
21
 
22
  class Losses():
23
  gaussian = nn.GaussianNLLLoss(full=True, reduction='none')
24
  mse = nn.MSELoss(reduction='none')
25
- ce = lambda weight : nn.CrossEntropyLoss(reduction='none', weight=weight)
 
 
26
  bce = nn.BCEWithLogitsLoss(reduction='none')
27
 
28
 
29
- def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=200, nlayers=6, nhead=2, dropout=0.2,
 
30
  epochs=10, steps_per_epoch=100, batch_size=200, bptt=10, lr=None, weight_decay=0.0, warmup_epochs=10, input_normalization=False,
31
  y_encoder_generator=None, pos_encoder_generator=None, decoder=None, extra_prior_kwargs_dict={}, scheduler=get_cosine_schedule_with_warmup,
32
  load_weights_from_this_state_dict=None, validation_period=10, single_eval_pos_gen=None, bptt_extra_samples=None, gpu_device='cuda:0',
33
- aggregate_k_gradients=1, verbose=True, style_encoder_generator=None, check_is_compatible=True, epoch_callback=None,
34
- initializer=None, initialize_with_model=None, train_mixed_precision=False, total_available_time_in_s=None, normalize_labels=True, **model_extra_args
35
  ):
36
- assert (epochs is None) != (total_available_time_in_s is None)
37
- start_of_training = time.time()
38
  device = gpu_device if torch.cuda.is_available() else 'cpu:0'
39
  print(f'Using {device} device')
40
  using_dist, rank, device = init_dist(device)
41
- bptt_sampler = (lambda : single_eval_pos_gen() + bptt_extra_samples if callable(single_eval_pos_gen) else single_eval_pos_gen + bptt_extra_samples) if bptt_extra_samples is not None else bptt
42
- dl = priordataloader_class(num_steps=steps_per_epoch, batch_size=batch_size, seq_len=bptt_sampler, seq_len_maximum=bptt+(bptt_extra_samples if bptt_extra_samples else 0), device=device, **extra_prior_kwargs_dict)
43
- if dl.fuse_x_y:
44
- raise Exception("Illegal parameter")
45
 
46
- encoder = encoder_generator(dl.num_features+1 if dl.fuse_x_y else dl.num_features,emsize)
47
- style_def = next(iter(dl))[0][0] # This is (style, x, y), target with x and y with batch size
 
 
 
 
 
48
 
49
- style_encoder = style_encoder_generator(hyperparameter_definitions=style_def[0], em_size=emsize) if (style_def is not None) else None
50
- n_out = dl.num_outputs
 
 
51
  if isinstance(criterion, nn.GaussianNLLLoss):
52
- n_out *= 2
53
  elif isinstance(criterion, nn.CrossEntropyLoss):
54
- n_out *= criterion.weight.shape[0]
 
 
 
55
  model = TransformerModel(encoder, n_out, emsize, nhead, nhid, nlayers, dropout, style_encoder=style_encoder,
56
- y_encoder=y_encoder_generator(dl.num_outputs, emsize), input_normalization=input_normalization,
57
  pos_encoder=(pos_encoder_generator or positional_encodings.NoPositionalEncoding)(emsize, bptt*2),
58
- decoder=decoder, init_method=initializer, **model_extra_args
59
  )
60
  model.criterion = criterion
61
  if load_weights_from_this_state_dict is not None:
@@ -75,6 +84,7 @@ def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=
75
  if using_dist:
76
  print("Distributed training")
77
  model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, broadcast_buffers=False)
 
78
 
79
 
80
  # learning rate
@@ -84,21 +94,25 @@ def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=
84
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
85
  scheduler = scheduler(optimizer, warmup_epochs, epochs if epochs is not None else 100) # when training for fixed time lr schedule takes 100 steps
86
 
87
- def train_step():
 
 
 
 
 
88
  model.train() # Turn on the train mode
89
  total_loss = 0.
90
  total_positional_losses = 0.
91
  total_positional_losses_recorded = 0
 
 
92
  before_get_batch = time.time()
93
  assert len(dl) % aggregate_k_gradients == 0, 'Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it.'
94
- valid_batch_steps = 0.0
95
- for batch, (data, targets) in enumerate(dl):
96
  if using_dist and not (batch % aggregate_k_gradients == aggregate_k_gradients - 1):
97
  cm = model.no_sync()
98
- #print(f'p={rank}, no_sync', force=True)
99
  else:
100
  cm = nullcontext()
101
- #print(f'p={rank}, sync', force=True)
102
  with cm:
103
  time_to_get_batch = time.time() - before_get_batch
104
  before_forward = time.time()
@@ -107,100 +121,75 @@ def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=
107
  else:
108
  single_eval_pos = targets.shape[0] - bptt_extra_samples
109
 
110
- is_compatible = torch.ones((targets.shape[1])).bool()
111
- if check_is_compatible or normalize_labels:
112
- for b in range(targets.shape[1]):
113
- targets_in_train = torch.unique(targets[:single_eval_pos, b], sorted=True)
114
- targets_in_eval = torch.unique(targets[single_eval_pos:, b], sorted=True)
115
-
116
- if check_is_compatible:
117
- is_compatible[b] = len(targets_in_train) == len(targets_in_eval) and (targets_in_train == targets_in_eval).all()
118
- is_compatible[b] = is_compatible[b] and len(targets_in_train) > 1
119
-
120
- # Set targets to range starting from 0 (e.g. targets 0, 2, 5, 2 will be converted to 0, 1, 2, 1)
121
- if normalize_labels:
122
- targets[:, b] = (targets[:, b] > torch.unique(targets[:, b]).unsqueeze(1)).sum(axis=0).unsqueeze(0)
123
- valid_batch_steps += is_compatible.float().mean()
124
- is_compatible = is_compatible.to(device)
125
- #if using_dist and check_is_compatible:
126
- # print('step share before reduce',curr_step_share, force=True)
127
- # curr_step_share = curr_step_share.to(device)
128
- # torch.distributed.all_reduce_multigpu([curr_step_share], op=torch.distributed.ReduceOp.SUM)
129
- # curr_step_share = curr_step_share.cpu() / torch.distributed.get_world_size()
130
- # print('step share after reduce',curr_step_share, torch.distributed.get_world_size(), force=True)
131
-
132
- # If style is set to None, it should not be transferred to device
133
- output = model(tuple(e.to(device) if torch.is_tensor(e) else e for e in data) if isinstance(data, tuple) else data.to(device)
134
- , single_eval_pos=single_eval_pos)
135
-
136
- forward_time = time.time() - before_forward
137
-
138
- #output, targets = output[:, is_compatible], targets[:, is_compatible]
139
-
140
- if single_eval_pos is not None:
141
- targets = targets[single_eval_pos:]
142
- if isinstance(criterion, nn.GaussianNLLLoss):
143
- assert output.shape[-1] == 2, \
144
- 'need to write a little bit of code to handle multiple regression targets at once'
145
-
146
- mean_pred = output[..., 0]
147
- var_pred = output[..., 1].abs()
148
- losses = criterion(mean_pred.flatten(), targets.to(device).flatten(), var=var_pred.flatten())
149
- elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
150
- losses = criterion(output.flatten(), targets.to(device).flatten())
151
- elif isinstance(criterion, (nn.CrossEntropyLoss)):
152
- #print(n_out, targets.min(), targets.max(), force=True)
153
- losses = criterion(output.reshape(-1, n_out), targets.to(device).long().flatten())
154
- else:
155
- losses = criterion(output.reshape(-1, n_out), targets.to(device).flatten())
156
- losses = losses.view(*output.shape[0:2])
157
- loss = losses.mean(0) @ is_compatible.float() / losses.shape[1]
158
- #loss = torch_nanmean(losses, axis=[0, 1]) * is_compatible.float().mean()
159
- # not sure whether we can go without the nan checks.
160
-
161
  loss.backward()
162
 
163
- if ((batch % aggregate_k_gradients == aggregate_k_gradients - 1) and (not check_is_compatible or using_dist))\
164
- or (valid_batch_steps >= aggregate_k_gradients and (check_is_compatible and not using_dist)):
165
- with torch.no_grad():
166
- for p in model.parameters():
167
- if p.grad is not None:
168
- p.grad.div_(valid_batch_steps)
169
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
170
  try:
171
- optimizer.step()
 
 
 
 
172
  except:
173
  print("Invalid optimization step encountered")
174
  optimizer.zero_grad()
175
- valid_batch_steps = 0.0
176
 
177
  step_time = time.time() - before_forward
178
 
179
  if not torch.isnan(loss):
180
- total_loss += loss.item()
181
  total_positional_losses += losses.mean(1).cpu().detach() if single_eval_pos is None else \
182
- nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)*loss.cpu().detach()
 
183
 
184
  total_positional_losses_recorded += torch.ones(bptt) if single_eval_pos is None else \
185
  nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
 
 
 
186
 
187
  before_get_batch = time.time()
188
- return total_loss / steps_per_epoch, (
189
- total_positional_losses / total_positional_losses_recorded).tolist(), time_to_get_batch, forward_time, step_time
 
190
 
191
- best_val_loss = float("inf")
192
- best_model = None
193
  total_loss = float('inf')
194
  total_positional_losses = float('inf')
195
  try:
196
  for epoch in (range(1, epochs + 1) if epochs is not None else itertools.count(1)):
197
 
198
  epoch_start_time = time.time()
199
- if train_mixed_precision:
200
- with autocast():
201
- total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train_step()
202
- else:
203
- total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time = train_step()
204
  if hasattr(dl, 'validate') and epoch % validation_period == 0:
205
  with torch.no_grad():
206
  val_score = dl.validate(model)
@@ -213,25 +202,23 @@ def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=
213
  f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | '
214
  f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
215
  f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}'
216
- f' forward time {forward_time:5.2f}' + (f'val score {val_score}' if val_score is not None else ''))
 
 
217
  print('-' * 89)
218
 
219
  # stepping with wallclock time based scheduler
220
- current_time = time.time()
221
  if epoch_callback is not None and rank == 0:
222
- epoch_callback(model, epoch / epochs if total_available_time_in_s is None else # noqa
223
- (current_time - start_of_training) / total_available_time_in_s # noqa
224
- )
225
- if epochs is None and (current_time - start_of_training) > total_available_time_in_s: # noqa
226
- break
227
- if epochs is None:
228
- scheduler.step((current_time - epoch_start_time) / total_available_time_in_s * 100)
229
- else:
230
- scheduler.step()
231
  except KeyboardInterrupt:
232
  pass
233
 
234
- return total_loss, total_positional_losses, model.to('cpu'), dl
 
 
 
 
235
 
236
  def _parse_args(config_parser, parser):
237
  # Do we have a config file to parse?
@@ -261,16 +248,17 @@ if __name__ == '__main__':
261
  parser.add_argument('--max_y', type=float, help='barnll can only model y in strict ranges, this is the maximum y can take.')
262
  parser.add_argument('--num_buckets', default=100, type=int)
263
  #parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
264
- parser.add_argument("--extra_prior_kwargs_dict", default={'fuse_x_y': False}, dest="extra_prior_kwargs_dict", action=StoreDictKeyPair, nargs="+", metavar="KEY=VAL", help='Specify depending on the prior.')
265
  parser.add_argument('--encoder', default='linear', type=str, help='Specify depending on the prior.')
266
  parser.add_argument('--y_encoder', default='linear', type=str, help='Specify depending on the prior. You should specify this if you do not fuse x and y.')
267
- parser.add_argument('--pos_encoder', default='sinus', type=str, help='Specify depending on the prior.')
268
  parser.add_argument('--bptt', default=10, type=int)
269
  parser.add_argument('--epochs', default=200, type=int)
270
  parser.add_argument('--warmup_epochs', default=50, type=int)
271
  parser.add_argument('--validation_period', default=10, type=int)
272
  parser.add_argument('--permutation_invariant_max_eval_pos', default=None, type=int, help='Set this to an int to ')
273
  parser.add_argument('--permutation_invariant_sampling', default='weighted', help="Only relevant if --permutation_invariant_max_eval_pos is set.")
 
274
 
275
  # these can likely be mostly left at defaults
276
  parser.add_argument('--emsize', default=512, type=int) # sometimes even larger is better e.g. 1024
@@ -309,28 +297,12 @@ if __name__ == '__main__':
309
  min_y = args.__dict__.pop('min_y')
310
  # criterion = nn.MSELoss(reduction='none')
311
 
312
- def get_y_sample():
313
- dl = prior(num_steps=1, batch_size=args.batch_size * args.steps_per_epoch, seq_len=args.bptt, device=device,
314
- **args.extra_prior_kwargs_dict)
315
- y_sample = next(iter(dl))[-1]
316
- print(f'Creating Bar distribution with borders from y sample of size {y_sample.numel()}')
317
- return y_sample
318
-
319
  if loss_function == 'ce':
320
  criterion = nn.CrossEntropyLoss(reduction='none')
321
  elif loss_function == 'gaussnll':
322
  criterion = nn.GaussianNLLLoss(reduction='none', full=True)
323
  elif loss_function == 'mse':
324
  criterion = nn.MSELoss(reduction='none')
325
- elif loss_function == 'barnll':
326
- criterion = BarDistribution(borders=get_bucket_limits(num_buckets, full_range=(min_y,max_y)))
327
- elif loss_function == 'adaptivebarnll':
328
- borders = get_bucket_limits(num_buckets, ys=get_y_sample(), full_range=(min_y,max_y))
329
- criterion = BarDistribution(borders=borders)
330
- elif loss_function == 'adaptivefullsupportbarnll':
331
- assert min_y is None and max_y is None, "Please do not specify `min_y` and `max_y` with `unboundedadaptivebarnll`."
332
- borders = get_bucket_limits(num_buckets, ys=get_y_sample())
333
- criterion = FullSupportBarDistribution(borders=borders)
334
  else:
335
  raise NotImplementedError(f'loss_function == {loss_function}.')
336
 
 
17
  import encoders
18
  import positional_encodings
19
  from utils import init_dist
20
+ from torch.cuda.amp import autocast, GradScaler
21
+ from torch import nn
22
 
23
  class Losses():
24
  gaussian = nn.GaussianNLLLoss(full=True, reduction='none')
25
  mse = nn.MSELoss(reduction='none')
26
+ def ce(num_classes):
27
+ num_classes = num_classes.shape[0] if torch.is_tensor(num_classes) else num_classes
28
+ return nn.CrossEntropyLoss(reduction='none', weight=torch.ones(num_classes))
29
  bce = nn.BCEWithLogitsLoss(reduction='none')
30
 
31
 
32
+
33
+ def train(priordataloader_class, criterion, encoder_generator, emsize=200, nhid=200, nlayers=6, nhead=2, dropout=0.0,
34
  epochs=10, steps_per_epoch=100, batch_size=200, bptt=10, lr=None, weight_decay=0.0, warmup_epochs=10, input_normalization=False,
35
  y_encoder_generator=None, pos_encoder_generator=None, decoder=None, extra_prior_kwargs_dict={}, scheduler=get_cosine_schedule_with_warmup,
36
  load_weights_from_this_state_dict=None, validation_period=10, single_eval_pos_gen=None, bptt_extra_samples=None, gpu_device='cuda:0',
37
+ aggregate_k_gradients=1, verbose=True, style_encoder_generator=None, epoch_callback=None,
38
+ initializer=None, initialize_with_model=None, train_mixed_precision=False, efficient_eval_masking=True, **model_extra_args
39
  ):
 
 
40
  device = gpu_device if torch.cuda.is_available() else 'cpu:0'
41
  print(f'Using {device} device')
42
  using_dist, rank, device = init_dist(device)
43
+ single_eval_pos_gen = single_eval_pos_gen if callable(single_eval_pos_gen) else lambda: single_eval_pos_gen
 
 
 
44
 
45
+ def eval_pos_seq_len_sampler():
46
+ single_eval_pos = single_eval_pos_gen()
47
+ if bptt_extra_samples:
48
+ return single_eval_pos, single_eval_pos + bptt_extra_samples
49
+ else:
50
+ return single_eval_pos, bptt
51
+ dl = priordataloader_class(num_steps=steps_per_epoch, batch_size=batch_size, eval_pos_seq_len_sampler=eval_pos_seq_len_sampler, seq_len_maximum=bptt+(bptt_extra_samples if bptt_extra_samples else 0), device=device, **extra_prior_kwargs_dict)
52
 
53
+ encoder = encoder_generator(dl.num_features, emsize)
54
+ style_def = dl.get_test_batch()[0][0] # the style in batch of the form ((style, x, y), target, single_eval_pos)
55
+ print(f'Style definition of first 3 examples: {style_def[:3] if style_def is not None else None}')
56
+ style_encoder = style_encoder_generator(style_def.shape[1], emsize) if (style_def is not None) else None
57
  if isinstance(criterion, nn.GaussianNLLLoss):
58
+ n_out = 2
59
  elif isinstance(criterion, nn.CrossEntropyLoss):
60
+ n_out = criterion.weight.shape[0]
61
+ else:
62
+ n_out = 1
63
+
64
  model = TransformerModel(encoder, n_out, emsize, nhead, nhid, nlayers, dropout, style_encoder=style_encoder,
65
+ y_encoder=y_encoder_generator(1, emsize), input_normalization=input_normalization,
66
  pos_encoder=(pos_encoder_generator or positional_encodings.NoPositionalEncoding)(emsize, bptt*2),
67
+ decoder=decoder, init_method=initializer, efficient_eval_masking=efficient_eval_masking, **model_extra_args
68
  )
69
  model.criterion = criterion
70
  if load_weights_from_this_state_dict is not None:
 
84
  if using_dist:
85
  print("Distributed training")
86
  model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, broadcast_buffers=False)
87
+ dl.model = model
88
 
89
 
90
  # learning rate
 
94
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
95
  scheduler = scheduler(optimizer, warmup_epochs, epochs if epochs is not None else 100) # when training for fixed time lr schedule takes 100 steps
96
 
97
+ scaler = GradScaler() if train_mixed_precision else None
98
+
99
+ # check that everything uses up-to-date APIs
100
+ utils.check_compatibility(dl)
101
+
102
+ def train_epoch():
103
  model.train() # Turn on the train mode
104
  total_loss = 0.
105
  total_positional_losses = 0.
106
  total_positional_losses_recorded = 0
107
+ nan_steps = 0
108
+ ignore_steps = 0
109
  before_get_batch = time.time()
110
  assert len(dl) % aggregate_k_gradients == 0, 'Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it.'
111
+ for batch, (data, targets, single_eval_pos) in enumerate(dl):
 
112
  if using_dist and not (batch % aggregate_k_gradients == aggregate_k_gradients - 1):
113
  cm = model.no_sync()
 
114
  else:
115
  cm = nullcontext()
 
116
  with cm:
117
  time_to_get_batch = time.time() - before_get_batch
118
  before_forward = time.time()
 
121
  else:
122
  single_eval_pos = targets.shape[0] - bptt_extra_samples
123
 
124
+ with autocast(enabled=scaler is not None):
125
+ # If style is set to None, it should not be transferred to device
126
+ output = model(tuple(e.to(device) if torch.is_tensor(e) else e for e in data) if isinstance(data, tuple) else data.to(device)
127
+ , single_eval_pos=single_eval_pos)
128
+
129
+ forward_time = time.time() - before_forward
130
+
131
+ if single_eval_pos is not None:
132
+ targets = targets[single_eval_pos:]
133
+ if isinstance(criterion, nn.GaussianNLLLoss):
134
+ assert output.shape[-1] == 2, \
135
+ 'need to write a little bit of code to handle multiple regression targets at once'
136
+
137
+ mean_pred = output[..., 0]
138
+ var_pred = output[..., 1].abs()
139
+ losses = criterion(mean_pred.flatten(), targets.to(device).flatten(), var=var_pred.flatten())
140
+ elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
141
+ losses = criterion(output.flatten(), targets.to(device).flatten())
142
+ elif isinstance(criterion, nn.CrossEntropyLoss):
143
+ losses = criterion(output.reshape(-1, n_out), targets.to(device).long().flatten())
144
+ else:
145
+ losses = criterion(output, targets)
146
+ losses = losses.view(*output.shape[0:2])
147
+ loss, nan_share = utils.torch_nanmean(losses.mean(0), return_nanshare=True)
148
+ loss = loss / aggregate_k_gradients
149
+
150
+ if scaler: loss = scaler.scale(loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  loss.backward()
152
 
153
+ if batch % aggregate_k_gradients == aggregate_k_gradients - 1:
154
+ if scaler: scaler.unscale_(optimizer)
 
 
 
 
155
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
156
  try:
157
+ if scaler:
158
+ scaler.step(optimizer)
159
+ scaler.update()
160
+ else:
161
+ optimizer.step()
162
  except:
163
  print("Invalid optimization step encountered")
164
  optimizer.zero_grad()
 
165
 
166
  step_time = time.time() - before_forward
167
 
168
  if not torch.isnan(loss):
169
+ total_loss += losses.mean().cpu().detach().item()
170
  total_positional_losses += losses.mean(1).cpu().detach() if single_eval_pos is None else \
171
+ nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)*\
172
+ losses[:bptt-single_eval_pos].mean().cpu().detach()
173
 
174
  total_positional_losses_recorded += torch.ones(bptt) if single_eval_pos is None else \
175
  nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
176
+ nan_steps += nan_share
177
+ ignore_steps += (targets == -100).float().mean()
178
+
179
 
180
  before_get_batch = time.time()
181
+ return total_loss / steps_per_epoch, (total_positional_losses / total_positional_losses_recorded).tolist(),\
182
+ time_to_get_batch, forward_time, step_time, nan_steps.cpu().item()/(batch+1),\
183
+ ignore_steps.cpu().item()/(batch+1)
184
 
 
 
185
  total_loss = float('inf')
186
  total_positional_losses = float('inf')
187
  try:
188
  for epoch in (range(1, epochs + 1) if epochs is not None else itertools.count(1)):
189
 
190
  epoch_start_time = time.time()
191
+ total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time, nan_share, ignore_share =\
192
+ train_epoch()
 
 
 
193
  if hasattr(dl, 'validate') and epoch % validation_period == 0:
194
  with torch.no_grad():
195
  val_score = dl.validate(model)
 
202
  f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | '
203
  f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
204
  f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}'
205
+ f' forward time {forward_time:5.2f}'
206
+ f' nan share {nan_share:5.2f} ignore share (for classification tasks) {ignore_share:5.4f}'
207
+ + (f'val score {val_score}' if val_score is not None else ''))
208
  print('-' * 89)
209
 
210
  # stepping with wallclock time based scheduler
 
211
  if epoch_callback is not None and rank == 0:
212
+ epoch_callback(model, epoch / epochs)
213
+ scheduler.step()
 
 
 
 
 
 
 
214
  except KeyboardInterrupt:
215
  pass
216
 
217
+ if rank == 0: # trivially true for non-parallel training
218
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
219
+ model = model.module
220
+ dl = None
221
+ return total_loss, total_positional_losses, model.to('cpu'), dl
222
 
223
  def _parse_args(config_parser, parser):
224
  # Do we have a config file to parse?
 
248
  parser.add_argument('--max_y', type=float, help='barnll can only model y in strict ranges, this is the maximum y can take.')
249
  parser.add_argument('--num_buckets', default=100, type=int)
250
  #parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
251
+ parser.add_argument("--extra_prior_kwargs_dict", default={}, dest="extra_prior_kwargs_dict", action=StoreDictKeyPair, nargs="+", metavar="KEY=VAL", help='Specify depending on the prior.')
252
  parser.add_argument('--encoder', default='linear', type=str, help='Specify depending on the prior.')
253
  parser.add_argument('--y_encoder', default='linear', type=str, help='Specify depending on the prior. You should specify this if you do not fuse x and y.')
254
+ parser.add_argument('--pos_encoder', default='none', type=str, help='Specify depending on the prior.')
255
  parser.add_argument('--bptt', default=10, type=int)
256
  parser.add_argument('--epochs', default=200, type=int)
257
  parser.add_argument('--warmup_epochs', default=50, type=int)
258
  parser.add_argument('--validation_period', default=10, type=int)
259
  parser.add_argument('--permutation_invariant_max_eval_pos', default=None, type=int, help='Set this to an int to ')
260
  parser.add_argument('--permutation_invariant_sampling', default='weighted', help="Only relevant if --permutation_invariant_max_eval_pos is set.")
261
+ parser.add_argument('--train_mixed_precision', action='store_true')
262
 
263
  # these can likely be mostly left at defaults
264
  parser.add_argument('--emsize', default=512, type=int) # sometimes even larger is better e.g. 1024
 
297
  min_y = args.__dict__.pop('min_y')
298
  # criterion = nn.MSELoss(reduction='none')
299
 
 
 
 
 
 
 
 
300
  if loss_function == 'ce':
301
  criterion = nn.CrossEntropyLoss(reduction='none')
302
  elif loss_function == 'gaussnll':
303
  criterion = nn.GaussianNLLLoss(reduction='none', full=True)
304
  elif loss_function == 'mse':
305
  criterion = nn.MSELoss(reduction='none')
 
 
 
 
 
 
 
 
 
306
  else:
307
  raise NotImplementedError(f'loss_function == {loss_function}.')
308