AnhP's picture
Upload 170 files
1e4a2ab verified
raw
history blame
3.24 kB
import os
import sys
import time
import tqdm
import torch
import librosa
import traceback
import concurrent.futures
import numpy as np
import torch.nn as nn
sys.path.append(os.getcwd())
from main.library.utils import load_audio
from main.app.variables import logger, translations
from main.inference.extracting.setup_path import setup_paths
class RMSEnergyExtractor(nn.Module):
def __init__(self, frame_length=2048, hop_length=512, center=True, pad_mode = "reflect"):
super().__init__()
self.frame_length = frame_length
self.hop_length = hop_length
self.center = center
self.pad_mode = pad_mode
def forward(self, x):
assert x.ndim == 2
assert x.shape[0] == 1
if str(x.device).startswith("ocl"): x = x.contiguous()
rms = torch.from_numpy(
librosa.feature.rms(
y=x.squeeze(0).cpu().numpy(),
frame_length=self.frame_length,
hop_length=self.hop_length,
center=self.center,
pad_mode=self.pad_mode
)
)
return rms.squeeze(-2).to(x.device) if not str(x.device).startswith("ocl") else rms.contiguous().squeeze(-2).to(x.device)
def process_file_rms(files, device, threads):
threads = max(1, threads)
module = RMSEnergyExtractor(
frame_length=2048, hop_length=160, center=True, pad_mode = "reflect"
).to(device).eval().float()
def worker(file_info):
try:
file, out_path = file_info
out_file_path = os.path.join(out_path, os.path.basename(file))
if os.path.exists(out_file_path + ".npy"): return
with torch.no_grad():
feats = torch.from_numpy(load_audio(file, 16000)).unsqueeze(0)
feats = module(feats if device.startswith("ocl") else feats.to(device))
np.save(out_file_path, feats.float().cpu().numpy(), allow_pickle=False)
except:
logger.debug(traceback.format_exc())
with tqdm.tqdm(total=len(files), ncols=100, unit="p", leave=True) as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
for _ in concurrent.futures.as_completed([executor.submit(worker, f) for f in files]):
pbar.update(1)
def run_rms_extraction(exp_dir, num_processes, devices, rms_extract):
if rms_extract:
wav_path, out_path = setup_paths(exp_dir, rms_extract=rms_extract)
start_time = time.time()
paths = sorted([(os.path.join(wav_path, file), out_path) for file in os.listdir(wav_path) if file.endswith(".wav")])
start_time = time.time()
logger.info(translations["rms_start_extract"].format(num_processes=num_processes))
with concurrent.futures.ProcessPoolExecutor(max_workers=len(devices)) as executor:
concurrent.futures.wait([executor.submit(process_file_rms, paths[i::len(devices)], devices[i], num_processes // len(devices)) for i in range(len(devices))])
logger.info(translations["rms_success_extract"].format(elapsed_time=f"{(time.time() - start_time):.2f}"))