AlexK-PL commited on
Commit
48aeee0
·
1 Parent(s): 4d5330a

Upload training.py

Browse files
Files changed (1) hide show
  1. training.py +254 -0
training.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ONE EPOCH = one forward pass and one backward pass of all the training examples.
2
+ #
3
+ # BATCH SIZE = the number of training examples in one forward/backward pass. The
4
+ # higher the batch size, the more memory space you'll need.
5
+ #
6
+ # NUMBER OF ITERATIONS = number of passes, each pass using [batch size] number of
7
+ # examples. To be clear, one pass = one forward pass + one backward pass.
8
+ #
9
+ # Example: if you have 1000 training examples, and your batch size is 500, then
10
+ # it will take 2 iterations to complete 1 epoch.
11
+
12
+ import os
13
+ import time
14
+ import math
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ from torch.utils.data.distributed import DistributedSampler
19
+ from torch.utils.data import DataLoader
20
+ from numpy import finfo
21
+
22
+ from Tacotron2 import tacotron_2
23
+ from fp16_optimizer import FP16_Optimizer
24
+ # from distributed import apply_gradient_allreduce
25
+ from loss_function import Tacotron2Loss
26
+ from logger import Tacotron2Logger
27
+
28
+
29
+ def batchnorm_to_float(module):
30
+ """Converts batch norm modules to FP32"""
31
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
32
+ module.float()
33
+ for child in module.children():
34
+ batchnorm_to_float(child)
35
+ return module
36
+
37
+
38
+ def reduce_tensor(tensor, n_gpus):
39
+ # this function is recorded in the computation graph. Gradients propagating to the cloned tensor will propagate to
40
+ # the original tensor
41
+ rt = tensor.clone()
42
+ # Each rank has a tensor and all_reduce sums up all tensors from different ranks to all ranks. Computes the average
43
+ # of the tensor results of all ranks (a rank is a gpu as far as I understood):
44
+ dist.all_reduce(rt, op=dist.reduce_op.SUM)
45
+ rt /= n_gpus
46
+ return rt
47
+
48
+
49
+ def prepare_directories_and_logger(output_directory, log_directory, rank):
50
+ if rank == 0:
51
+ if not os.path.isdir(output_directory):
52
+ os.makedirs(output_directory)
53
+ os.chmod(output_directory, 0o775)
54
+ logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
55
+ # logger = None
56
+ else:
57
+ logger = None
58
+ return logger
59
+
60
+
61
+ def warm_start_model(checkpoint_path, model):
62
+ assert os.path.isfile(checkpoint_path)
63
+ print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
64
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
65
+ model.load_state_dict(checkpoint_dict['state_dict'])
66
+ return model
67
+
68
+
69
+ def load_checkpoint(checkpoint_path, model, optimizer):
70
+ assert os.path.isfile(checkpoint_path)
71
+ print("Loading checkpoint '{}'".format(checkpoint_path))
72
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
73
+ model.load_state_dict(checkpoint_dict['state_dict'])
74
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
75
+ learning_rate = checkpoint_dict['learning_rate']
76
+ iteration = checkpoint_dict['iteration']
77
+ print("Loaded checkpoint '{}' from iteration {}".format(checkpoint_path, iteration))
78
+ return model, optimizer, learning_rate, iteration
79
+
80
+
81
+ def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
82
+ print("Saving model and optimizer state at iteration {} to {}".format(iteration, filepath))
83
+ torch.save({'iteration': iteration,
84
+ 'state_dict': model.state_dict(),
85
+ 'optimizer': optimizer.state_dict(),
86
+ 'learning_rate': learning_rate}, filepath)
87
+
88
+
89
+ def init_distributed(hyper_params, n_gpus, rank, group_name):
90
+ assert torch.cuda.is_available(), "Distributed mode requires CUDA"
91
+ print("Initializing distributed")
92
+ # Set CUDA device so everything is done on the right GPU
93
+ torch.cuda.set_device(rank % torch.cuda.device_count())
94
+
95
+ # Initialize distributed communication
96
+ torch.distributed.init_process_group(backend=hyper_params['dist_backend'], rank=rank, world_size=n_gpus,
97
+ init_method=hyper_params['dist_url'], group_name=group_name)
98
+
99
+ print("Initializing distributed: Done")
100
+
101
+
102
+ def load_model(hyper_params):
103
+ # according to the documentation, it is recommended to move a model to GPU before constructing the optimizer
104
+ # model = tacotron_2(hyper_params).cuda()
105
+ model = tacotron_2(hyper_params)
106
+ if hyper_params['fp16_run']: # converts everything into half type (16 bits)
107
+ model = batchnorm_to_float(model.half())
108
+ model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
109
+
110
+ # if hyper_params['distributed_run']:
111
+ # model = apply_gradient_allreduce(model)
112
+
113
+ return model
114
+
115
+
116
+ def validate(model, criterion, valset, iteration, batch_size, n_gpus, collate_fn, logger, distributed_run, rank):
117
+ """Handles all the validation scoring and printing"""
118
+
119
+ # We change to eval() because this is an evaluation stage and not a training
120
+ model.eval()
121
+ # temporarily set all the requires_grad flag to false
122
+ with torch.no_grad():
123
+ # Sampler that restricts data loading to a subset of the dataset. Distributed sampler for distributed batch.
124
+ # Which samples take (randomization?)
125
+ val_sampler = DistributedSampler(valset) if distributed_run else None
126
+ # data loader wraper to the validation data (same as for the training data)
127
+ val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, shuffle=False, batch_size=batch_size,
128
+ pin_memory=False, collate_fn=collate_fn)
129
+
130
+ val_loss = 0.0
131
+ for i, batch in enumerate(val_loader):
132
+ x, y = model.parse_batch(batch)
133
+ y_pred = model(x)
134
+ _, _, _, _, gst_scores = y_pred
135
+ if i == 0:
136
+ validation_gst_scores = gst_scores
137
+ else:
138
+ validation_gst_scores = torch.cat((validation_gst_scores, gst_scores), 0)
139
+ loss = criterion(y_pred, y)
140
+ if distributed_run:
141
+ reduced_val_loss = reduce_tensor(loss.data, n_gpus).item() # gets the pure float value with item()
142
+ else:
143
+ reduced_val_loss = loss.item()
144
+ val_loss += reduced_val_loss
145
+ val_loss = val_loss / (i + 1) # Averaged val_loss from all batches
146
+
147
+ model.train()
148
+ if rank == 0:
149
+ print("Validation loss {}: {:9f} ".format(iteration, val_loss)) # I changed this
150
+ # print("GST scores of the validation set: {}".format(validation_gst_scores.shape))
151
+ logger.log_validation(reduced_val_loss, model, y, y_pred, validation_gst_scores, iteration)
152
+
153
+
154
+ # ------------------------------------------- MAIN TRAINING METHOD -------------------------------------------------- #
155
+
156
+ def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, rank, group_name,
157
+ hyper_params, train_loader, valset, collate_fn):
158
+ """Training and validation method with logging results to tensorboard and stdout
159
+
160
+ :param output_directory (string): directory to save checkpoints
161
+ :param log_directory (string): directory to save tensorboard logs
162
+ :param checkpoint_path (string): checkpoint path
163
+ :param n_gpus (int): number of gpus
164
+ :param rank (int): rank of current gpu
165
+ :param hyper_params (object dictionary): dictionary with all hyper parameters
166
+ """
167
+
168
+ # Check whether is a distributed running
169
+ if hyper_params['distributed_run']:
170
+ init_distributed(hyper_params, n_gpus, rank, group_name)
171
+
172
+ # set the same fixed seed to reproduce same results everytime we train
173
+ torch.manual_seed(hyper_params['seed'])
174
+ torch.cuda.manual_seed(hyper_params['seed'])
175
+
176
+ model = load_model(hyper_params)
177
+ learning_rate = hyper_params['learning_rate']
178
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=hyper_params['weight_decay'])
179
+
180
+ if hyper_params['fp16_run']:
181
+ optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=hyper_params['dynamic_loss_scaling'])
182
+
183
+ # Define the criterion of the loss function. The objective.
184
+ criterion = Tacotron2Loss()
185
+
186
+ logger = prepare_directories_and_logger(output_directory, log_directory, rank)
187
+ # logger = ''
188
+
189
+ iteration = 0
190
+ epoch_offset = 0
191
+ if checkpoint_path is not None:
192
+ if warm_start:
193
+ # Re-start the model from the last checkpoint if we save the parameters and don't want to start from 0
194
+ model = warm_start_model(checkpoint_path, model)
195
+ else:
196
+ # CHECK THIS OUT!!!
197
+ model, optimizer, _learning_rate, iteration = load_checkpoint(checkpoint_path, model, optimizer)
198
+ if hyper_params['use_saved_learning_rate']:
199
+ learning_rate = _learning_rate
200
+ iteration += 1 # next iteration is iteration + 1
201
+ epoch_offset = max(0, int(iteration / len(train_loader)))
202
+
203
+ # Set this to make all modules and regularization aware this is the training stage:
204
+ model.train()
205
+
206
+ # MAIN LOOP
207
+ for epoch in range(epoch_offset, hyper_params['epochs']):
208
+ print("Epoch: {}".format(epoch))
209
+ for i, batch in enumerate(train_loader):
210
+ start = time.perf_counter()
211
+ # CHECK THIS OUT!!!
212
+ for param_group in optimizer.param_groups:
213
+ param_group['lr'] = learning_rate
214
+
215
+ model.zero_grad()
216
+ input_data, output_target = model.parse_batch(batch)
217
+ output_predicted = model(input_data)
218
+
219
+ loss = criterion(output_predicted, output_target)
220
+
221
+ if hyper_params['distributed_run']:
222
+ reduced_loss = reduce_tensor(loss.data, n_gpus).item()
223
+ else:
224
+ reduced_loss = loss.item()
225
+
226
+ if hyper_params['fp16_run']:
227
+ optimizer.backward(loss) # transformed optimizer into fp16 type
228
+ grad_norm = optimizer.clip_fp32_grads(hyper_params['grad_clip_thresh'])
229
+ else:
230
+ loss.backward()
231
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hyper_params['grad_clip_thresh'])
232
+
233
+ # Performs a single optimization step (parameter update)
234
+ optimizer.step()
235
+ # This boolean controls overflow when running in fp16 optimizer
236
+ overflow = optimizer.overflow if hyper_params['fp16_run'] else False
237
+
238
+ # If overflow is True, it will not enter. If isnan is True, it will not enter neither.
239
+ if not overflow and not math.isnan(reduced_loss) and rank == 0:
240
+ duration = time.perf_counter() - start
241
+ print("Train loss {} {:.6f} Grand Norm {:.6f} {:.2f}s/it".format(iteration, reduced_loss,
242
+ grad_norm, duration))
243
+ # logs training information of the current iteration
244
+ logger.log_training(reduced_loss, grad_norm, learning_rate, duration, iteration)
245
+
246
+ # Every iters_per_checkpoint steps there is a validation of the model and its updated parameters
247
+ if not overflow and (iteration % hyper_params['iters_per_checkpoint'] == 0):
248
+ validate(model, criterion, valset, iteration, hyper_params['batch_size'], n_gpus, collate_fn,
249
+ logger, hyper_params['distributed_run'], rank)
250
+ if rank == 0:
251
+ checkpoint_path = os.path.join(output_directory, "checkpoint_{}".format(iteration))
252
+ save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
253
+
254
+ iteration += 1