Gregniuki commited on
Commit
59d0e09
·
verified ·
1 Parent(s): 3983b32

Delete model/trainer.py

Browse files
Files changed (1) hide show
  1. model/trainer.py +0 -353
model/trainer.py DELETED
@@ -1,353 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import gc
4
- import os
5
-
6
- import torch
7
- import torchaudio
8
- import wandb
9
- from accelerate import Accelerator
10
- from accelerate.utils import DistributedDataParallelKwargs
11
- from ema_pytorch import EMA
12
- from torch.optim import AdamW
13
- from torch.optim.lr_scheduler import LinearLR, SequentialLR
14
- from torch.utils.data import DataLoader, Dataset, SequentialSampler
15
- from tqdm import tqdm
16
-
17
- from f5_tts.model import CFM
18
- from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
19
- from f5_tts.model.utils import default, exists
20
-
21
- # trainer
22
-
23
-
24
- class Trainer:
25
- def __init__(
26
- self,
27
- model: CFM,
28
- epochs,
29
- learning_rate,
30
- num_warmup_updates=20000,
31
- save_per_updates=1000,
32
- checkpoint_path=None,
33
- batch_size=32,
34
- batch_size_type: str = "sample",
35
- max_samples=32,
36
- grad_accumulation_steps=1,
37
- max_grad_norm=1.0,
38
- noise_scheduler: str | None = None,
39
- duration_predictor: torch.nn.Module | None = None,
40
- logger: str | None = "wandb", # "wandb" | "tensorboard" | None
41
- wandb_project="test_e2-tts",
42
- wandb_run_name="test_run",
43
- wandb_resume_id: str = None,
44
- log_samples: bool = False,
45
- last_per_steps=None,
46
- accelerate_kwargs: dict = dict(),
47
- ema_kwargs: dict = dict(),
48
- bnb_optimizer: bool = False,
49
- mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
50
- ):
51
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
52
-
53
- if logger == "wandb" and not wandb.api.api_key:
54
- logger = None
55
- print(f"Using logger: {logger}")
56
- self.log_samples = log_samples
57
-
58
- self.accelerator = Accelerator(
59
- log_with=logger if logger == "wandb" else None,
60
- kwargs_handlers=[ddp_kwargs],
61
- gradient_accumulation_steps=grad_accumulation_steps,
62
- **accelerate_kwargs,
63
- )
64
-
65
- self.logger = logger
66
- if self.logger == "wandb":
67
- if exists(wandb_resume_id):
68
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
69
- else:
70
- init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
71
-
72
- self.accelerator.init_trackers(
73
- project_name=wandb_project,
74
- init_kwargs=init_kwargs,
75
- config={
76
- "epochs": epochs,
77
- "learning_rate": learning_rate,
78
- "num_warmup_updates": num_warmup_updates,
79
- "batch_size": batch_size,
80
- "batch_size_type": batch_size_type,
81
- "max_samples": max_samples,
82
- "grad_accumulation_steps": grad_accumulation_steps,
83
- "max_grad_norm": max_grad_norm,
84
- "gpus": self.accelerator.num_processes,
85
- "noise_scheduler": noise_scheduler,
86
- },
87
- )
88
-
89
- elif self.logger == "tensorboard":
90
- from torch.utils.tensorboard import SummaryWriter
91
-
92
- self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
93
-
94
- self.model = model
95
-
96
- if self.is_main:
97
- self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
98
- self.ema_model.to(self.accelerator.device)
99
-
100
- self.epochs = epochs
101
- self.num_warmup_updates = num_warmup_updates
102
- self.save_per_updates = save_per_updates
103
- self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
104
- self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
105
-
106
- self.batch_size = batch_size
107
- self.batch_size_type = batch_size_type
108
- self.max_samples = max_samples
109
- self.grad_accumulation_steps = grad_accumulation_steps
110
- self.max_grad_norm = max_grad_norm
111
- self.vocoder_name = mel_spec_type
112
-
113
- self.noise_scheduler = noise_scheduler
114
-
115
- self.duration_predictor = duration_predictor
116
-
117
- if bnb_optimizer:
118
- import bitsandbytes as bnb
119
-
120
- self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
121
- else:
122
- self.optimizer = AdamW(model.parameters(), lr=learning_rate)
123
- self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
124
-
125
- @property
126
- def is_main(self):
127
- return self.accelerator.is_main_process
128
-
129
- def save_checkpoint(self, step, last=False):
130
- self.accelerator.wait_for_everyone()
131
- if self.is_main:
132
- checkpoint = dict(
133
- model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
134
- optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
135
- ema_model_state_dict=self.ema_model.state_dict(),
136
- scheduler_state_dict=self.scheduler.state_dict(),
137
- step=step,
138
- )
139
- if not os.path.exists(self.checkpoint_path):
140
- os.makedirs(self.checkpoint_path)
141
- if last:
142
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
143
- print(f"Saved last checkpoint at step {step}")
144
- else:
145
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
146
-
147
- def load_checkpoint(self):
148
- if (
149
- not exists(self.checkpoint_path)
150
- or not os.path.exists(self.checkpoint_path)
151
- or not os.listdir(self.checkpoint_path)
152
- ):
153
- return 0
154
-
155
- self.accelerator.wait_for_everyone()
156
- if "model_last.pt" in os.listdir(self.checkpoint_path):
157
- latest_checkpoint = "model_last.pt"
158
- else:
159
- latest_checkpoint = sorted(
160
- [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
161
- key=lambda x: int("".join(filter(str.isdigit, x))),
162
- )[-1]
163
- # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
164
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
165
-
166
- # patch for backward compatibility, 305e3ea
167
- for key in ["ema_model.mel_spec.mel_stft.mel_scale.fb", "ema_model.mel_spec.mel_stft.spectrogram.window"]:
168
- if key in checkpoint["ema_model_state_dict"]:
169
- del checkpoint["ema_model_state_dict"][key]
170
-
171
- if self.is_main:
172
- self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
173
-
174
- if "step" in checkpoint:
175
- # patch for backward compatibility, 305e3ea
176
- for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
177
- if key in checkpoint["model_state_dict"]:
178
- del checkpoint["model_state_dict"][key]
179
-
180
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
181
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
182
- if self.scheduler:
183
- self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
184
- step = checkpoint["step"]
185
- else:
186
- checkpoint["model_state_dict"] = {
187
- k.replace("ema_model.", ""): v
188
- for k, v in checkpoint["ema_model_state_dict"].items()
189
- if k not in ["initted", "step"]
190
- }
191
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
192
- step = 0
193
-
194
- del checkpoint
195
- gc.collect()
196
- return step
197
-
198
- def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
199
- if self.log_samples:
200
- from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
201
-
202
- vocoder = load_vocoder(vocoder_name=self.vocoder_name)
203
- target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
204
- log_samples_path = f"{self.checkpoint_path}/samples"
205
- os.makedirs(log_samples_path, exist_ok=True)
206
-
207
- if exists(resumable_with_seed):
208
- generator = torch.Generator()
209
- generator.manual_seed(resumable_with_seed)
210
- else:
211
- generator = None
212
-
213
- if self.batch_size_type == "sample":
214
- train_dataloader = DataLoader(
215
- train_dataset,
216
- collate_fn=collate_fn,
217
- num_workers=num_workers,
218
- pin_memory=True,
219
- persistent_workers=True,
220
- batch_size=self.batch_size,
221
- shuffle=True,
222
- generator=generator,
223
- )
224
- elif self.batch_size_type == "frame":
225
- self.accelerator.even_batches = False
226
- sampler = SequentialSampler(train_dataset)
227
- batch_sampler = DynamicBatchSampler(
228
- sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False
229
- )
230
- train_dataloader = DataLoader(
231
- train_dataset,
232
- collate_fn=collate_fn,
233
- num_workers=num_workers,
234
- pin_memory=True,
235
- persistent_workers=True,
236
- batch_sampler=batch_sampler,
237
- )
238
- else:
239
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
240
-
241
- # accelerator.prepare() dispatches batches to devices;
242
- # which means the length of dataloader calculated before, should consider the number of devices
243
- warmup_steps = (
244
- self.num_warmup_updates * self.accelerator.num_processes
245
- ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
246
- # otherwise by default with split_batches=False, warmup steps change with num_processes
247
- total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
248
- decay_steps = total_steps - warmup_steps
249
- warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
250
- decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
251
- self.scheduler = SequentialLR(
252
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
253
- )
254
- train_dataloader, self.scheduler = self.accelerator.prepare(
255
- train_dataloader, self.scheduler
256
- ) # actual steps = 1 gpu steps / gpus
257
- start_step = self.load_checkpoint()
258
- global_step = start_step
259
-
260
- if exists(resumable_with_seed):
261
- orig_epoch_step = len(train_dataloader)
262
- skipped_epoch = int(start_step // orig_epoch_step)
263
- skipped_batch = start_step % orig_epoch_step
264
- skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
265
- else:
266
- skipped_epoch = 0
267
-
268
- for epoch in range(skipped_epoch, self.epochs):
269
- self.model.train()
270
- if exists(resumable_with_seed) and epoch == skipped_epoch:
271
- progress_bar = tqdm(
272
- skipped_dataloader,
273
- desc=f"Epoch {epoch+1}/{self.epochs}",
274
- unit="step",
275
- disable=not self.accelerator.is_local_main_process,
276
- initial=skipped_batch,
277
- total=orig_epoch_step,
278
- )
279
- else:
280
- progress_bar = tqdm(
281
- train_dataloader,
282
- desc=f"Epoch {epoch+1}/{self.epochs}",
283
- unit="step",
284
- disable=not self.accelerator.is_local_main_process,
285
- )
286
-
287
- for batch in progress_bar:
288
- with self.accelerator.accumulate(self.model):
289
- text_inputs = batch["text"]
290
- mel_spec = batch["mel"].permute(0, 2, 1)
291
- mel_lengths = batch["mel_lengths"]
292
-
293
- # TODO. add duration predictor training
294
- if self.duration_predictor is not None and self.accelerator.is_local_main_process:
295
- dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
296
- self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
297
-
298
- loss, cond, pred = self.model(
299
- mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
300
- )
301
- self.accelerator.backward(loss)
302
-
303
- if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
304
- self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
305
-
306
- self.optimizer.step()
307
- self.scheduler.step()
308
- self.optimizer.zero_grad()
309
-
310
- if self.is_main:
311
- self.ema_model.update()
312
-
313
- global_step += 1
314
-
315
- if self.accelerator.is_local_main_process:
316
- self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
317
- if self.logger == "tensorboard":
318
- self.writer.add_scalar("loss", loss.item(), global_step)
319
- self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
320
-
321
- progress_bar.set_postfix(step=str(global_step), loss=loss.item())
322
-
323
- if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
324
- self.save_checkpoint(global_step)
325
-
326
- if self.log_samples and self.accelerator.is_local_main_process:
327
- ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
328
- torchaudio.save(
329
- f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
330
- )
331
- with torch.inference_mode():
332
- generated, _ = self.accelerator.unwrap_model(self.model).sample(
333
- cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
334
- text=[text_inputs[0] + [" "] + text_inputs[0]],
335
- duration=ref_audio_len * 2,
336
- steps=nfe_step,
337
- cfg_strength=cfg_strength,
338
- sway_sampling_coef=sway_sampling_coef,
339
- )
340
- generated = generated.to(torch.float32)
341
- gen_audio = vocoder.decode(
342
- generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
343
- )
344
- torchaudio.save(
345
- f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
346
- )
347
-
348
- if global_step % self.last_per_steps == 0:
349
- self.save_checkpoint(global_step, last=True)
350
-
351
- self.save_checkpoint(global_step, last=True)
352
-
353
- self.accelerator.end_training()