|
import os
|
|
import sys
|
|
import time
|
|
import tqdm
|
|
import torch
|
|
import librosa
|
|
import traceback
|
|
import concurrent.futures
|
|
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from main.library.utils import load_audio
|
|
from main.app.variables import logger, translations
|
|
from main.inference.extracting.setup_path import setup_paths
|
|
|
|
class RMSEnergyExtractor(nn.Module):
|
|
def __init__(self, frame_length=2048, hop_length=512, center=True, pad_mode = "reflect"):
|
|
super().__init__()
|
|
self.frame_length = frame_length
|
|
self.hop_length = hop_length
|
|
self.center = center
|
|
self.pad_mode = pad_mode
|
|
|
|
def forward(self, x):
|
|
assert x.ndim == 2
|
|
assert x.shape[0] == 1
|
|
|
|
if str(x.device).startswith("ocl"): x = x.contiguous()
|
|
|
|
rms = torch.from_numpy(
|
|
librosa.feature.rms(
|
|
y=x.squeeze(0).cpu().numpy(),
|
|
frame_length=self.frame_length,
|
|
hop_length=self.hop_length,
|
|
center=self.center,
|
|
pad_mode=self.pad_mode
|
|
)
|
|
)
|
|
|
|
return rms.squeeze(-2).to(x.device) if not str(x.device).startswith("ocl") else rms.contiguous().squeeze(-2).to(x.device)
|
|
|
|
def process_file_rms(files, device, threads):
|
|
threads = max(1, threads)
|
|
|
|
module = RMSEnergyExtractor(
|
|
frame_length=2048, hop_length=160, center=True, pad_mode = "reflect"
|
|
).to(device).eval().float()
|
|
|
|
def worker(file_info):
|
|
try:
|
|
file, out_path = file_info
|
|
out_file_path = os.path.join(out_path, os.path.basename(file))
|
|
|
|
if os.path.exists(out_file_path + ".npy"): return
|
|
with torch.no_grad():
|
|
feats = torch.from_numpy(load_audio(file, 16000)).unsqueeze(0)
|
|
feats = module(feats if device.startswith("ocl") else feats.to(device))
|
|
|
|
np.save(out_file_path, feats.float().cpu().numpy(), allow_pickle=False)
|
|
except:
|
|
logger.debug(traceback.format_exc())
|
|
|
|
with tqdm.tqdm(total=len(files), ncols=100, unit="p", leave=True) as pbar:
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
|
|
for _ in concurrent.futures.as_completed([executor.submit(worker, f) for f in files]):
|
|
pbar.update(1)
|
|
|
|
def run_rms_extraction(exp_dir, num_processes, devices, rms_extract):
|
|
if rms_extract:
|
|
wav_path, out_path = setup_paths(exp_dir, rms_extract=rms_extract)
|
|
start_time = time.time()
|
|
paths = sorted([(os.path.join(wav_path, file), out_path) for file in os.listdir(wav_path) if file.endswith(".wav")])
|
|
|
|
start_time = time.time()
|
|
logger.info(translations["rms_start_extract"].format(num_processes=num_processes))
|
|
|
|
with concurrent.futures.ProcessPoolExecutor(max_workers=len(devices)) as executor:
|
|
concurrent.futures.wait([executor.submit(process_file_rms, paths[i::len(devices)], devices[i], num_processes // len(devices)) for i in range(len(devices))])
|
|
|
|
logger.info(translations["rms_success_extract"].format(elapsed_time=f"{(time.time() - start_time):.2f}")) |