|
import numpy as np |
|
import torchaudio |
|
import torch |
|
|
|
def get_firstchannel_read(path, fs=16000): |
|
wave_data, sr = torchaudio.load(path) |
|
if sr != fs: |
|
wave_data = torchaudio.functional.resample(wave_data, sr, fs) |
|
if len(wave_data.shape) > 1: |
|
wave_data = wave_data[0,...] |
|
wave_data = wave_data.cpu().numpy() |
|
return wave_data |
|
|
|
def parse_scp(scp, path_list): |
|
with open(scp) as fid: |
|
for line in fid: |
|
tmp = line.strip().split() |
|
if len(tmp) > 1: |
|
path_list.append({"inputs": tmp[0], "duration": tmp[1]}) |
|
else: |
|
path_list.append({"inputs": tmp[0]}) |
|
|
|
class DataReader(object): |
|
def __init__(self, filename, sample_rate): |
|
self.file_list = [] |
|
self.sample_rate = sample_rate |
|
parse_scp(filename, self.file_list) |
|
|
|
def extract_feature(self, path): |
|
path = path["inputs"] |
|
name = path.split("/")[-1].split(".")[0] |
|
data = get_firstchannel_read(path, fs=self.sample_rate).astype(np.float32) |
|
max_norm = np.max(np.abs(data)) |
|
if max_norm == 0: |
|
max_norm = 1 |
|
data = data / max_norm |
|
inputs = np.reshape(data, [1, data.shape[0]]) |
|
inputs = torch.from_numpy(inputs) |
|
|
|
egs = { |
|
"mix": inputs, |
|
"max_norm": max_norm, |
|
"name": name |
|
} |
|
return egs |
|
|
|
def __len__(self): |
|
return len(self.file_list) |
|
|
|
def __getitem__(self, index): |
|
return self.extract_feature(self.file_list[index]) |
|
|
|
def get_utt2spk(self, path): |
|
lines = open(path, "r").readlines() |
|
for line in lines: |
|
line = line.strip().split() |
|
utt_path, spk_id = line[0], line[1] |
|
self.utt2spk[utt_path] = spk_id |
|
|
|
def get_spk2utt(self, path): |
|
lines = open(path, "r").readlines() |
|
for line in lines: |
|
line = line.strip().split() |
|
utt_path, spk_id = line[0], line[1] |
|
self.spk2aux[spk_id] = utt_path |
|
|