KingNish commited on
Commit
f5135ef
·
verified ·
1 Parent(s): a82df09

Upload ./RepCodec/examples/data2vec_audio.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. RepCodec/examples/data2vec_audio.py +541 -0
RepCodec/examples/data2vec_audio.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # Copyright (c) Chutong Meng
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Based on fairseq (https://github.com/facebookresearch/fairseq)
7
+
8
+ # ref: https://github.com/facebookresearch/fairseq/blob/main/examples/data2vec/models/data2vec_audio.py
9
+
10
+ import logging
11
+ import math
12
+ from dataclasses import dataclass, field
13
+ from typing import Optional
14
+
15
+ from omegaconf import II
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.distributed as dist
21
+
22
+ from fairseq.modules import EMAModule, EMAModuleConfig
23
+ from fairseq.data.data_utils import compute_mask_indices
24
+ from fairseq.models import BaseFairseqModel, register_model
25
+ from fairseq.models.wav2vec import (
26
+ ConvFeatureExtractionModel,
27
+ Wav2Vec2Config,
28
+ TransformerEncoder,
29
+ )
30
+ from fairseq.modules import (
31
+ GradMultiply,
32
+ LayerNorm,
33
+ )
34
+ from fairseq.utils import index_put
35
+
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ @dataclass
41
+ class Data2VecAudioConfig(Wav2Vec2Config):
42
+
43
+ loss_beta: float = field(
44
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
45
+ )
46
+ loss_scale: Optional[float] = field(
47
+ default=None,
48
+ metadata={
49
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
50
+ },
51
+ )
52
+ average_top_k_layers: int = field(
53
+ default=8, metadata={"help": "how many layers to average"}
54
+ )
55
+
56
+ layer_norm_target_layer: bool = False
57
+ instance_norm_target_layer: bool = False
58
+ instance_norm_targets: bool = False
59
+ layer_norm_targets: bool = False
60
+ batch_norm_target_layer: bool = False
61
+ group_norm_target_layer: bool = False
62
+
63
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
64
+ ema_end_decay: float = field(
65
+ default=0.9999, metadata={"help": "final ema decay rate"}
66
+ )
67
+
68
+ # when to finish annealing ema decay rate
69
+ ema_anneal_end_step: int = II("optimization.max_update")
70
+
71
+ ema_transformer_only: bool = field(
72
+ default=True,
73
+ metadata={"help": "whether to momentum update only the transformer"},
74
+ )
75
+ ema_layers_only: bool = field(
76
+ default=True,
77
+ metadata={"help": "whether to momentum update only the transformer layers"},
78
+ )
79
+
80
+ max_update: int = II("optimization.max_update")
81
+
82
+ min_target_var: float = field(
83
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
84
+ )
85
+ min_pred_var: float = field(
86
+ default=0.01,
87
+ metadata={"help": "stop training if prediction var falls below this"},
88
+ )
89
+
90
+
91
+ def get_annealed_rate(start, end, curr_step, total_steps):
92
+ r = end - start
93
+ pct_remaining = 1 - curr_step / total_steps
94
+ return end - r * pct_remaining
95
+
96
+
97
+ @register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
98
+ class Data2VecAudioModel(BaseFairseqModel):
99
+ def __init__(self, cfg: Data2VecAudioConfig):
100
+ super().__init__()
101
+ self.cfg = cfg
102
+
103
+ feature_enc_layers = eval(cfg.conv_feature_layers)
104
+ self.extractor_embed = feature_enc_layers[-1][0]
105
+
106
+ self.ema = None
107
+ self.embed = cfg.encoder_embed_dim
108
+
109
+ self.average_top_k_layers = cfg.average_top_k_layers
110
+ self.loss_beta = cfg.loss_beta
111
+ self.loss_scale = cfg.loss_scale
112
+
113
+ self.feature_extractor = ConvFeatureExtractionModel(
114
+ conv_layers=feature_enc_layers,
115
+ dropout=0.0,
116
+ mode=cfg.extractor_mode,
117
+ conv_bias=cfg.conv_bias,
118
+ )
119
+
120
+ self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
121
+
122
+ self.mask_prob = cfg.mask_prob
123
+ self.mask_selection = cfg.mask_selection
124
+ self.mask_other = cfg.mask_other
125
+ self.mask_length = cfg.mask_length
126
+ self.no_mask_overlap = cfg.no_mask_overlap
127
+ self.mask_min_space = cfg.mask_min_space
128
+
129
+ self.mask_channel_prob = cfg.mask_channel_prob
130
+ self.mask_channel_before = cfg.mask_channel_before
131
+ self.mask_channel_selection = cfg.mask_channel_selection
132
+ self.mask_channel_other = cfg.mask_channel_other
133
+ self.mask_channel_length = cfg.mask_channel_length
134
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
135
+ self.mask_channel_min_space = cfg.mask_channel_min_space
136
+
137
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
138
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
139
+
140
+ self.feature_grad_mult = cfg.feature_grad_mult
141
+
142
+ self.mask_emb = nn.Parameter(
143
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
144
+ )
145
+
146
+ self.encoder = TransformerEncoder(cfg)
147
+ self.layer_norm = LayerNorm(self.extractor_embed)
148
+
149
+ self.final_proj = nn.Linear(self.embed, self.embed)
150
+
151
+ self.num_updates = 0
152
+
153
+ def make_ema_teacher(self):
154
+ ema_config = EMAModuleConfig(
155
+ ema_decay=self.cfg.ema_decay,
156
+ ema_fp32=True,
157
+ )
158
+ skip_keys = set()
159
+ if self.cfg.ema_layers_only:
160
+ self.cfg.ema_transformer_only = True
161
+ for k, _ in self.encoder.pos_conv.named_parameters():
162
+ skip_keys.add(f"pos_conv.{k}")
163
+
164
+ self.ema = EMAModule(
165
+ self.encoder if self.cfg.ema_transformer_only else self,
166
+ ema_config,
167
+ skip_keys=skip_keys,
168
+ )
169
+
170
+ def set_num_updates(self, num_updates):
171
+ super().set_num_updates(num_updates)
172
+
173
+ if self.ema is None and self.final_proj is not None:
174
+ logger.info(f"making ema teacher")
175
+ self.make_ema_teacher()
176
+ elif self.training and self.ema is not None:
177
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
178
+ if num_updates >= self.cfg.ema_anneal_end_step:
179
+ decay = self.cfg.ema_end_decay
180
+ else:
181
+ decay = get_annealed_rate(
182
+ self.cfg.ema_decay,
183
+ self.cfg.ema_end_decay,
184
+ num_updates,
185
+ self.cfg.ema_anneal_end_step,
186
+ )
187
+ self.ema.set_decay(decay)
188
+ if self.ema.get_decay() < 1:
189
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
190
+
191
+ self.num_updates = num_updates
192
+
193
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
194
+ state = super().state_dict(destination, prefix, keep_vars)
195
+
196
+ if self.ema is not None:
197
+ state[prefix + "_ema"] = self.ema.fp32_params
198
+
199
+ return state
200
+
201
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
202
+ if self.ema is not None:
203
+ k = prefix + "_ema"
204
+ assert k in state_dict
205
+ self.ema.restore(state_dict[k], True)
206
+ del state_dict[k]
207
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
208
+
209
+ @classmethod
210
+ def build_model(cls, cfg: Data2VecAudioConfig, task=None):
211
+ """Build a new model instance."""
212
+
213
+ return cls(cfg)
214
+
215
+ def apply_mask(
216
+ self,
217
+ x,
218
+ padding_mask,
219
+ mask_indices=None,
220
+ mask_channel_indices=None,
221
+ ):
222
+ B, T, C = x.shape
223
+
224
+ if self.mask_channel_prob > 0 and self.mask_channel_before:
225
+ mask_channel_indices = compute_mask_indices(
226
+ (B, C),
227
+ None,
228
+ self.mask_channel_prob,
229
+ self.mask_channel_length,
230
+ self.mask_channel_selection,
231
+ self.mask_channel_other,
232
+ no_overlap=self.no_mask_channel_overlap,
233
+ min_space=self.mask_channel_min_space,
234
+ )
235
+ mask_channel_indices = (
236
+ torch.from_numpy(mask_channel_indices)
237
+ .to(x.device)
238
+ .unsqueeze(1)
239
+ .expand(-1, T, -1)
240
+ )
241
+ x[mask_channel_indices] = 0
242
+
243
+ if self.mask_prob > 0:
244
+ if mask_indices is None:
245
+ mask_indices = compute_mask_indices(
246
+ (B, T),
247
+ padding_mask,
248
+ self.mask_prob,
249
+ self.mask_length,
250
+ self.mask_selection,
251
+ self.mask_other,
252
+ min_masks=1,
253
+ no_overlap=self.no_mask_overlap,
254
+ min_space=self.mask_min_space,
255
+ require_same_masks=self.cfg.require_same_masks,
256
+ mask_dropout=self.cfg.mask_dropout,
257
+ )
258
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
259
+ x = index_put(x, mask_indices, self.mask_emb)
260
+ else:
261
+ mask_indices = None
262
+
263
+ if self.mask_channel_prob > 0 and not self.mask_channel_before:
264
+ if mask_channel_indices is None:
265
+ mask_channel_indices = compute_mask_indices(
266
+ (B, C),
267
+ None,
268
+ self.mask_channel_prob,
269
+ self.mask_channel_length,
270
+ self.mask_channel_selection,
271
+ self.mask_channel_other,
272
+ no_overlap=self.no_mask_channel_overlap,
273
+ min_space=self.mask_channel_min_space,
274
+ )
275
+ mask_channel_indices = (
276
+ torch.from_numpy(mask_channel_indices)
277
+ .to(x.device)
278
+ .unsqueeze(1)
279
+ .expand(-1, T, -1)
280
+ )
281
+ x = index_put(x, mask_channel_indices, 0)
282
+
283
+ return x, mask_indices
284
+
285
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
286
+ """
287
+ Computes the output length of the convolutional layers
288
+ """
289
+
290
+ def _conv_out_length(input_length, kernel_size, stride):
291
+ return torch.floor((input_length - kernel_size) / stride + 1)
292
+
293
+ conv_cfg_list = eval(self.cfg.conv_feature_layers)
294
+
295
+ for i in range(len(conv_cfg_list)):
296
+ input_lengths = _conv_out_length(
297
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
298
+ )
299
+
300
+ return input_lengths.to(torch.long)
301
+
302
+ def forward(
303
+ self,
304
+ source,
305
+ padding_mask=None,
306
+ mask=True,
307
+ features_only=False,
308
+ layer=None,
309
+ mask_indices=None,
310
+ mask_channel_indices=None,
311
+ padding_count=None,
312
+ ):
313
+ features = source
314
+
315
+ if self.feature_grad_mult > 0:
316
+ features = self.feature_extractor(features)
317
+ if self.feature_grad_mult != 1.0:
318
+ features = GradMultiply.apply(features, self.feature_grad_mult)
319
+ else:
320
+ with torch.no_grad():
321
+ features = self.feature_extractor(features)
322
+
323
+ features = features.transpose(1, 2)
324
+
325
+ features = self.layer_norm(features)
326
+
327
+ orig_padding_mask = padding_mask
328
+
329
+ if padding_mask is not None and padding_mask.any():
330
+ input_lengths = (1 - padding_mask.long()).sum(-1)
331
+ # apply conv formula to get real output_lengths
332
+ output_lengths = self._get_feat_extract_output_lengths(input_lengths)
333
+
334
+ padding_mask = torch.zeros(
335
+ features.shape[:2], dtype=features.dtype, device=features.device
336
+ )
337
+
338
+ # these two operations makes sure that all values
339
+ # before the output lengths indices are attended to
340
+ padding_mask[
341
+ (
342
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
343
+ output_lengths - 1,
344
+ )
345
+ ] = 1
346
+ padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
347
+ else:
348
+ padding_mask = None
349
+
350
+ if self.post_extract_proj is not None:
351
+ features = self.post_extract_proj(features)
352
+
353
+ pre_encoder_features = None
354
+ if self.cfg.ema_transformer_only:
355
+ pre_encoder_features = features.clone()
356
+
357
+ features = self.dropout_input(features)
358
+
359
+ if mask:
360
+ x, mask_indices = self.apply_mask(
361
+ features,
362
+ padding_mask,
363
+ mask_indices=mask_indices,
364
+ mask_channel_indices=mask_channel_indices,
365
+ )
366
+ else:
367
+ x = features
368
+ mask_indices = None
369
+
370
+ x, layer_results = self.encoder(
371
+ x,
372
+ padding_mask=padding_mask,
373
+ layer=layer,
374
+ )
375
+
376
+ if features_only:
377
+ return {
378
+ "x": x,
379
+ "padding_mask": padding_mask,
380
+ "layer_results": layer_results,
381
+ }
382
+
383
+ result = {
384
+ "losses": {},
385
+ }
386
+
387
+ with torch.no_grad():
388
+ self.ema.model.eval()
389
+
390
+ if self.cfg.ema_transformer_only:
391
+ y, layer_results = self.ema.model.extract_features(
392
+ pre_encoder_features,
393
+ padding_mask=padding_mask,
394
+ min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
395
+ )
396
+ y = {
397
+ "x": y,
398
+ "padding_mask": padding_mask,
399
+ "layer_results": layer_results,
400
+ }
401
+ else:
402
+ y = self.ema.model.extract_features(
403
+ source=source,
404
+ padding_mask=orig_padding_mask,
405
+ mask=False,
406
+ )
407
+
408
+ target_layer_results = [l[2] for l in y["layer_results"]]
409
+
410
+ permuted = False
411
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
412
+ target_layer_results = [
413
+ tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
414
+ ]
415
+ permuted = True
416
+
417
+ if self.cfg.batch_norm_target_layer:
418
+ target_layer_results = [
419
+ F.batch_norm(
420
+ tl.float(), running_mean=None, running_var=None, training=True
421
+ )
422
+ for tl in target_layer_results
423
+ ]
424
+
425
+ if self.cfg.instance_norm_target_layer:
426
+ target_layer_results = [
427
+ F.instance_norm(tl.float()) for tl in target_layer_results
428
+ ]
429
+
430
+ if permuted:
431
+ target_layer_results = [
432
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
433
+ ]
434
+
435
+ if self.cfg.group_norm_target_layer:
436
+ target_layer_results = [
437
+ F.layer_norm(tl.float(), tl.shape[-2:])
438
+ for tl in target_layer_results
439
+ ]
440
+
441
+ if self.cfg.layer_norm_target_layer:
442
+ target_layer_results = [
443
+ F.layer_norm(tl.float(), tl.shape[-1:])
444
+ for tl in target_layer_results
445
+ ]
446
+
447
+ y = sum(target_layer_results) / len(target_layer_results)
448
+
449
+ if self.cfg.layer_norm_targets:
450
+ y = F.layer_norm(y.float(), y.shape[-1:])
451
+
452
+ if self.cfg.instance_norm_targets:
453
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
454
+
455
+ if not permuted:
456
+ y = y.transpose(0, 1)
457
+
458
+ y = y[mask_indices]
459
+
460
+ x = x[mask_indices]
461
+ x = self.final_proj(x)
462
+
463
+ sz = x.size(-1)
464
+
465
+ if self.loss_beta == 0:
466
+ loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
467
+ else:
468
+ loss = F.smooth_l1_loss(
469
+ x.float(), y.float(), reduction="none", beta=self.loss_beta
470
+ ).sum(dim=-1)
471
+
472
+ if self.loss_scale is not None:
473
+ scale = self.loss_scale
474
+ else:
475
+ scale = 1 / math.sqrt(sz)
476
+
477
+ result["losses"]["regression"] = loss.sum() * scale
478
+
479
+ if "sample_size" not in result:
480
+ result["sample_size"] = loss.numel()
481
+
482
+ with torch.no_grad():
483
+ result["target_var"] = self.compute_var(y)
484
+ result["pred_var"] = self.compute_var(x.float())
485
+
486
+ if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
487
+ logger.error(
488
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
489
+ )
490
+ raise Exception(
491
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
492
+ )
493
+ if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
494
+ logger.error(
495
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
496
+ )
497
+ raise Exception(
498
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
499
+ )
500
+
501
+ if self.ema is not None:
502
+ result["ema_decay"] = self.ema.get_decay() * 1000
503
+
504
+ return result
505
+
506
+ @staticmethod
507
+ def compute_var(y):
508
+ y = y.view(-1, y.size(-1))
509
+ if dist.is_initialized():
510
+ zc = torch.tensor(y.size(0)).cuda()
511
+ zs = y.sum(dim=0)
512
+ zss = (y ** 2).sum(dim=0)
513
+
514
+ dist.all_reduce(zc)
515
+ dist.all_reduce(zs)
516
+ dist.all_reduce(zss)
517
+
518
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
519
+ return torch.sqrt(var + 1e-6).mean()
520
+ else:
521
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
522
+
523
+ def extract_features(
524
+ self, source, padding_mask, mask=False, layer=None
525
+ ):
526
+ res = self.forward(
527
+ source,
528
+ padding_mask,
529
+ mask=mask,
530
+ features_only=True,
531
+ layer=layer,
532
+ )
533
+ return res
534
+
535
+ def remove_pretraining_modules(self, last_layer=None):
536
+ self.final_proj = None
537
+ self.ema = None
538
+ if last_layer is not None:
539
+ self.encoder.layers = nn.ModuleList(
540
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
541
+ )