File size: 3,143 Bytes
2d9b22b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import logging

import torch
import torch.utils.data

logger = logging.getLogger(__name__)


class TextAudioCollateMultiNSFsid:
    """Zero-pads model inputs and targets"""

    def __init__(self):
        pass

    def __call__(self, batch):
        """Collate's training batch from normalized text and aduio

        PARAMS

        ------

        batch: [text_normalized, spec_normalized, wav_normalized]

        """
        device = batch[0]["spec"].device

        with device:
            # Right zero-pad all one-hot text sequences to max input length
            _, ids_sorted_decreasing = torch.sort(
                torch.tensor([x["spec"].size(1) for x in batch], dtype=torch.int32),
                dim=0,
                descending=True,
            )

            max_spec_len = max([x["spec"].size(1) for x in batch])
            max_wave_len = max([x["wav_gt"]["array"].size(0) for x in batch])
            spec_lengths = torch.zeros(len(batch), dtype=torch.int32)
            wave_lengths = torch.zeros(len(batch), dtype=torch.int32)
            spec_padded = torch.zeros(
                len(batch), batch[0]["spec"].size(0), max_spec_len, dtype=torch.float32
            )
            wave_padded = torch.zeros(len(batch), 1, max_wave_len, dtype=torch.float32)

            max_phone_len = max([x["hubert_feats"].size(0) for x in batch])
            phone_lengths = torch.zeros(len(batch), dtype=torch.int32)
            phone_padded = torch.zeros(
                len(batch),
                max_phone_len,
                batch[0]["hubert_feats"].shape[1],
                dtype=torch.float32,
            )  # (spec, wav, phone, pitch)
            pitch_padded = torch.zeros(len(batch), max_phone_len, dtype=torch.int32)
            pitchf_padded = torch.zeros(len(batch), max_phone_len, dtype=torch.float32)
            # dv = torch.FloatTensor(len(batch), 256)#gin=256
            sid = torch.zeros(len(batch), dtype=torch.int32)

            for i in range(len(ids_sorted_decreasing)):
                row = batch[ids_sorted_decreasing[i]]

                spec = row["spec"]
                spec_padded[i, :, : spec.size(1)] = spec
                spec_lengths[i] = spec.size(1)

                wave = row["wav_gt"]["array"]
                wave_padded[i, :, : wave.size(0)] = wave
                wave_lengths[i] = wave.size(0)

                phone = row["hubert_feats"]
                phone_padded[i, : phone.size(0), :] = phone
                phone_lengths[i] = phone.size(0)

                pitch = row["f0"]
                pitch_padded[i, : pitch.size(0)] = pitch
                pitchf = row["f0nsf"]
                pitchf_padded[i, : pitchf.size(0)] = pitchf

                sid[i] = torch.tensor([0], dtype=torch.int32)

            return (
                phone_padded,
                phone_lengths,
                pitch_padded,
                pitchf_padded,
                spec_padded,
                spec_lengths,
                wave_padded,
                wave_lengths,
                sid,
            )