File size: 3,240 Bytes
1e4a2ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import sys
import time
import tqdm
import torch
import librosa
import traceback
import concurrent.futures

import numpy as np
import torch.nn as nn

sys.path.append(os.getcwd())

from main.library.utils import load_audio
from main.app.variables import logger, translations
from main.inference.extracting.setup_path import setup_paths

class RMSEnergyExtractor(nn.Module):
    def __init__(self, frame_length=2048, hop_length=512, center=True, pad_mode = "reflect"):
        super().__init__()
        self.frame_length = frame_length
        self.hop_length = hop_length
        self.center = center
        self.pad_mode = pad_mode

    def forward(self, x):
        assert x.ndim == 2
        assert x.shape[0] == 1

        if str(x.device).startswith("ocl"): x = x.contiguous()

        rms = torch.from_numpy(
            librosa.feature.rms(
                y=x.squeeze(0).cpu().numpy(), 
                frame_length=self.frame_length, 
                hop_length=self.hop_length, 
                center=self.center, 
                pad_mode=self.pad_mode
            )
        )

        return rms.squeeze(-2).to(x.device) if not str(x.device).startswith("ocl") else rms.contiguous().squeeze(-2).to(x.device)
    
def process_file_rms(files, device, threads):
    threads = max(1, threads)

    module = RMSEnergyExtractor(
        frame_length=2048, hop_length=160, center=True, pad_mode = "reflect"
    ).to(device).eval().float()

    def worker(file_info):
        try:
            file, out_path = file_info
            out_file_path = os.path.join(out_path, os.path.basename(file))

            if os.path.exists(out_file_path + ".npy"): return
            with torch.no_grad():
                feats = torch.from_numpy(load_audio(file, 16000)).unsqueeze(0)
                feats = module(feats if device.startswith("ocl") else feats.to(device))
                
            np.save(out_file_path, feats.float().cpu().numpy(), allow_pickle=False)
        except:
            logger.debug(traceback.format_exc())

    with tqdm.tqdm(total=len(files), ncols=100, unit="p", leave=True) as pbar:
        with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
            for _ in concurrent.futures.as_completed([executor.submit(worker, f) for f in files]):
                pbar.update(1)

def run_rms_extraction(exp_dir, num_processes, devices, rms_extract):
    if rms_extract:
        wav_path, out_path = setup_paths(exp_dir, rms_extract=rms_extract)
        start_time = time.time()
        paths = sorted([(os.path.join(wav_path, file), out_path) for file in os.listdir(wav_path) if file.endswith(".wav")])

        start_time = time.time()
        logger.info(translations["rms_start_extract"].format(num_processes=num_processes))

        with concurrent.futures.ProcessPoolExecutor(max_workers=len(devices)) as executor:
            concurrent.futures.wait([executor.submit(process_file_rms, paths[i::len(devices)], devices[i], num_processes // len(devices)) for i in range(len(devices))])

        logger.info(translations["rms_success_extract"].format(elapsed_time=f"{(time.time() - start_time):.2f}"))