guidel's picture
Duplicate from OFA-Sys/OFA-Generic_Interface
8c90e7d
import collections
import io
import json
import librosa
import numpy as np
import soundfile as sf
import time
import torch
from scipy.io.wavfile import read
from .text import SOS_TOK, EOS_TOK
def get_mask_from_lengths(lengths):
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
mask = (ids < lengths.unsqueeze(1))
return mask
def load_wav_to_torch(full_path, sr=None):
data, sr = librosa.load(full_path, sr=sr)
data = np.clip(data, -1, 1) # potentially out of [-1, 1] due to resampling
data = data * 32768.0 # match values loaded by scipy
return torch.FloatTensor(data.astype(np.float32)), sr
def read_binary_audio(bin_data, tar_sr=None):
"""
read binary audio (`bytes` or `uint8` `numpy.ndarray`) to `float32`
`numpy.ndarray`
RETURNS:
data (np.ndarray) : audio of shape (n,) or (2, n)
tar_sr (int) : sample rate
"""
data, ori_sr = sf.read(io.BytesIO(bin_data), dtype='float32')
data = data.T
if (tar_sr is not None) and (ori_sr != tar_sr):
data = librosa.resample(data, ori_sr, tar_sr)
else:
tar_sr = ori_sr
data = np.clip(data, -1, 1)
data = data * 32768.0
return torch.FloatTensor(data.astype(np.float32)), tar_sr
def load_filepaths_and_text(filename):
with open(filename, encoding='utf-8') as f:
data = [json.loads(line.rstrip()) for line in f]
return data
def to_gpu(x):
x = x.contiguous()
if torch.cuda.is_available():
x = x.cuda(non_blocking=True)
return torch.autograd.Variable(x)
def load_code_dict(path, add_sos=False, add_eos=False):
if not path:
return {}
with open(path, 'r') as f:
codes = ['_'] + [line.rstrip() for line in f] # '_' for pad
code_dict = {c: i for i, c in enumerate(codes)}
if add_sos:
code_dict[SOS_TOK] = len(code_dict)
if add_eos:
code_dict[EOS_TOK] = len(code_dict)
assert(set(code_dict.values()) == set(range(len(code_dict))))
return code_dict
def load_obs_label_dict(path):
if not path:
return {}
with open(path, 'r') as f:
obs_labels = [line.rstrip() for line in f]
return {c: i for i, c in enumerate(obs_labels)}
# A simple timer class inspired from `tnt.TimeMeter`
class CudaTimer:
def __init__(self, keys):
self.keys = keys
self.reset()
def start(self, key):
s = torch.cuda.Event(enable_timing=True)
s.record()
self.start_events[key].append(s)
return self
def stop(self, key):
e = torch.cuda.Event(enable_timing=True)
e.record()
self.end_events[key].append(e)
return self
def reset(self):
self.start_events = collections.defaultdict(list)
self.end_events = collections.defaultdict(list)
self.running_times = collections.defaultdict(float)
self.n = collections.defaultdict(int)
return self
def value(self):
self._synchronize()
return {k: self.running_times[k] / self.n[k] for k in self.keys}
def _synchronize(self):
torch.cuda.synchronize()
for k in self.keys:
starts = self.start_events[k]
ends = self.end_events[k]
if len(starts) == 0:
raise ValueError("Trying to divide by zero in TimeMeter")
if len(ends) != len(starts):
raise ValueError("Call stop before checking value!")
time = 0
for start, end in zip(starts, ends):
time += start.elapsed_time(end)
self.running_times[k] += time * 1e-3
self.n[k] += len(starts)
self.start_events = collections.defaultdict(list)
self.end_events = collections.defaultdict(list)
# Used to measure the time taken for multiple events
class Timer:
def __init__(self, keys):
self.keys = keys
self.n = {}
self.running_time = {}
self.total_time = {}
self.reset()
def start(self, key):
self.running_time[key] = time.time()
return self
def stop(self, key):
self.total_time[key] = time.time() - self.running_time[key]
self.n[key] += 1
self.running_time[key] = None
return self
def reset(self):
for k in self.keys:
self.total_time[k] = 0
self.running_time[k] = None
self.n[k] = 0
return self
def value(self):
vals = {}
for k in self.keys:
if self.n[k] == 0:
raise ValueError("Trying to divide by zero in TimeMeter")
else:
vals[k] = self.total_time[k] / self.n[k]
return vals