Spaces:
Sleeping
Sleeping
Delete training.py
Browse files- training.py +0 -254
training.py
DELETED
@@ -1,254 +0,0 @@
|
|
1 |
-
# ONE EPOCH = one forward pass and one backward pass of all the training examples.
|
2 |
-
#
|
3 |
-
# BATCH SIZE = the number of training examples in one forward/backward pass. The
|
4 |
-
# higher the batch size, the more memory space you'll need.
|
5 |
-
#
|
6 |
-
# NUMBER OF ITERATIONS = number of passes, each pass using [batch size] number of
|
7 |
-
# examples. To be clear, one pass = one forward pass + one backward pass.
|
8 |
-
#
|
9 |
-
# Example: if you have 1000 training examples, and your batch size is 500, then
|
10 |
-
# it will take 2 iterations to complete 1 epoch.
|
11 |
-
|
12 |
-
import os
|
13 |
-
import time
|
14 |
-
import math
|
15 |
-
|
16 |
-
import torch
|
17 |
-
import torch.distributed as dist
|
18 |
-
from torch.utils.data.distributed import DistributedSampler
|
19 |
-
from torch.utils.data import DataLoader
|
20 |
-
from numpy import finfo
|
21 |
-
|
22 |
-
from Tacotron2 import tacotron_2
|
23 |
-
from fp16_optimizer import FP16_Optimizer
|
24 |
-
# from distributed import apply_gradient_allreduce
|
25 |
-
from loss_function import Tacotron2Loss
|
26 |
-
from logger import Tacotron2Logger
|
27 |
-
|
28 |
-
|
29 |
-
def batchnorm_to_float(module):
|
30 |
-
"""Converts batch norm modules to FP32"""
|
31 |
-
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
32 |
-
module.float()
|
33 |
-
for child in module.children():
|
34 |
-
batchnorm_to_float(child)
|
35 |
-
return module
|
36 |
-
|
37 |
-
|
38 |
-
def reduce_tensor(tensor, n_gpus):
|
39 |
-
# this function is recorded in the computation graph. Gradients propagating to the cloned tensor will propagate to
|
40 |
-
# the original tensor
|
41 |
-
rt = tensor.clone()
|
42 |
-
# Each rank has a tensor and all_reduce sums up all tensors from different ranks to all ranks. Computes the average
|
43 |
-
# of the tensor results of all ranks (a rank is a gpu as far as I understood):
|
44 |
-
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
45 |
-
rt /= n_gpus
|
46 |
-
return rt
|
47 |
-
|
48 |
-
|
49 |
-
def prepare_directories_and_logger(output_directory, log_directory, rank):
|
50 |
-
if rank == 0:
|
51 |
-
if not os.path.isdir(output_directory):
|
52 |
-
os.makedirs(output_directory)
|
53 |
-
os.chmod(output_directory, 0o775)
|
54 |
-
logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
|
55 |
-
# logger = None
|
56 |
-
else:
|
57 |
-
logger = None
|
58 |
-
return logger
|
59 |
-
|
60 |
-
|
61 |
-
def warm_start_model(checkpoint_path, model):
|
62 |
-
assert os.path.isfile(checkpoint_path)
|
63 |
-
print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
|
64 |
-
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
65 |
-
model.load_state_dict(checkpoint_dict['state_dict'])
|
66 |
-
return model
|
67 |
-
|
68 |
-
|
69 |
-
def load_checkpoint(checkpoint_path, model, optimizer):
|
70 |
-
assert os.path.isfile(checkpoint_path)
|
71 |
-
print("Loading checkpoint '{}'".format(checkpoint_path))
|
72 |
-
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
73 |
-
model.load_state_dict(checkpoint_dict['state_dict'])
|
74 |
-
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
75 |
-
learning_rate = checkpoint_dict['learning_rate']
|
76 |
-
iteration = checkpoint_dict['iteration']
|
77 |
-
print("Loaded checkpoint '{}' from iteration {}".format(checkpoint_path, iteration))
|
78 |
-
return model, optimizer, learning_rate, iteration
|
79 |
-
|
80 |
-
|
81 |
-
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
|
82 |
-
print("Saving model and optimizer state at iteration {} to {}".format(iteration, filepath))
|
83 |
-
torch.save({'iteration': iteration,
|
84 |
-
'state_dict': model.state_dict(),
|
85 |
-
'optimizer': optimizer.state_dict(),
|
86 |
-
'learning_rate': learning_rate}, filepath)
|
87 |
-
|
88 |
-
|
89 |
-
def init_distributed(hyper_params, n_gpus, rank, group_name):
|
90 |
-
assert torch.cuda.is_available(), "Distributed mode requires CUDA"
|
91 |
-
print("Initializing distributed")
|
92 |
-
# Set CUDA device so everything is done on the right GPU
|
93 |
-
torch.cuda.set_device(rank % torch.cuda.device_count())
|
94 |
-
|
95 |
-
# Initialize distributed communication
|
96 |
-
torch.distributed.init_process_group(backend=hyper_params['dist_backend'], rank=rank, world_size=n_gpus,
|
97 |
-
init_method=hyper_params['dist_url'], group_name=group_name)
|
98 |
-
|
99 |
-
print("Initializing distributed: Done")
|
100 |
-
|
101 |
-
|
102 |
-
def load_model(hyper_params):
|
103 |
-
# according to the documentation, it is recommended to move a model to GPU before constructing the optimizer
|
104 |
-
# model = tacotron_2(hyper_params).cuda()
|
105 |
-
model = tacotron_2(hyper_params)
|
106 |
-
if hyper_params['fp16_run']: # converts everything into half type (16 bits)
|
107 |
-
model = batchnorm_to_float(model.half())
|
108 |
-
model.decoder.attention_layer.score_mask_value = float(finfo('float16').min)
|
109 |
-
|
110 |
-
# if hyper_params['distributed_run']:
|
111 |
-
# model = apply_gradient_allreduce(model)
|
112 |
-
|
113 |
-
return model
|
114 |
-
|
115 |
-
|
116 |
-
def validate(model, criterion, valset, iteration, batch_size, n_gpus, collate_fn, logger, distributed_run, rank):
|
117 |
-
"""Handles all the validation scoring and printing"""
|
118 |
-
|
119 |
-
# We change to eval() because this is an evaluation stage and not a training
|
120 |
-
model.eval()
|
121 |
-
# temporarily set all the requires_grad flag to false
|
122 |
-
with torch.no_grad():
|
123 |
-
# Sampler that restricts data loading to a subset of the dataset. Distributed sampler for distributed batch.
|
124 |
-
# Which samples take (randomization?)
|
125 |
-
val_sampler = DistributedSampler(valset) if distributed_run else None
|
126 |
-
# data loader wraper to the validation data (same as for the training data)
|
127 |
-
val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, shuffle=False, batch_size=batch_size,
|
128 |
-
pin_memory=False, collate_fn=collate_fn)
|
129 |
-
|
130 |
-
val_loss = 0.0
|
131 |
-
for i, batch in enumerate(val_loader):
|
132 |
-
x, y = model.parse_batch(batch)
|
133 |
-
y_pred = model(x)
|
134 |
-
_, _, _, _, gst_scores = y_pred
|
135 |
-
if i == 0:
|
136 |
-
validation_gst_scores = gst_scores
|
137 |
-
else:
|
138 |
-
validation_gst_scores = torch.cat((validation_gst_scores, gst_scores), 0)
|
139 |
-
loss = criterion(y_pred, y)
|
140 |
-
if distributed_run:
|
141 |
-
reduced_val_loss = reduce_tensor(loss.data, n_gpus).item() # gets the pure float value with item()
|
142 |
-
else:
|
143 |
-
reduced_val_loss = loss.item()
|
144 |
-
val_loss += reduced_val_loss
|
145 |
-
val_loss = val_loss / (i + 1) # Averaged val_loss from all batches
|
146 |
-
|
147 |
-
model.train()
|
148 |
-
if rank == 0:
|
149 |
-
print("Validation loss {}: {:9f} ".format(iteration, val_loss)) # I changed this
|
150 |
-
# print("GST scores of the validation set: {}".format(validation_gst_scores.shape))
|
151 |
-
logger.log_validation(reduced_val_loss, model, y, y_pred, validation_gst_scores, iteration)
|
152 |
-
|
153 |
-
|
154 |
-
# ------------------------------------------- MAIN TRAINING METHOD -------------------------------------------------- #
|
155 |
-
|
156 |
-
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, rank, group_name,
|
157 |
-
hyper_params, train_loader, valset, collate_fn):
|
158 |
-
"""Training and validation method with logging results to tensorboard and stdout
|
159 |
-
|
160 |
-
:param output_directory (string): directory to save checkpoints
|
161 |
-
:param log_directory (string): directory to save tensorboard logs
|
162 |
-
:param checkpoint_path (string): checkpoint path
|
163 |
-
:param n_gpus (int): number of gpus
|
164 |
-
:param rank (int): rank of current gpu
|
165 |
-
:param hyper_params (object dictionary): dictionary with all hyper parameters
|
166 |
-
"""
|
167 |
-
|
168 |
-
# Check whether is a distributed running
|
169 |
-
if hyper_params['distributed_run']:
|
170 |
-
init_distributed(hyper_params, n_gpus, rank, group_name)
|
171 |
-
|
172 |
-
# set the same fixed seed to reproduce same results everytime we train
|
173 |
-
torch.manual_seed(hyper_params['seed'])
|
174 |
-
torch.cuda.manual_seed(hyper_params['seed'])
|
175 |
-
|
176 |
-
model = load_model(hyper_params)
|
177 |
-
learning_rate = hyper_params['learning_rate']
|
178 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=hyper_params['weight_decay'])
|
179 |
-
|
180 |
-
if hyper_params['fp16_run']:
|
181 |
-
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=hyper_params['dynamic_loss_scaling'])
|
182 |
-
|
183 |
-
# Define the criterion of the loss function. The objective.
|
184 |
-
criterion = Tacotron2Loss()
|
185 |
-
|
186 |
-
logger = prepare_directories_and_logger(output_directory, log_directory, rank)
|
187 |
-
# logger = ''
|
188 |
-
|
189 |
-
iteration = 0
|
190 |
-
epoch_offset = 0
|
191 |
-
if checkpoint_path is not None:
|
192 |
-
if warm_start:
|
193 |
-
# Re-start the model from the last checkpoint if we save the parameters and don't want to start from 0
|
194 |
-
model = warm_start_model(checkpoint_path, model)
|
195 |
-
else:
|
196 |
-
# CHECK THIS OUT!!!
|
197 |
-
model, optimizer, _learning_rate, iteration = load_checkpoint(checkpoint_path, model, optimizer)
|
198 |
-
if hyper_params['use_saved_learning_rate']:
|
199 |
-
learning_rate = _learning_rate
|
200 |
-
iteration += 1 # next iteration is iteration + 1
|
201 |
-
epoch_offset = max(0, int(iteration / len(train_loader)))
|
202 |
-
|
203 |
-
# Set this to make all modules and regularization aware this is the training stage:
|
204 |
-
model.train()
|
205 |
-
|
206 |
-
# MAIN LOOP
|
207 |
-
for epoch in range(epoch_offset, hyper_params['epochs']):
|
208 |
-
print("Epoch: {}".format(epoch))
|
209 |
-
for i, batch in enumerate(train_loader):
|
210 |
-
start = time.perf_counter()
|
211 |
-
# CHECK THIS OUT!!!
|
212 |
-
for param_group in optimizer.param_groups:
|
213 |
-
param_group['lr'] = learning_rate
|
214 |
-
|
215 |
-
model.zero_grad()
|
216 |
-
input_data, output_target = model.parse_batch(batch)
|
217 |
-
output_predicted = model(input_data)
|
218 |
-
|
219 |
-
loss = criterion(output_predicted, output_target)
|
220 |
-
|
221 |
-
if hyper_params['distributed_run']:
|
222 |
-
reduced_loss = reduce_tensor(loss.data, n_gpus).item()
|
223 |
-
else:
|
224 |
-
reduced_loss = loss.item()
|
225 |
-
|
226 |
-
if hyper_params['fp16_run']:
|
227 |
-
optimizer.backward(loss) # transformed optimizer into fp16 type
|
228 |
-
grad_norm = optimizer.clip_fp32_grads(hyper_params['grad_clip_thresh'])
|
229 |
-
else:
|
230 |
-
loss.backward()
|
231 |
-
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hyper_params['grad_clip_thresh'])
|
232 |
-
|
233 |
-
# Performs a single optimization step (parameter update)
|
234 |
-
optimizer.step()
|
235 |
-
# This boolean controls overflow when running in fp16 optimizer
|
236 |
-
overflow = optimizer.overflow if hyper_params['fp16_run'] else False
|
237 |
-
|
238 |
-
# If overflow is True, it will not enter. If isnan is True, it will not enter neither.
|
239 |
-
if not overflow and not math.isnan(reduced_loss) and rank == 0:
|
240 |
-
duration = time.perf_counter() - start
|
241 |
-
print("Train loss {} {:.6f} Grand Norm {:.6f} {:.2f}s/it".format(iteration, reduced_loss,
|
242 |
-
grad_norm, duration))
|
243 |
-
# logs training information of the current iteration
|
244 |
-
logger.log_training(reduced_loss, grad_norm, learning_rate, duration, iteration)
|
245 |
-
|
246 |
-
# Every iters_per_checkpoint steps there is a validation of the model and its updated parameters
|
247 |
-
if not overflow and (iteration % hyper_params['iters_per_checkpoint'] == 0):
|
248 |
-
validate(model, criterion, valset, iteration, hyper_params['batch_size'], n_gpus, collate_fn,
|
249 |
-
logger, hyper_params['distributed_run'], rank)
|
250 |
-
if rank == 0:
|
251 |
-
checkpoint_path = os.path.join(output_directory, "checkpoint_{}".format(iteration))
|
252 |
-
save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
|
253 |
-
|
254 |
-
iteration += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|