File size: 2,024 Bytes
d972bc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import numpy as np
import torchaudio
import torch
def get_firstchannel_read(path, fs=16000):
wave_data, sr = torchaudio.load(path)
if sr != fs:
wave_data = torchaudio.functional.resample(wave_data, sr, fs)
if len(wave_data.shape) > 1:
wave_data = wave_data[0,...]
wave_data = wave_data.cpu().numpy()
return wave_data
def parse_scp(scp, path_list):
with open(scp) as fid:
for line in fid:
tmp = line.strip().split()
if len(tmp) > 1:
path_list.append({"inputs": tmp[0], "duration": tmp[1]})
else:
path_list.append({"inputs": tmp[0]})
class DataReader(object):
def __init__(self, filename, sample_rate):
self.file_list = []
self.sample_rate = sample_rate
parse_scp(filename, self.file_list)
def extract_feature(self, path):
path = path["inputs"]
name = path.split("/")[-1].split(".")[0]
data = get_firstchannel_read(path, fs=self.sample_rate).astype(np.float32)
max_norm = np.max(np.abs(data))
if max_norm == 0:
max_norm = 1
data = data / max_norm
inputs = np.reshape(data, [1, data.shape[0]])
inputs = torch.from_numpy(inputs)
egs = {
"mix": inputs,
"max_norm": max_norm,
"name": name
}
return egs
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
return self.extract_feature(self.file_list[index])
def get_utt2spk(self, path):
lines = open(path, "r").readlines()
for line in lines:
line = line.strip().split()
utt_path, spk_id = line[0], line[1]
self.utt2spk[utt_path] = spk_id
def get_spk2utt(self, path):
lines = open(path, "r").readlines()
for line in lines:
line = line.strip().split()
utt_path, spk_id = line[0], line[1]
self.spk2aux[spk_id] = utt_path
|