Wendyellé Abubakrh Alban NYANTUDRE
deleted parent dir resemble-enhance
689d78f
raw
history blame
3.54 kB
import argparse
import random
import time
from pathlib import Path
import torch
import torchaudio
from tqdm import tqdm
from .inference import denoise, enhance
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
parser.add_argument("out_dir", type=Path, help="Output folder")
parser.add_argument(
"--run_dir",
type=Path,
default=None,
help="Path to the enhancer run folder, if None, use the default model",
)
parser.add_argument(
"--suffix",
type=str,
default=".wav",
help="Audio file suffix",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for computation, recommended to use CUDA",
)
parser.add_argument(
"--denoise_only",
action="store_true",
help="Only apply denoising without enhancement",
)
parser.add_argument(
"--lambd",
type=float,
default=1.0,
help="Denoise strength for enhancement (0.0 to 1.0)",
)
parser.add_argument(
"--tau",
type=float,
default=0.5,
help="CFM prior temperature (0.0 to 1.0)",
)
parser.add_argument(
"--solver",
type=str,
default="midpoint",
choices=["midpoint", "rk4", "euler"],
help="Numerical solver to use",
)
parser.add_argument(
"--nfe",
type=int,
default=64,
help="Number of function evaluations",
)
parser.add_argument(
"--parallel_mode",
action="store_true",
help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
)
args = parser.parse_args()
device = args.device
if device == "cuda" and not torch.cuda.is_available():
print("CUDA is not available but --device is set to cuda, using CPU instead")
device = "cpu"
start_time = time.perf_counter()
run_dir = args.run_dir
paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
if args.parallel_mode:
random.shuffle(paths)
if len(paths) == 0:
print(f"No {args.suffix} files found in the following path: {args.in_dir}")
return
pbar = tqdm(paths)
for path in pbar:
out_path = args.out_dir / path.relative_to(args.in_dir)
if args.parallel_mode and out_path.exists():
continue
pbar.set_description(f"Processing {out_path}")
dwav, sr = torchaudio.load(path)
dwav = dwav.mean(0)
if args.denoise_only:
hwav, sr = denoise(
dwav=dwav,
sr=sr,
device=device,
run_dir=args.run_dir,
)
else:
hwav, sr = enhance(
dwav=dwav,
sr=sr,
device=device,
nfe=args.nfe,
solver=args.solver,
lambd=args.lambd,
tau=args.tau,
run_dir=run_dir,
)
out_path.parent.mkdir(parents=True, exist_ok=True)
torchaudio.save(out_path, hwav[None], sr)
# Cool emoji effect saying the job is done
elapsed_time = time.perf_counter() - start_time
print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
if __name__ == "__main__":
main()