File size: 5,947 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from typing import Tuple

import torch
import torch.nn as nn
import torchaudio


class AttPool(nn.Module):
    """Attention-Pooling module that estimates the attention score.



    Args:

        input_dim (int): Input feature dimension.

        att_dim (int): Attention Tensor dimension.

    """

    def __init__(self, input_dim: int, att_dim: int):
        super(AttPool, self).__init__()

        self.linear1 = nn.Linear(input_dim, 1)
        self.linear2 = nn.Linear(input_dim, att_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply attention and pooling.



        Args:

            x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.



        Returns:

            (torch.Tensor): Attention score with dimensions `(batch, att_dim)`.

        """

        att = self.linear1(x)  # (batch, time, 1)
        att = att.transpose(2, 1)  # (batch, 1, time)
        att = nn.functional.softmax(att, dim=2)
        x = torch.matmul(att, x).squeeze(1)  # (batch, input_dim)
        x = self.linear2(x)  # (batch, att_dim)
        return x


class Predictor(nn.Module):
    """Prediction module that apply pooling and attention, then predict subjective metric scores.



    Args:

        input_dim (int): Input feature dimension.

        att_dim (int): Attention Tensor dimension.

    """

    def __init__(self, input_dim: int, att_dim: int):
        super(Predictor, self).__init__()
        self.att_pool_layer = AttPool(input_dim, att_dim)
        self.att_dim = att_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Predict subjective evaluation metric score.



        Args:

            x (torch.Tensor): Input Tensor with dimensions `(batch, time, feature_dim)`.



        Returns:

            (torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.

        """
        x = self.att_pool_layer(x)
        x = nn.functional.softmax(x, dim=1)
        B = torch.linspace(0, 4, steps=self.att_dim, device=x.device)
        x = (x * B).sum(dim=1)
        return x


class SquimSubjective(nn.Module):
    """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **subjective** metric scores

    for speech enhancement (e.g., Mean Opinion Score (MOS)). The model is adopted from *NORESQA-MOS*

    :cite:`manocha2022speech` which predicts MOS scores given the input speech and a non-matching reference.



    Args:

        ssl_model (torch.nn.Module): The self-supervised learning model for feature extraction.

        projector (torch.nn.Module): Projection layer that projects SSL feature to a lower dimension.

        predictor (torch.nn.Module): Predict the subjective scores.

    """

    def __init__(self, ssl_model: nn.Module, projector: nn.Module, predictor: nn.Module):
        super(SquimSubjective, self).__init__()
        self.ssl_model = ssl_model
        self.projector = projector
        self.predictor = predictor

    def _align_shapes(self, waveform: torch.Tensor, reference: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Cut or pad the reference Tensor to make it aligned with waveform Tensor.



        Args:

            waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.

            reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.



        Returns:

            (torch.Tensor, torch.Tensor): The aligned waveform and reference Tensors

                with same dimensions `(batch, time)`.

        """
        T_waveform = waveform.shape[-1]
        T_reference = reference.shape[-1]
        if T_reference < T_waveform:
            num_padding = T_waveform // T_reference + 1
            reference = torch.cat([reference for _ in range(num_padding)], dim=1)
        return waveform, reference[:, :T_waveform]

    def forward(self, waveform: torch.Tensor, reference: torch.Tensor):
        """Predict subjective evaluation metric score.



        Args:

            waveform (torch.Tensor): Input waveform for evaluation. Tensor with dimensions `(batch, time)`.

            reference (torch.Tensor): Non-matching clean reference. Tensor with dimensions `(batch, time_ref)`.



        Returns:

            (torch.Tensor): Subjective metric score. Tensor with dimensions `(batch,)`.

        """
        waveform, reference = self._align_shapes(waveform, reference)
        waveform = self.projector(self.ssl_model.extract_features(waveform)[0][-1])
        reference = self.projector(self.ssl_model.extract_features(reference)[0][-1])
        concat = torch.cat((reference, waveform), dim=2)
        score_diff = self.predictor(concat)  # Score difference compared to the reference
        return 5 - score_diff


def squim_subjective_model(

    ssl_type: str,

    feat_dim: int,

    proj_dim: int,

    att_dim: int,

) -> SquimSubjective:
    """Build a custome :class:`torchaudio.prototype.models.SquimSubjective` model.



    Args:

        ssl_type (str): Type of self-supervised learning (SSL) models.

            Must be one of ["wav2vec2_base", "wav2vec2_large"].

        feat_dim (int): Feature dimension of the SSL feature representation.

        proj_dim (int): Output dimension of projection layer.

        att_dim (int): Dimension of attention scores.

    """
    ssl_model = getattr(torchaudio.models, ssl_type)()
    projector = nn.Linear(feat_dim, proj_dim)
    predictor = Predictor(proj_dim * 2, att_dim)
    return SquimSubjective(ssl_model, projector, predictor)


def squim_subjective_base() -> SquimSubjective:
    """Build :class:`torchaudio.prototype.models.SquimSubjective` model with default arguments."""
    return squim_subjective_model(
        ssl_type="wav2vec2_base",
        feat_dim=768,
        proj_dim=32,
        att_dim=5,
    )