File size: 7,346 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from dataclasses import dataclass

from torchaudio._internal import load_state_dict_from_url

from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective


@dataclass
class SquimObjectiveBundle:
    """Data class that bundles associated information to use pretrained

    :py:class:`~torchaudio.models.SquimObjective` model.



    This class provides interfaces for instantiating the pretrained model along with

    the information necessary to retrieve pretrained weights and additional data

    to be used with the model.



    Torchaudio library instantiates objects of this class, each of which represents

    a different pretrained model. Client code should access pretrained models via these

    instances.



    This bundle can estimate objective metric scores for speech enhancement, such as STOI, PESQ, Si-SDR.

    A typical use case would be a flow like `waveform -> list of scores`. Please see below for the code example.



    Example: Estimate the objective metric scores for the input waveform.

        >>> import torch

        >>> import torchaudio

        >>> from torchaudio.pipelines import SQUIM_OBJECTIVE as bundle

        >>>

        >>> # Load the SquimObjective bundle

        >>> model = bundle.get_model()

        Downloading: "https://download.pytorch.org/torchaudio/models/squim_objective_dns2020.pth"

        100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 28.2M/28.2M [00:03<00:00, 9.24MB/s]

        >>>

        >>> # Resample audio to the expected sampling rate

        >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)

        >>>

        >>> # Estimate objective metric scores

        >>> scores = model(waveform)

        >>> print(f"STOI: {scores[0].item()},  PESQ: {scores[1].item()}, SI-SDR: {scores[2].item()}.")

    """  # noqa: E501

    _path: str
    _sample_rate: float

    def _get_state_dict(self, dl_kwargs):
        url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
        dl_kwargs = {} if dl_kwargs is None else dl_kwargs
        state_dict = load_state_dict_from_url(url, **dl_kwargs)
        return state_dict

    def get_model(self, *, dl_kwargs=None) -> SquimObjective:
        """Construct the SquimObjective model, and load the pretrained weight.



        The weight file is downloaded from the internet and cached with

        :func:`torch.hub.load_state_dict_from_url`



        Args:

            dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.



        Returns:

            Variation of :py:class:`~torchaudio.models.SquimObjective`.

        """
        model = squim_objective_base()
        model.load_state_dict(self._get_state_dict(dl_kwargs))
        model.eval()
        return model

    @property
    def sample_rate(self):
        """Sample rate of the audio that the model is trained on.



        :type: float

        """
        return self._sample_rate


SQUIM_OBJECTIVE = SquimObjectiveBundle(
    "squim_objective_dns2020.pth",
    _sample_rate=16000,
)
SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in

    :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.



    The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.

    The weights are under `Creative Commons Attribution 4.0 International License

    <https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/LICENSE>`__.



    Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.

    """


@dataclass
class SquimSubjectiveBundle:
    """Data class that bundles associated information to use pretrained

    :py:class:`~torchaudio.models.SquimSubjective` model.



    This class provides interfaces for instantiating the pretrained model along with

    the information necessary to retrieve pretrained weights and additional data

    to be used with the model.



    Torchaudio library instantiates objects of this class, each of which represents

    a different pretrained model. Client code should access pretrained models via these

    instances.



    This bundle can estimate subjective metric scores for speech enhancement, such as MOS.

    A typical use case would be a flow like `waveform -> score`. Please see below for the code example.



    Example: Estimate the subjective metric scores for the input waveform.

        >>> import torch

        >>> import torchaudio

        >>> from torchaudio.pipelines import SQUIM_SUBJECTIVE as bundle

        >>>

        >>> # Load the SquimSubjective bundle

        >>> model = bundle.get_model()

        Downloading: "https://download.pytorch.org/torchaudio/models/squim_subjective_bvcc_daps.pth"

        100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 360M/360M [00:09<00:00, 41.1MB/s]

        >>>

        >>> # Resample audio to the expected sampling rate

        >>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)

        >>> # Use a clean reference (doesn't need to be the reference for the waveform) as the second input

        >>> reference = torchaudio.functional.resample(reference, sample_rate, bundle.sample_rate)

        >>>

        >>> # Estimate subjective metric scores

        >>> score = model(waveform, reference)

        >>> print(f"MOS: {score}.")

    """  # noqa: E501

    _path: str
    _sample_rate: float

    def _get_state_dict(self, dl_kwargs):
        url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
        dl_kwargs = {} if dl_kwargs is None else dl_kwargs
        state_dict = load_state_dict_from_url(url, **dl_kwargs)
        return state_dict

    def get_model(self, *, dl_kwargs=None) -> SquimSubjective:
        """Construct the SquimSubjective model, and load the pretrained weight.



        The weight file is downloaded from the internet and cached with

        :func:`torch.hub.load_state_dict_from_url`



        Args:

            dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.



        Returns:

            Variation of :py:class:`~torchaudio.models.SquimObjective`.

        """
        model = squim_subjective_base()
        model.load_state_dict(self._get_state_dict(dl_kwargs))
        model.eval()
        return model

    @property
    def sample_rate(self):
        """Sample rate of the audio that the model is trained on.



        :type: float

        """
        return self._sample_rate


SQUIM_SUBJECTIVE = SquimSubjectiveBundle(
    "squim_subjective_bvcc_daps.pth",
    _sample_rate=16000,
)
SQUIM_SUBJECTIVE.__doc__ = """SquimSubjective pipeline trained

    as described in :cite:`manocha2022speech` and :cite:`kumar2023torchaudio`

    on the *BVCC* :cite:`cooper2021voices` and *DAPS* :cite:`mysore2014can` datasets.



    The underlying model is constructed by :py:func:`torchaudio.models.squim_subjective_base`.

    The weights are under `Creative Commons Attribution Non Commercial 4.0 International

    <https://zenodo.org/record/4660670#.ZBtWPOxuerN>`__.



    Please refer to :py:class:`SquimSubjectiveBundle` for usage instructions.

    """