Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from typing import Dict, Optional | |
import importlib | |
import os | |
import numpy as np | |
class AudioTransform(ABC): | |
def from_config_dict(cls, config: Optional[Dict] = None): | |
pass | |
class CompositeAudioTransform(AudioTransform): | |
def _from_config_dict( | |
cls, | |
transform_type, | |
get_audio_transform, | |
composite_cls, | |
config=None, | |
return_empty=False, | |
): | |
_config = {} if config is None else config | |
_transforms = _config.get(f"{transform_type}_transforms") | |
if _transforms is None: | |
if return_empty: | |
_transforms = [] | |
else: | |
return None | |
transforms = [ | |
get_audio_transform(_t).from_config_dict(_config.get(_t)) | |
for _t in _transforms | |
] | |
return composite_cls(transforms) | |
def __init__(self, transforms): | |
self.transforms = [t for t in transforms if t is not None] | |
def __call__(self, x): | |
for t in self.transforms: | |
x = t(x) | |
return x | |
def __repr__(self): | |
format_string = ( | |
[self.__class__.__name__ + "("] | |
+ [f" {t.__repr__()}" for t in self.transforms] | |
+ [")"] | |
) | |
return "\n".join(format_string) | |
def register_audio_transform(name, cls_type, registry, class_names): | |
def register_audio_transform_cls(cls): | |
if name in registry: | |
raise ValueError(f"Cannot register duplicate transform ({name})") | |
if not issubclass(cls, cls_type): | |
raise ValueError( | |
f"Transform ({name}: {cls.__name__}) must extend " | |
f"{cls_type.__name__}" | |
) | |
if cls.__name__ in class_names: | |
raise ValueError( | |
f"Cannot register audio transform with duplicate " | |
f"class name ({cls.__name__})" | |
) | |
registry[name] = cls | |
class_names.add(cls.__name__) | |
return cls | |
return register_audio_transform_cls | |
def import_transforms(transforms_dir, transform_type): | |
for file in os.listdir(transforms_dir): | |
path = os.path.join(transforms_dir, file) | |
if ( | |
not file.startswith("_") | |
and not file.startswith(".") | |
and (file.endswith(".py") or os.path.isdir(path)) | |
): | |
name = file[: file.find(".py")] if file.endswith(".py") else file | |
importlib.import_module( | |
f"fairseq.data.audio.{transform_type}_transforms." + name | |
) | |
# Utility fn for uniform numbers in transforms | |
def rand_uniform(a, b): | |
return np.random.uniform() * (b - a) + a | |