KingNish commited on
Commit
2923eb3
·
verified ·
1 Parent(s): 088ba8b

Upload ./RepCodec/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. RepCodec/train.py +228 -0
RepCodec/train.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the CC BY-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on AudioDec (https://github.com/facebookresearch/AudioDec)
7
+
8
+ import argparse
9
+ import logging
10
+
11
+ import os
12
+
13
+ logging.basicConfig(
14
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
15
+ datefmt="%Y-%m-%d %H:%M:%S",
16
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
17
+ )
18
+ logger = logging.getLogger("repcodec_train") # init logger before other modules
19
+
20
+ import random
21
+
22
+ import numpy as np
23
+ import torch
24
+ import yaml
25
+ from torch.utils.data import DataLoader
26
+
27
+ from dataloader import ReprDataset, ReprCollater
28
+ from losses.repr_reconstruct_loss import ReprReconstructLoss
29
+ from repcodec.RepCodec import RepCodec
30
+ from trainer.autoencoder import Trainer
31
+
32
+
33
+ class TrainMain:
34
+ def __init__(self, args):
35
+ # Fix seed and make backends deterministic
36
+ random.seed(args.seed)
37
+ np.random.seed(args.seed)
38
+ torch.manual_seed(args.seed)
39
+ if not torch.cuda.is_available():
40
+ self.device = torch.device('cpu')
41
+ logger.info(f"device: cpu")
42
+ else:
43
+ self.device = torch.device('cuda:0') # only supports single gpu for now
44
+ logger.info(f"device: gpu")
45
+ torch.cuda.manual_seed_all(args.seed)
46
+ if args.disable_cudnn == "False":
47
+ torch.backends.cudnn.benchmark = True
48
+
49
+ # initialize config
50
+ with open(args.config, 'r') as f:
51
+ self.config = yaml.load(f, Loader=yaml.FullLoader)
52
+ self.config.update(vars(args))
53
+
54
+ # initialize model folder
55
+ expdir = os.path.join(args.exp_root, args.tag)
56
+ os.makedirs(expdir, exist_ok=True)
57
+ self.config["outdir"] = expdir
58
+
59
+ # save config
60
+ with open(os.path.join(expdir, "config.yml"), "w") as f:
61
+ yaml.dump(self.config, f, Dumper=yaml.Dumper)
62
+ for key, value in self.config.items():
63
+ logger.info(f"{key} = {value}")
64
+
65
+ # initialize attribute
66
+ self.resume: str = args.resume
67
+ self.data_loader = None
68
+ self.model = None
69
+ self.optimizer = None
70
+ self.scheduler = None
71
+ self.criterion = None
72
+ self.trainer = None
73
+
74
+ # initialize batch_length
75
+ self.batch_length: int = self.config['batch_length']
76
+ self.data_path: str = self.config['data']['path']
77
+
78
+ def initialize_data_loader(self):
79
+ train_set = self._build_dataset("train")
80
+ valid_set = self._build_dataset("valid")
81
+ collater = ReprCollater()
82
+
83
+ logger.info(f"The number of training files = {len(train_set)}.")
84
+ logger.info(f"The number of validation files = {len(valid_set)}.")
85
+ dataset = {"train": train_set, "dev": valid_set}
86
+ self._set_data_loader(dataset, collater)
87
+
88
+ def define_model_optimizer_scheduler(self):
89
+ # model arch
90
+ self.model = {
91
+ "repcodec": RepCodec(**self.config["model_params"]).to(self.device)
92
+ }
93
+ logger.info(f"Model Arch:\n{self.model['repcodec']}")
94
+
95
+ # opt
96
+ optimizer_class = getattr(
97
+ torch.optim,
98
+ self.config["model_optimizer_type"]
99
+ )
100
+ self.optimizer = {
101
+ "repcodec": optimizer_class(
102
+ self.model["repcodec"].parameters(),
103
+ **self.config["model_optimizer_params"]
104
+ )
105
+ }
106
+
107
+ # scheduler
108
+ scheduler_class = getattr(
109
+ torch.optim.lr_scheduler,
110
+ self.config.get("model_scheduler_type", "StepLR"),
111
+ )
112
+ self.scheduler = {
113
+ "repcodec": scheduler_class(
114
+ optimizer=self.optimizer["repcodec"],
115
+ **self.config["model_scheduler_params"]
116
+ )
117
+ }
118
+
119
+ def define_criterion(self):
120
+ self.criterion = {
121
+ "repr_reconstruct_loss": ReprReconstructLoss(
122
+ **self.config.get("repr_reconstruct_loss_params", {}),
123
+ ).to(self.device)
124
+ }
125
+
126
+ def define_trainer(self):
127
+ self.trainer = Trainer(
128
+ steps=0,
129
+ epochs=0,
130
+ data_loader=self.data_loader,
131
+ model=self.model,
132
+ criterion=self.criterion,
133
+ optimizer=self.optimizer,
134
+ scheduler=self.scheduler,
135
+ config=self.config,
136
+ device=self.device
137
+ )
138
+
139
+ def initialize_model(self):
140
+ initial = self.config.get("initial", "")
141
+ if os.path.exists(self.resume): # resume from trained model
142
+ self.trainer.load_checkpoint(self.resume)
143
+ logger.info(f"Successfully resumed from {self.resume}.")
144
+ elif os.path.exists(initial): # initial new model with the pre-trained model
145
+ self.trainer.load_checkpoint(initial, load_only_params=True)
146
+ logger.info(f"Successfully initialize parameters from {initial}.")
147
+ else:
148
+ logger.info("Train from scrach")
149
+
150
+ def run(self):
151
+ assert self.trainer is not None
152
+ self.trainer: Trainer
153
+ try:
154
+ logger.info(f"The current training step: {self.trainer.steps}")
155
+ self.trainer.train_max_steps = self.config["train_max_steps"]
156
+ if not self.trainer._check_train_finish():
157
+ self.trainer.run()
158
+ finally:
159
+ self.trainer.save_checkpoint(
160
+ os.path.join(self.config["outdir"], f"checkpoint-{self.trainer.steps}steps.pkl")
161
+ )
162
+ logger.info(f"Successfully saved checkpoint @ {self.trainer.steps}steps.")
163
+
164
+ def _build_dataset(
165
+ self, subset: str
166
+ ) -> ReprDataset:
167
+ data_dir = os.path.join(
168
+ self.data_path, self.config['data']['subset'][subset]
169
+ )
170
+ params = {
171
+ "data_dir": data_dir,
172
+ "batch_len": self.batch_length
173
+ }
174
+ return ReprDataset(**params)
175
+
176
+ def _set_data_loader(self, dataset, collater):
177
+ self.data_loader = {
178
+ "train": DataLoader(
179
+ dataset=dataset["train"],
180
+ shuffle=True,
181
+ collate_fn=collater,
182
+ batch_size=self.config["batch_size"],
183
+ num_workers=self.config["num_workers"],
184
+ pin_memory=self.config["pin_memory"],
185
+ ),
186
+ "dev": DataLoader(
187
+ dataset=dataset["dev"],
188
+ shuffle=False,
189
+ collate_fn=collater,
190
+ batch_size=self.config["batch_size"],
191
+ num_workers=0,
192
+ pin_memory=False, # save some memory. set to True if you have enough memory.
193
+ ),
194
+ }
195
+
196
+
197
+ def train():
198
+ parser = argparse.ArgumentParser()
199
+ parser.add_argument(
200
+ "-c", "--config", type=str, required=True,
201
+ help="the path of config yaml file."
202
+ )
203
+ parser.add_argument(
204
+ "--tag", type=str, required=True,
205
+ help="the outputs will be saved to exp_root/tag/"
206
+ )
207
+ parser.add_argument(
208
+ "--exp_root", type=str, default="exp"
209
+ )
210
+ parser.add_argument(
211
+ "--resume", default="", type=str, nargs="?",
212
+ help='checkpoint file path to resume training. (default="")',
213
+ )
214
+ parser.add_argument("--seed", default=1337, type=int)
215
+ parser.add_argument("--disable_cudnn", choices=("True", "False"), default="False", help="Disable CUDNN")
216
+ args = parser.parse_args()
217
+
218
+ train_main = TrainMain(args)
219
+ train_main.initialize_data_loader()
220
+ train_main.define_model_optimizer_scheduler()
221
+ train_main.define_criterion()
222
+ train_main.define_trainer()
223
+ train_main.initialize_model()
224
+ train_main.run()
225
+
226
+
227
+ if __name__ == '__main__':
228
+ train()