| import numpy as np | |
| import torchaudio | |
| import torch | |
| def get_firstchannel_read(path, fs=16000): | |
| wave_data, sr = torchaudio.load(path) | |
| if sr != fs: | |
| wave_data = torchaudio.functional.resample(wave_data, sr, fs) | |
| if len(wave_data.shape) > 1: | |
| wave_data = wave_data[0,...] | |
| wave_data = wave_data.cpu().numpy() | |
| return wave_data | |
| def parse_scp(scp, path_list): | |
| with open(scp) as fid: | |
| for line in fid: | |
| tmp = line.strip().split() | |
| if len(tmp) > 1: | |
| path_list.append({"inputs": tmp[0], "duration": tmp[1]}) | |
| else: | |
| path_list.append({"inputs": tmp[0]}) | |
| class DataReader(object): | |
| def __init__(self, filename, sample_rate): | |
| self.file_list = [] | |
| self.sample_rate = sample_rate | |
| parse_scp(filename, self.file_list) | |
| def extract_feature(self, path): | |
| path = path["inputs"] | |
| name = path.split("/")[-1].split(".")[0] | |
| data = get_firstchannel_read(path, fs=self.sample_rate).astype(np.float32) | |
| max_norm = np.max(np.abs(data)) | |
| if max_norm == 0: | |
| max_norm = 1 | |
| data = data / max_norm | |
| inputs = np.reshape(data, [1, data.shape[0]]) | |
| inputs = torch.from_numpy(inputs) | |
| egs = { | |
| "mix": inputs, | |
| "max_norm": max_norm, | |
| "name": name | |
| } | |
| return egs | |
| def __len__(self): | |
| return len(self.file_list) | |
| def __getitem__(self, index): | |
| return self.extract_feature(self.file_list[index]) | |
| def get_utt2spk(self, path): | |
| lines = open(path, "r").readlines() | |
| for line in lines: | |
| line = line.strip().split() | |
| utt_path, spk_id = line[0], line[1] | |
| self.utt2spk[utt_path] = spk_id | |
| def get_spk2utt(self, path): | |
| lines = open(path, "r").readlines() | |
| for line in lines: | |
| line = line.strip().split() | |
| utt_path, spk_id = line[0], line[1] | |
| self.spk2aux[spk_id] = utt_path | |