Upload ./RepCodec/train.py with huggingface_hub
Browse files- RepCodec/train.py +228 -0
RepCodec/train.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
2 |
+
# Copyright (c) Chutong Meng
|
3 |
+
#
|
4 |
+
# This source code is licensed under the CC BY-NC license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# Based on AudioDec (https://github.com/facebookresearch/AudioDec)
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import logging
|
10 |
+
|
11 |
+
import os
|
12 |
+
|
13 |
+
logging.basicConfig(
|
14 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
15 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
16 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
17 |
+
)
|
18 |
+
logger = logging.getLogger("repcodec_train") # init logger before other modules
|
19 |
+
|
20 |
+
import random
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import yaml
|
25 |
+
from torch.utils.data import DataLoader
|
26 |
+
|
27 |
+
from dataloader import ReprDataset, ReprCollater
|
28 |
+
from losses.repr_reconstruct_loss import ReprReconstructLoss
|
29 |
+
from repcodec.RepCodec import RepCodec
|
30 |
+
from trainer.autoencoder import Trainer
|
31 |
+
|
32 |
+
|
33 |
+
class TrainMain:
|
34 |
+
def __init__(self, args):
|
35 |
+
# Fix seed and make backends deterministic
|
36 |
+
random.seed(args.seed)
|
37 |
+
np.random.seed(args.seed)
|
38 |
+
torch.manual_seed(args.seed)
|
39 |
+
if not torch.cuda.is_available():
|
40 |
+
self.device = torch.device('cpu')
|
41 |
+
logger.info(f"device: cpu")
|
42 |
+
else:
|
43 |
+
self.device = torch.device('cuda:0') # only supports single gpu for now
|
44 |
+
logger.info(f"device: gpu")
|
45 |
+
torch.cuda.manual_seed_all(args.seed)
|
46 |
+
if args.disable_cudnn == "False":
|
47 |
+
torch.backends.cudnn.benchmark = True
|
48 |
+
|
49 |
+
# initialize config
|
50 |
+
with open(args.config, 'r') as f:
|
51 |
+
self.config = yaml.load(f, Loader=yaml.FullLoader)
|
52 |
+
self.config.update(vars(args))
|
53 |
+
|
54 |
+
# initialize model folder
|
55 |
+
expdir = os.path.join(args.exp_root, args.tag)
|
56 |
+
os.makedirs(expdir, exist_ok=True)
|
57 |
+
self.config["outdir"] = expdir
|
58 |
+
|
59 |
+
# save config
|
60 |
+
with open(os.path.join(expdir, "config.yml"), "w") as f:
|
61 |
+
yaml.dump(self.config, f, Dumper=yaml.Dumper)
|
62 |
+
for key, value in self.config.items():
|
63 |
+
logger.info(f"{key} = {value}")
|
64 |
+
|
65 |
+
# initialize attribute
|
66 |
+
self.resume: str = args.resume
|
67 |
+
self.data_loader = None
|
68 |
+
self.model = None
|
69 |
+
self.optimizer = None
|
70 |
+
self.scheduler = None
|
71 |
+
self.criterion = None
|
72 |
+
self.trainer = None
|
73 |
+
|
74 |
+
# initialize batch_length
|
75 |
+
self.batch_length: int = self.config['batch_length']
|
76 |
+
self.data_path: str = self.config['data']['path']
|
77 |
+
|
78 |
+
def initialize_data_loader(self):
|
79 |
+
train_set = self._build_dataset("train")
|
80 |
+
valid_set = self._build_dataset("valid")
|
81 |
+
collater = ReprCollater()
|
82 |
+
|
83 |
+
logger.info(f"The number of training files = {len(train_set)}.")
|
84 |
+
logger.info(f"The number of validation files = {len(valid_set)}.")
|
85 |
+
dataset = {"train": train_set, "dev": valid_set}
|
86 |
+
self._set_data_loader(dataset, collater)
|
87 |
+
|
88 |
+
def define_model_optimizer_scheduler(self):
|
89 |
+
# model arch
|
90 |
+
self.model = {
|
91 |
+
"repcodec": RepCodec(**self.config["model_params"]).to(self.device)
|
92 |
+
}
|
93 |
+
logger.info(f"Model Arch:\n{self.model['repcodec']}")
|
94 |
+
|
95 |
+
# opt
|
96 |
+
optimizer_class = getattr(
|
97 |
+
torch.optim,
|
98 |
+
self.config["model_optimizer_type"]
|
99 |
+
)
|
100 |
+
self.optimizer = {
|
101 |
+
"repcodec": optimizer_class(
|
102 |
+
self.model["repcodec"].parameters(),
|
103 |
+
**self.config["model_optimizer_params"]
|
104 |
+
)
|
105 |
+
}
|
106 |
+
|
107 |
+
# scheduler
|
108 |
+
scheduler_class = getattr(
|
109 |
+
torch.optim.lr_scheduler,
|
110 |
+
self.config.get("model_scheduler_type", "StepLR"),
|
111 |
+
)
|
112 |
+
self.scheduler = {
|
113 |
+
"repcodec": scheduler_class(
|
114 |
+
optimizer=self.optimizer["repcodec"],
|
115 |
+
**self.config["model_scheduler_params"]
|
116 |
+
)
|
117 |
+
}
|
118 |
+
|
119 |
+
def define_criterion(self):
|
120 |
+
self.criterion = {
|
121 |
+
"repr_reconstruct_loss": ReprReconstructLoss(
|
122 |
+
**self.config.get("repr_reconstruct_loss_params", {}),
|
123 |
+
).to(self.device)
|
124 |
+
}
|
125 |
+
|
126 |
+
def define_trainer(self):
|
127 |
+
self.trainer = Trainer(
|
128 |
+
steps=0,
|
129 |
+
epochs=0,
|
130 |
+
data_loader=self.data_loader,
|
131 |
+
model=self.model,
|
132 |
+
criterion=self.criterion,
|
133 |
+
optimizer=self.optimizer,
|
134 |
+
scheduler=self.scheduler,
|
135 |
+
config=self.config,
|
136 |
+
device=self.device
|
137 |
+
)
|
138 |
+
|
139 |
+
def initialize_model(self):
|
140 |
+
initial = self.config.get("initial", "")
|
141 |
+
if os.path.exists(self.resume): # resume from trained model
|
142 |
+
self.trainer.load_checkpoint(self.resume)
|
143 |
+
logger.info(f"Successfully resumed from {self.resume}.")
|
144 |
+
elif os.path.exists(initial): # initial new model with the pre-trained model
|
145 |
+
self.trainer.load_checkpoint(initial, load_only_params=True)
|
146 |
+
logger.info(f"Successfully initialize parameters from {initial}.")
|
147 |
+
else:
|
148 |
+
logger.info("Train from scrach")
|
149 |
+
|
150 |
+
def run(self):
|
151 |
+
assert self.trainer is not None
|
152 |
+
self.trainer: Trainer
|
153 |
+
try:
|
154 |
+
logger.info(f"The current training step: {self.trainer.steps}")
|
155 |
+
self.trainer.train_max_steps = self.config["train_max_steps"]
|
156 |
+
if not self.trainer._check_train_finish():
|
157 |
+
self.trainer.run()
|
158 |
+
finally:
|
159 |
+
self.trainer.save_checkpoint(
|
160 |
+
os.path.join(self.config["outdir"], f"checkpoint-{self.trainer.steps}steps.pkl")
|
161 |
+
)
|
162 |
+
logger.info(f"Successfully saved checkpoint @ {self.trainer.steps}steps.")
|
163 |
+
|
164 |
+
def _build_dataset(
|
165 |
+
self, subset: str
|
166 |
+
) -> ReprDataset:
|
167 |
+
data_dir = os.path.join(
|
168 |
+
self.data_path, self.config['data']['subset'][subset]
|
169 |
+
)
|
170 |
+
params = {
|
171 |
+
"data_dir": data_dir,
|
172 |
+
"batch_len": self.batch_length
|
173 |
+
}
|
174 |
+
return ReprDataset(**params)
|
175 |
+
|
176 |
+
def _set_data_loader(self, dataset, collater):
|
177 |
+
self.data_loader = {
|
178 |
+
"train": DataLoader(
|
179 |
+
dataset=dataset["train"],
|
180 |
+
shuffle=True,
|
181 |
+
collate_fn=collater,
|
182 |
+
batch_size=self.config["batch_size"],
|
183 |
+
num_workers=self.config["num_workers"],
|
184 |
+
pin_memory=self.config["pin_memory"],
|
185 |
+
),
|
186 |
+
"dev": DataLoader(
|
187 |
+
dataset=dataset["dev"],
|
188 |
+
shuffle=False,
|
189 |
+
collate_fn=collater,
|
190 |
+
batch_size=self.config["batch_size"],
|
191 |
+
num_workers=0,
|
192 |
+
pin_memory=False, # save some memory. set to True if you have enough memory.
|
193 |
+
),
|
194 |
+
}
|
195 |
+
|
196 |
+
|
197 |
+
def train():
|
198 |
+
parser = argparse.ArgumentParser()
|
199 |
+
parser.add_argument(
|
200 |
+
"-c", "--config", type=str, required=True,
|
201 |
+
help="the path of config yaml file."
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--tag", type=str, required=True,
|
205 |
+
help="the outputs will be saved to exp_root/tag/"
|
206 |
+
)
|
207 |
+
parser.add_argument(
|
208 |
+
"--exp_root", type=str, default="exp"
|
209 |
+
)
|
210 |
+
parser.add_argument(
|
211 |
+
"--resume", default="", type=str, nargs="?",
|
212 |
+
help='checkpoint file path to resume training. (default="")',
|
213 |
+
)
|
214 |
+
parser.add_argument("--seed", default=1337, type=int)
|
215 |
+
parser.add_argument("--disable_cudnn", choices=("True", "False"), default="False", help="Disable CUDNN")
|
216 |
+
args = parser.parse_args()
|
217 |
+
|
218 |
+
train_main = TrainMain(args)
|
219 |
+
train_main.initialize_data_loader()
|
220 |
+
train_main.define_model_optimizer_scheduler()
|
221 |
+
train_main.define_criterion()
|
222 |
+
train_main.define_trainer()
|
223 |
+
train_main.initialize_model()
|
224 |
+
train_main.run()
|
225 |
+
|
226 |
+
|
227 |
+
if __name__ == '__main__':
|
228 |
+
train()
|