KingNish commited on
Commit
b49f9aa
·
verified ·
1 Parent(s): 5320d2b

Upload ./RepCodec/trainer/autoencoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. RepCodec/trainer/autoencoder.py +287 -0
RepCodec/trainer/autoencoder.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import logging
9
+ import os
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ from tensorboardX import SummaryWriter
14
+ from tqdm import tqdm
15
+
16
+ logger = logging.getLogger("repcodec_train")
17
+
18
+
19
+ class Trainer:
20
+ def __init__(
21
+ self,
22
+ steps: int,
23
+ epochs: int,
24
+ data_loader: dict,
25
+ model: dict,
26
+ criterion: dict,
27
+ optimizer: dict,
28
+ scheduler: dict,
29
+ config: dict,
30
+ device=torch.device("cpu"),
31
+ ):
32
+ self.steps = steps
33
+ self.epochs = epochs
34
+ self.data_loader = data_loader
35
+ self.model = model
36
+ self.criterion = criterion
37
+ self.optimizer = optimizer
38
+ self.scheduler = scheduler
39
+ self.config = config
40
+ self.device = device
41
+ self.writer = SummaryWriter(config["outdir"])
42
+ self.total_train_loss = defaultdict(float)
43
+ self.total_eval_loss = defaultdict(float)
44
+ self.train_max_steps = config.get("train_max_steps", 0)
45
+
46
+ def _train_step(self, batch):
47
+ """Single step of training."""
48
+ mode = "train"
49
+ x = batch
50
+ x = x.to(self.device)
51
+
52
+ codec_loss = 0.0
53
+ y_, zq, z, vqloss, perplexity = self.model["repcodec"](x)
54
+ self._perplexity(perplexity, mode=mode)
55
+ codec_loss += self._vq_loss(vqloss, mode=mode)
56
+ codec_loss += self._metric_loss(y_, x, mode=mode)
57
+
58
+ self._record_loss("codec_loss", codec_loss, mode=mode)
59
+ self._update_repcodec(codec_loss)
60
+
61
+ self.steps += 1
62
+ self.tqdm.update(1)
63
+ self._check_train_finish()
64
+
65
+ @torch.no_grad()
66
+ def _eval_step(self, batch):
67
+ """Single step of evaluation."""
68
+ mode = "eval"
69
+ x = batch
70
+ x = x.to(self.device)
71
+
72
+ codec_loss = 0.0
73
+ y_, zq, z, vqloss, perplexity = self.model["repcodec"](x)
74
+ self._perplexity(perplexity, mode=mode)
75
+ codec_loss += self._vq_loss(vqloss, mode=mode)
76
+ codec_loss += self._metric_loss(y_, x, mode=mode)
77
+
78
+ self._record_loss("codec_loss", codec_loss, mode=mode)
79
+
80
+ def run(self):
81
+ """Run training."""
82
+ self.finish_train = False
83
+ self.tqdm = tqdm(
84
+ initial=self.steps, total=self.train_max_steps, desc="[train]"
85
+ )
86
+ while True:
87
+ self._train_epoch()
88
+
89
+ # check whether training is finished
90
+ if self.finish_train:
91
+ break
92
+
93
+ self.tqdm.close()
94
+ logger.info("Finished training.")
95
+
96
+ def save_checkpoint(self, checkpoint_path: str):
97
+ state_dict = {
98
+ "model": {
99
+ "repcodec": self.model["repcodec"].state_dict()
100
+ },
101
+ "optimizer": {
102
+ "repcodec": self.optimizer["repcodec"].state_dict(),
103
+ },
104
+ "scheduler": {
105
+ "repcodec": self.scheduler["repcodec"].state_dict(),
106
+ },
107
+ "steps": self.steps,
108
+ "epochs": self.epochs,
109
+ }
110
+
111
+ if not os.path.exists(os.path.dirname(checkpoint_path)):
112
+ os.makedirs(os.path.dirname(checkpoint_path))
113
+ torch.save(state_dict, checkpoint_path)
114
+
115
+ def load_checkpoint(
116
+ self,
117
+ checkpoint_path: str,
118
+ strict: bool = True,
119
+ load_only_params: bool = False
120
+ ):
121
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
122
+ self.model["repcodec"].load_state_dict(
123
+ state_dict["model"]["repcodec"], strict=strict
124
+ )
125
+
126
+ if not load_only_params:
127
+ self.steps = state_dict["steps"]
128
+ self.epochs = state_dict["epochs"]
129
+ self.optimizer["repcodec"].load_state_dict(
130
+ state_dict["optimizer"]["repcodec"]
131
+ )
132
+ self.scheduler["repcodec"].load_state_dict(
133
+ state_dict["scheduler"]["repcodec"]
134
+ )
135
+
136
+ def _train_epoch(self):
137
+ """One epoch of training."""
138
+ for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1):
139
+ # train one step
140
+ self._train_step(batch)
141
+
142
+ # check interval
143
+ self._check_log_interval()
144
+ self._check_eval_interval()
145
+ self._check_save_interval()
146
+
147
+ # check whether training is finished
148
+ if self.finish_train:
149
+ return
150
+
151
+ # update
152
+ self.epochs += 1
153
+ self.train_steps_per_epoch = train_steps_per_epoch
154
+ if train_steps_per_epoch > 200:
155
+ logger.info(
156
+ f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
157
+ f"({self.train_steps_per_epoch} steps per epoch)."
158
+ )
159
+
160
+ def _eval_epoch(self):
161
+ """One epoch of evaluation."""
162
+ logger.info(f"(Steps: {self.steps}) Start evaluation.")
163
+ # change mode
164
+ for key in self.model.keys():
165
+ self.model[key].eval()
166
+
167
+ # calculate loss for each batch
168
+ for eval_steps_per_epoch, batch in enumerate(
169
+ tqdm(self.data_loader["dev"], desc="[eval]"), 1
170
+ ):
171
+ # eval one step
172
+ self._eval_step(batch)
173
+
174
+ logger.info(
175
+ f"(Steps: {self.steps}) Finished evaluation "
176
+ f"({eval_steps_per_epoch} steps per epoch)."
177
+ )
178
+
179
+ # average loss
180
+ for key in self.total_eval_loss.keys():
181
+ self.total_eval_loss[key] /= eval_steps_per_epoch
182
+ logger.info(
183
+ f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}."
184
+ )
185
+
186
+ # record
187
+ self._write_to_tensorboard(self.total_eval_loss)
188
+
189
+ # reset
190
+ self.total_eval_loss = defaultdict(float)
191
+
192
+ # restore mode
193
+ for key in self.model.keys():
194
+ self.model[key].train()
195
+
196
+ def _metric_loss(self, predict_y, natural_y, mode='train'):
197
+ """Metric losses."""
198
+ metric_loss = 0.0
199
+
200
+ repr_reconstruct_loss = self.criterion["repr_reconstruct_loss"](predict_y, natural_y)
201
+ repr_reconstruct_loss *= self.config["lambda_repr_reconstruct_loss"]
202
+ self._record_loss("reconstruct_loss", repr_reconstruct_loss, mode=mode)
203
+ metric_loss += repr_reconstruct_loss
204
+
205
+ return metric_loss
206
+
207
+ def _update_repcodec(self, repr_loss):
208
+ """Update generator."""
209
+ self.optimizer["repcodec"].zero_grad()
210
+ repr_loss.backward()
211
+ if self.config["grad_norm"] > 0:
212
+ torch.nn.utils.clip_grad_norm_(
213
+ self.model["repcodec"].parameters(),
214
+ self.config["grad_norm"],
215
+ )
216
+ self.optimizer["repcodec"].step()
217
+ self.scheduler["repcodec"].step()
218
+
219
+ def _record_loss(self, name: str, loss, mode='train'):
220
+ """Record loss."""
221
+ if torch.is_tensor(loss):
222
+ loss = loss.item()
223
+
224
+ if mode == 'train':
225
+ self.total_train_loss[f"train/{name}"] += loss
226
+ elif mode == 'eval':
227
+ self.total_eval_loss[f"eval/{name}"] += loss
228
+ else:
229
+ raise NotImplementedError(f"Mode ({mode}) is not supported!")
230
+
231
+ def _write_to_tensorboard(self, loss):
232
+ """Write to tensorboard."""
233
+ for key, value in loss.items():
234
+ self.writer.add_scalar(key, value, self.steps)
235
+
236
+ def _check_save_interval(self):
237
+ if self.steps and (self.steps % self.config["save_interval_steps"] == 0):
238
+ self.save_checkpoint(
239
+ os.path.join(self.config["outdir"], f"checkpoint-{self.steps}steps.pkl")
240
+ )
241
+ logger.info(f"Successfully saved checkpoint @ {self.steps} steps.")
242
+
243
+ def _check_eval_interval(self):
244
+ if self.steps % self.config["eval_interval_steps"] == 0:
245
+ self._eval_epoch()
246
+
247
+ def _check_log_interval(self):
248
+ if self.steps % self.config["log_interval_steps"] == 0:
249
+ for key in self.total_train_loss.keys():
250
+ self.total_train_loss[key] /= self.config["log_interval_steps"]
251
+ logger.info(
252
+ f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}."
253
+ )
254
+ self._write_to_tensorboard(self.total_train_loss)
255
+
256
+ # reset
257
+ self.total_train_loss = defaultdict(float)
258
+
259
+ def _check_train_finish(self):
260
+ if self.steps >= self.train_max_steps:
261
+ self.finish_train = True
262
+ else:
263
+ self.finish_train = False
264
+ return self.finish_train
265
+
266
+ def _perplexity(self, perplexity, label=None, mode='train'):
267
+ if label:
268
+ name = f"{mode}/ppl_{label}"
269
+ else:
270
+ name = f"{mode}/ppl"
271
+ if torch.numel(perplexity) > 1:
272
+ perplexity = perplexity.tolist()
273
+ for idx, ppl in enumerate(perplexity):
274
+ self._record_loss(f"{name}_{idx}", ppl, mode=mode)
275
+ else:
276
+ self._record_loss(name, perplexity, mode=mode)
277
+
278
+ def _vq_loss(self, vqloss, label=None, mode='train'):
279
+ if label:
280
+ name = f"{mode}/vqloss_{label}"
281
+ else:
282
+ name = f"{mode}/vqloss"
283
+ vqloss = torch.sum(vqloss)
284
+ vqloss *= self.config["lambda_vq_loss"]
285
+ self._record_loss(name, vqloss, mode=mode)
286
+
287
+ return vqloss