File size: 11,628 Bytes
864affd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import os
import re
import sys
from typing import BinaryIO, Optional, Tuple, Union

import torch
import torchaudio

from .backend import Backend
from .common import AudioMetaData

InputType = Union[BinaryIO, str, os.PathLike]


def info_audio(

    src: InputType,

    format: Optional[str],

    buffer_size: int = 4096,

) -> AudioMetaData:
    s = torchaudio.io.StreamReader(src, format, None, buffer_size)
    sinfo = s.get_src_stream_info(s.default_audio_stream)
    if sinfo.num_frames == 0:
        waveform = _load_audio(s)
        num_frames = waveform.size(1)
    else:
        num_frames = sinfo.num_frames
    return AudioMetaData(
        int(sinfo.sample_rate),
        num_frames,
        sinfo.num_channels,
        sinfo.bits_per_sample,
        sinfo.codec.upper(),
    )


def _get_load_filter(

    frame_offset: int = 0,

    num_frames: int = -1,

    convert: bool = True,

) -> Optional[str]:
    if frame_offset < 0:
        raise RuntimeError("Invalid argument: frame_offset must be non-negative. Found: {}".format(frame_offset))
    if num_frames == 0 or num_frames < -1:
        raise RuntimeError("Invalid argument: num_frames must be -1 or greater than 0. Found: {}".format(num_frames))

    # All default values -> no filter
    if frame_offset == 0 and num_frames == -1 and not convert:
        return None
    # Only convert
    aformat = "aformat=sample_fmts=fltp"
    if frame_offset == 0 and num_frames == -1 and convert:
        return aformat
    # At least one of frame_offset or num_frames has non-default value
    if num_frames > 0:
        atrim = "atrim=start_sample={}:end_sample={}".format(frame_offset, frame_offset + num_frames)
    else:
        atrim = "atrim=start_sample={}".format(frame_offset)
    if not convert:
        return atrim
    return "{},{}".format(atrim, aformat)


def _load_audio(

    s: "torchaudio.io.StreamReader",

    filter: Optional[str] = None,

    channels_first: bool = True,

) -> torch.Tensor:
    s.add_audio_stream(-1, -1, filter_desc=filter)
    s.process_all_packets()
    chunk = s.pop_chunks()[0]
    if chunk is None:
        raise RuntimeError("Failed to decode audio.")
    waveform = chunk._elem
    return waveform.T if channels_first else waveform


def load_audio(

    src: InputType,

    frame_offset: int = 0,

    num_frames: int = -1,

    convert: bool = True,

    channels_first: bool = True,

    format: Optional[str] = None,

    buffer_size: int = 4096,

) -> Tuple[torch.Tensor, int]:
    if hasattr(src, "read") and format == "vorbis":
        format = "ogg"
    s = torchaudio.io.StreamReader(src, format, None, buffer_size)
    sample_rate = int(s.get_src_stream_info(s.default_audio_stream).sample_rate)
    filter = _get_load_filter(frame_offset, num_frames, convert)
    waveform = _load_audio(s, filter, channels_first)
    return waveform, sample_rate


def _get_sample_format(dtype: torch.dtype) -> str:
    dtype_to_format = {
        torch.uint8: "u8",
        torch.int16: "s16",
        torch.int32: "s32",
        torch.int64: "s64",
        torch.float32: "flt",
        torch.float64: "dbl",
    }
    format = dtype_to_format.get(dtype)
    if format is None:
        raise ValueError(f"No format found for dtype {dtype}; dtype must be one of {list(dtype_to_format.keys())}.")
    return format


def _native_endianness() -> str:
    if sys.byteorder == "little":
        return "le"
    else:
        return "be"


def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str:
    if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
        raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
    endianness = _native_endianness()
    if not encoding:
        if not bits_per_sample:
            # default to PCM S16
            return f"pcm_s16{endianness}"
        if bits_per_sample == 8:
            return "pcm_u8"
        return f"pcm_s{bits_per_sample}{endianness}"
    if encoding == "PCM_S":
        if not bits_per_sample:
            bits_per_sample = 16
        if bits_per_sample == 8:
            raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
        return f"pcm_s{bits_per_sample}{endianness}"
    if encoding == "PCM_U":
        if bits_per_sample in (None, 8):
            return "pcm_u8"
        raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
    if encoding == "PCM_F":
        if not bits_per_sample:
            bits_per_sample = 32
        if bits_per_sample in (32, 64):
            return f"pcm_f{bits_per_sample}{endianness}"
        raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
    if encoding == "ULAW":
        if bits_per_sample in (None, 8):
            return "pcm_mulaw"
        raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
    if encoding == "ALAW":
        if bits_per_sample in (None, 8):
            return "pcm_alaw"
        raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
    raise ValueError(f"WAV encoding {encoding} is not supported.")


def _get_flac_sample_fmt(bps):
    if bps is None or bps == 16:
        return "s16"
    if bps == 24:
        return "s32"
    raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).")


