Spaces:
Running
Running
File size: 11,628 Bytes
864affd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 |
import os
import re
import sys
from typing import BinaryIO, Optional, Tuple, Union
import torch
import torchaudio
from .backend import Backend
from .common import AudioMetaData
InputType = Union[BinaryIO, str, os.PathLike]
def info_audio(
src: InputType,
format: Optional[str],
buffer_size: int = 4096,
) -> AudioMetaData:
s = torchaudio.io.StreamReader(src, format, None, buffer_size)
sinfo = s.get_src_stream_info(s.default_audio_stream)
if sinfo.num_frames == 0:
waveform = _load_audio(s)
num_frames = waveform.size(1)
else:
num_frames = sinfo.num_frames
return AudioMetaData(
int(sinfo.sample_rate),
num_frames,
sinfo.num_channels,
sinfo.bits_per_sample,
sinfo.codec.upper(),
)
def _get_load_filter(
frame_offset: int = 0,
num_frames: int = -1,
convert: bool = True,
) -> Optional[str]:
if frame_offset < 0:
raise RuntimeError("Invalid argument: frame_offset must be non-negative. Found: {}".format(frame_offset))
if num_frames == 0 or num_frames < -1:
raise RuntimeError("Invalid argument: num_frames must be -1 or greater than 0. Found: {}".format(num_frames))
# All default values -> no filter
if frame_offset == 0 and num_frames == -1 and not convert:
return None
# Only convert
aformat = "aformat=sample_fmts=fltp"
if frame_offset == 0 and num_frames == -1 and convert:
return aformat
# At least one of frame_offset or num_frames has non-default value
if num_frames > 0:
atrim = "atrim=start_sample={}:end_sample={}".format(frame_offset, frame_offset + num_frames)
else:
atrim = "atrim=start_sample={}".format(frame_offset)
if not convert:
return atrim
return "{},{}".format(atrim, aformat)
def _load_audio(
s: "torchaudio.io.StreamReader",
filter: Optional[str] = None,
channels_first: bool = True,
) -> torch.Tensor:
s.add_audio_stream(-1, -1, filter_desc=filter)
s.process_all_packets()
chunk = s.pop_chunks()[0]
if chunk is None:
raise RuntimeError("Failed to decode audio.")
waveform = chunk._elem
return waveform.T if channels_first else waveform
def load_audio(
src: InputType,
frame_offset: int = 0,
num_frames: int = -1,
convert: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
if hasattr(src, "read") and format == "vorbis":
format = "ogg"
s = torchaudio.io.StreamReader(src, format, None, buffer_size)
sample_rate = int(s.get_src_stream_info(s.default_audio_stream).sample_rate)
filter = _get_load_filter(frame_offset, num_frames, convert)
waveform = _load_audio(s, filter, channels_first)
return waveform, sample_rate
def _get_sample_format(dtype: torch.dtype) -> str:
dtype_to_format = {
torch.uint8: "u8",
torch.int16: "s16",
torch.int32: "s32",
torch.int64: "s64",
torch.float32: "flt",
torch.float64: "dbl",
}
format = dtype_to_format.get(dtype)
if format is None:
raise ValueError(f"No format found for dtype {dtype}; dtype must be one of {list(dtype_to_format.keys())}.")
return format
def _native_endianness() -> str:
if sys.byteorder == "little":
return "le"
else:
return "be"
def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str:
if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
endianness = _native_endianness()
if not encoding:
if not bits_per_sample:
# default to PCM S16
return f"pcm_s16{endianness}"
if bits_per_sample == 8:
return "pcm_u8"
return f"pcm_s{bits_per_sample}{endianness}"
if encoding == "PCM_S":
if not bits_per_sample:
bits_per_sample = 16
if bits_per_sample == 8:
raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
return f"pcm_s{bits_per_sample}{endianness}"
if encoding == "PCM_U":
if bits_per_sample in (None, 8):
return "pcm_u8"
raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
if encoding == "PCM_F":
if not bits_per_sample:
bits_per_sample = 32
if bits_per_sample in (32, 64):
return f"pcm_f{bits_per_sample}{endianness}"
raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "pcm_mulaw"
raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
if encoding == "ALAW":
if bits_per_sample in (None, 8):
return "pcm_alaw"
raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
raise ValueError(f"WAV encoding {encoding} is not supported.")
def _get_flac_sample_fmt(bps):
if bps is None or bps == 16:
return "s16"
if bps == 24:
return "s32"
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).")
def _parse_save_args(
ext: Optional[str],
format: Optional[str],
encoding: Optional[str],
bps: Optional[int],
):
# torchaudio's save function accepts the followings, which do not 1to1 map
# to FFmpeg.
#
# - format: audio format
# - bits_per_sample: encoder sample format
# - encoding: such as PCM_U8.
#
# In FFmpeg, format is specified with the following three (and more)
#
# - muxer: could be audio format or container format.
# the one we passed to the constructor of StreamWriter
# - encoder: the audio encoder used to encode audio
# - encoder sample format: the format used by encoder to encode audio.
#
# If encoder sample format is different from source sample format, StreamWriter
# will insert a filter automatically.
#
def _type(spec):
# either format is exactly the specified one
# or extension matches to the spec AND there is no format override.
return format == spec or (format is None and ext == spec)
if _type("wav") or _type("amb"):
# wav is special because it supports different encoding through encoders
# each encoder only supports one encoder format
#
# amb format is a special case originated from libsox.
# It is basically a WAV format, with slight modification.
# https://github.com/chirlu/sox/commit/4a4ea33edbca5972a1ed8933cc3512c7302fa67a#diff-39171191a858add9df87f5f210a34a776ac2c026842ae6db6ce97f5e68836795
# It is a format so that decoders will recognize it as ambisonic.
# https://www.ambisonia.com/Members/mleese/file-format-for-b-format/
# FFmpeg does not recognize amb because it is basically a WAV format.
muxer = "wav"
encoder = _get_encoder_for_wav(encoding, bps)
sample_fmt = None
elif _type("vorbis"):
# FFpmeg does not recognize vorbis extension, while libsox used to do.
# For the sake of bakward compatibility, (and the simplicity),
# we support the case where users want to do save("foo.vorbis")
muxer = "ogg"
encoder = "vorbis"
sample_fmt = None
else:
muxer = format
encoder = None
sample_fmt = None
if _type("flac"):
sample_fmt = _get_flac_sample_fmt(bps)
if _type("ogg"):
sample_fmt = _get_flac_sample_fmt(bps)
return muxer, encoder, sample_fmt
def save_audio(
uri: InputType,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[torchaudio.io.CodecConfig] = None,
) -> None:
ext = None
if hasattr(uri, "write"):
if format is None:
raise RuntimeError("'format' is required when saving to file object.")
else:
uri = os.path.normpath(uri)
if tokens := str(uri).split(".")[1:]:
ext = tokens[-1].lower()
muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)
if channels_first:
src = src.T
s = torchaudio.io.StreamWriter(uri, format=muxer, buffer_size=buffer_size)
s.add_audio_stream(
sample_rate,
num_channels=src.size(-1),
format=_get_sample_format(src.dtype),
encoder=encoder,
encoder_format=enc_fmt,
codec_config=compression,
)
with s.open():
s.write_audio_chunk(0, src)
def _map_encoding(encoding: str) -> str:
for dst in ["PCM_S", "PCM_U", "PCM_F"]:
if dst in encoding:
return dst
if encoding == "PCM_MULAW":
return "ULAW"
elif encoding == "PCM_ALAW":
return "ALAW"
return encoding
def _get_bits_per_sample(encoding: str, bits_per_sample: int) -> str:
if m := re.search(r"PCM_\w(\d+)\w*", encoding):
return int(m.group(1))
elif encoding in ["PCM_ALAW", "PCM_MULAW"]:
return 8
return bits_per_sample
class FFmpegBackend(Backend):
@staticmethod
def info(uri: InputType, format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
metadata = info_audio(uri, format, buffer_size)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding)
return metadata
@staticmethod
def load(
uri: InputType,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)
@staticmethod
def save(
uri: InputType,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
compression: Optional[Union[torchaudio.io.CodecConfig, float, int]] = None,
) -> None:
if not isinstance(compression, (torchaudio.io.CodecConfig, type(None))):
raise ValueError(
"FFmpeg backend expects non-`None` value for argument `compression` to be of ",
f"type `torchaudio.io.CodecConfig`, but received value of type {type(compression)}",
)
save_audio(
uri,
src,
sample_rate,
channels_first,
format,
encoding,
bits_per_sample,
buffer_size,
compression,
)
@staticmethod
def can_decode(uri: InputType, format: Optional[str]) -> bool:
return True
@staticmethod
def can_encode(uri: InputType, format: Optional[str]) -> bool:
return True
|