File size: 3,911 Bytes
2d9b22b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from huggingface_hub import PyTorchModelHubMixin
from .seq import BiGRU
from .deepunet import DeepUnet
from .mel import MelSpectrogram
from .constants import *

logger = logging.getLogger(__name__)


class RMVPE(nn.Module, PyTorchModelHubMixin):
    def __init__(
        self,
        n_blocks: int,
        n_gru: int,
        kernel_size: int,
        en_de_layers=5,
        inter_layers=4,
        in_channels=1,
        en_out_channels=16,
    ):
        super().__init__()
        self.mel_extractor = MelSpectrogram(
            N_MELS, SAMPLE_RATE, WINDOW_LENGTH, HOP_LENGTH, None, MEL_FMIN, MEL_FMAX
        )
        self.unet = DeepUnet(
            kernel_size,
            n_blocks,
            en_de_layers,
            inter_layers,
            in_channels,
            en_out_channels,
        )
        self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
        if n_gru:
            self.fc = nn.Sequential(
                BiGRU(3 * N_MELS, 256, n_gru),
                nn.Linear(512, N_CLASS),
                nn.Dropout(0.25),
                nn.Sigmoid(),
            )
        else:
            self.fc = nn.Sequential(
                nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
            )

        cents_mapping = 20 * np.arange(360) + MAGIC_CONST
        self.cents_mapping = np.pad(cents_mapping, (4, 4))  # 368
        self.cents_mapping_torch = torch.from_numpy(self.cents_mapping).to(
            dtype=torch.float32
        )

    def to(self, device):
        self.cents_mapping_torch = self.cents_mapping_torch.to(device)
        return super().to(device)

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        mel = mel.transpose(-1, -2).unsqueeze(1)
        x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
        x = self.fc(x)
        return x

    def mel2hidden(self, mel: torch.Tensor):
        with torch.no_grad():
            n_frames = mel.shape[2]
            n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
            mel = F.pad(mel, (0, n_pad), mode="constant")
            hidden = self(mel)
            return hidden[:, :n_frames]

    def decode(self, hidden: torch.Tensor, thred=0.03):
        cents_pred = self.to_local_average_cents(hidden, thred=thred)
        f0 = 10 * (2 ** (cents_pred / 1200))
        f0[f0 == 10] = 0
        return f0

    def infer(self, audio: torch.Tensor, thred=0.03, return_tensor=False):
        mel = self.mel_extractor(audio.unsqueeze(0))
        hidden = self.mel2hidden(mel)
        hidden = hidden[0].float()
        f0 = self.decode(hidden, thred=thred)
        if return_tensor:
            return f0
        return f0.cpu().numpy()

    def infer_from_audio(self, audio: np.ndarray, thred=0.03):
        audio = torch.from_numpy(audio).to(next(self.parameters()).device)
        return self.infer(audio, thred=thred)

    def to_local_average_cents(
        self, salience: torch.Tensor, thred=0.05
    ) -> torch.Tensor:
        center = torch.argmax(salience, dim=1)
        salience = F.pad(salience, (4, 4))

        center += 4
        batch_indices = torch.arange(salience.shape[0], device=salience.device)

        # Create indices for the 9-point window around each center
        offsets = torch.arange(-4, 5, device=salience.device)
        indices = center.unsqueeze(1) + offsets.unsqueeze(0)

        # Extract values using advanced indexing
        todo_salience = salience[batch_indices.unsqueeze(1), indices]
        todo_cents_mapping = self.cents_mapping_torch[indices]

        product_sum = torch.sum(todo_salience * todo_cents_mapping, 1)
        weight_sum = torch.sum(todo_salience, 1)
        divided = product_sum / weight_sum

        maxx = torch.max(salience, 1).values
        divided[maxx <= thred] = 0

        return divided