def _parse_save_args(

    ext: Optional[str],

    format: Optional[str],

    encoding: Optional[str],

    bps: Optional[int],

):
    # torchaudio's save function accepts the followings, which do not 1to1 map
    # to FFmpeg.
    #
    # - format: audio format
    # - bits_per_sample: encoder sample format
    # - encoding: such as PCM_U8.
    #
    # In FFmpeg, format is specified with the following three (and more)
    #
    # - muxer: could be audio format or container format.
    # the one we passed to the constructor of StreamWriter
    # - encoder: the audio encoder used to encode audio
    # - encoder sample format: the format used by encoder to encode audio.
    #
    # If encoder sample format is different from source sample format, StreamWriter
    # will insert a filter automatically.
    #
    def _type(spec):
        # either format is exactly the specified one
        # or extension matches to the spec AND there is no format override.
        return format == spec or (format is None and ext == spec)

    if _type("wav") or _type("amb"):
        # wav is special because it supports different encoding through encoders
        # each encoder only supports one encoder format
        #
        # amb format is a special case originated from libsox.
        # It is basically a WAV format, with slight modification.
        # https://github.com/chirlu/sox/commit/4a4ea33edbca5972a1ed8933cc3512c7302fa67a#diff-39171191a858add9df87f5f210a34a776ac2c026842ae6db6ce97f5e68836795
        # It is a format so that decoders will recognize it as ambisonic.
        # https://www.ambisonia.com/Members/mleese/file-format-for-b-format/
        # FFmpeg does not recognize amb because it is basically a WAV format.
        muxer = "wav"
        encoder = _get_encoder_for_wav(encoding, bps)
        sample_fmt = None
    elif _type("vorbis"):
        # FFpmeg does not recognize vorbis extension, while libsox used to do.
        # For the sake of bakward compatibility, (and the simplicity),
        # we support the case where users want to do save("foo.vorbis")
        muxer = "ogg"
        encoder = "vorbis"
        sample_fmt = None
    else:
        muxer = format
        encoder = None
        sample_fmt = None
        if _type("flac"):
            sample_fmt = _get_flac_sample_fmt(bps)
        if _type("ogg"):
            sample_fmt = _get_flac_sample_fmt(bps)
    return muxer, encoder, sample_fmt


def save_audio(

    uri: InputType,

    src: torch.Tensor,

    sample_rate: int,

    channels_first: bool = True,

    format: Optional[str] = None,

    encoding: Optional[str] = None,

    bits_per_sample: Optional[int] = None,

    buffer_size: int = 4096,

    compression: Optional[torchaudio.io.CodecConfig] = None,

) -> None:
    ext = None
    if hasattr(uri, "write"):
        if format is None:
            raise RuntimeError("'format' is required when saving to file object.")
    else:
        uri = os.path.normpath(uri)
        if tokens := str(uri).split(".")[1:]:
            ext = tokens[-1].lower()

    muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)

    if channels_first:
        src = src.T

    s = torchaudio.io.StreamWriter(uri, format=muxer, buffer_size=buffer_size)
    s.add_audio_stream(
        sample_rate,
        num_channels=src.size(-1),
        format=_get_sample_format(src.dtype),
        encoder=encoder,
        encoder_format=enc_fmt,
        codec_config=compression,
    )
    with s.open():
        s.write_audio_chunk(0, src)


def _map_encoding(encoding: str) -> str:
    for dst in ["PCM_S", "PCM_U", "PCM_F"]:
        if dst in encoding:
            return dst
    if encoding == "PCM_MULAW":
        return "ULAW"
    elif encoding == "PCM_ALAW":
        return "ALAW"
    return encoding


def _get_bits_per_sample(encoding: str, bits_per_sample: int) -> str:
    if m := re.search(r"PCM_\w(\d+)\w*", encoding):
        return int(m.group(1))
    elif encoding in ["PCM_ALAW", "PCM_MULAW"]:
        return 8
    return bits_per_sample


class FFmpegBackend(Backend):
    @staticmethod
    def info(uri: InputType, format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
        metadata = info_audio(uri, format, buffer_size)
        metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
        metadata.encoding = _map_encoding(metadata.encoding)
        return metadata

    @staticmethod
    def load(

        uri: InputType,

        frame_offset: int = 0,

        num_frames: int = -1,

        normalize: bool = True,

        channels_first: bool = True,

        format: Optional[str] = None,

        buffer_size: int = 4096,

    ) -> Tuple[torch.Tensor, int]:
        return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)

    @staticmethod
    def save(

        uri: InputType,

        src: torch.Tensor,

        sample_rate: int,

        channels_first: bool = True,

        format: Optional[str] = None,

        encoding: Optional[str] = None,

        bits_per_sample: Optional[int] = None,

        buffer_size: int = 4096,

        compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,

    ) -> None:
        if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
            raise ValueError(
                "FFmpeg backend expects non-`None` value for argument `compression` to be of ",
                f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
            )
        save_audio(
            uri,
            src,
            sample_rate,
            channels_first,
            format,
            encoding,
            bits_per_sample,
            buffer_size,
            compression,
        )

    @staticmethod
    def can_decode(uri: InputType, format: Optional[str]) -> bool:
        return True

    @staticmethod
    def can_encode(uri: InputType, format: Optional[str]) -> bool:
        return True