Audio-to-Audio
File size: 2,024 Bytes
d972bc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
import torchaudio
import torch

def get_firstchannel_read(path, fs=16000):
    wave_data, sr = torchaudio.load(path)
    if sr != fs:
        wave_data = torchaudio.functional.resample(wave_data, sr, fs)
    if len(wave_data.shape) > 1:
        wave_data = wave_data[0,...]
    wave_data = wave_data.cpu().numpy()
    return wave_data

def parse_scp(scp, path_list):
    with open(scp) as fid: 
        for line in fid:
            tmp = line.strip().split()
            if len(tmp) > 1:
                path_list.append({"inputs": tmp[0], "duration": tmp[1]})
            else:
                path_list.append({"inputs": tmp[0]})

class DataReader(object):
    def __init__(self, filename, sample_rate): 
        self.file_list = []
        self.sample_rate = sample_rate
        parse_scp(filename, self.file_list)

    def extract_feature(self, path):
        path = path["inputs"]
        name = path.split("/")[-1].split(".")[0]
        data = get_firstchannel_read(path, fs=self.sample_rate).astype(np.float32)
        max_norm = np.max(np.abs(data))
        if max_norm == 0:
            max_norm = 1      
        data = data / max_norm
        inputs = np.reshape(data, [1, data.shape[0]])
        inputs = torch.from_numpy(inputs)

        egs = {
            "mix": inputs,
            "max_norm": max_norm,
            "name": name
        }
        return egs

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        return self.extract_feature(self.file_list[index])

    def get_utt2spk(self, path):
        lines = open(path, "r").readlines()
        for line in lines:
            line = line.strip().split()
            utt_path, spk_id = line[0], line[1]
            self.utt2spk[utt_path] = spk_id
    
    def get_spk2utt(self, path):
        lines = open(path, "r").readlines()
        for line in lines:
            line = line.strip().split()
            utt_path, spk_id = line[0], line[1]
            self.spk2aux[spk_id] = utt_path