|
import torch
|
|
import platform
|
|
import subprocess
|
|
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from librosa.util import pad_center
|
|
from scipy.signal import get_window
|
|
|
|
try:
|
|
import pytorch_ocl
|
|
except:
|
|
pytorch_ocl = None
|
|
|
|
torch_available = pytorch_ocl != None
|
|
|
|
def check_amd_gpu(gpu):
|
|
for i in ["RX", "AMD", "Vega", "Radeon", "FirePro"]:
|
|
return i in gpu
|
|
|
|
def get_amd_gpu_windows():
|
|
try:
|
|
return [gpu.strip() for gpu in subprocess.check_output("wmic path win32_VideoController get name", shell=True).decode().split('\n')[1:] if check_amd_gpu(gpu)]
|
|
except:
|
|
return []
|
|
|
|
def get_amd_gpu_linux():
|
|
try:
|
|
return [gpu for gpu in subprocess.check_output("lspci | grep VGA", shell=True).decode().split('\n') if check_amd_gpu(gpu)]
|
|
except:
|
|
return []
|
|
|
|
def get_gpu_list():
|
|
return (get_amd_gpu_windows() if platform.system() == "Windows" else get_amd_gpu_linux()) if torch_available else []
|
|
|
|
def device_count():
|
|
return len(get_gpu_list()) if torch_available else 0
|
|
|
|
def device_name(device_id = 0):
|
|
return (get_gpu_list()[device_id] if device_id >= 0 and device_id < device_count() else "") if torch_available else ""
|
|
|
|
def is_available():
|
|
return (device_count() > 0) if torch_available else False
|
|
|
|
class STFT(torch.nn.Module):
|
|
def __init__(self, filter_length=1024, hop_length=512, win_length=None, window="hann"):
|
|
super(STFT, self).__init__()
|
|
self.filter_length = filter_length
|
|
self.hop_length = hop_length
|
|
self.pad_amount = int(self.filter_length / 2)
|
|
self.win_length = win_length
|
|
self.hann_window = {}
|
|
|
|
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
|
cutoff = int((self.filter_length / 2 + 1))
|
|
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
|
|
forward_basis = torch.FloatTensor(fourier_basis)
|
|
inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
|
|
|
|
if win_length is None or not win_length: win_length = filter_length
|
|
assert filter_length >= win_length
|
|
|
|
fft_window = torch.from_numpy(pad_center(get_window(window, win_length, fftbins=True), size=filter_length)).float()
|
|
forward_basis *= fft_window
|
|
inverse_basis = (inverse_basis.T * fft_window).T
|
|
|
|
self.register_buffer("forward_basis", forward_basis.float())
|
|
self.register_buffer("inverse_basis", inverse_basis.float())
|
|
self.register_buffer("fft_window", fft_window.float())
|
|
|
|
def transform(self, input_data, eps):
|
|
input_data = F.pad(input_data, (self.pad_amount, self.pad_amount), mode="reflect")
|
|
forward_transform = torch.matmul(self.forward_basis, input_data.unfold(1, self.filter_length, self.hop_length).permute(0, 2, 1))
|
|
cutoff = int(self.filter_length / 2 + 1)
|
|
|
|
return torch.sqrt(forward_transform[:, :cutoff, :]**2 + forward_transform[:, cutoff:, :]**2 + eps)
|
|
|
|
class GRU(nn.RNNBase):
|
|
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0.0, bidirectional=False, device=None, dtype=None):
|
|
super().__init__("GRU", input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, device=device, dtype=dtype)
|
|
|
|
@staticmethod
|
|
def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh):
|
|
gate_x = F.linear(x, weight_ih, bias_ih)
|
|
gate_h = F.linear(hx, weight_hh, bias_hh)
|
|
|
|
i_r, i_i, i_n = gate_x.chunk(3, 1)
|
|
h_r, h_i, h_n = gate_h.chunk(3, 1)
|
|
|
|
resetgate = torch.sigmoid(i_r + h_r)
|
|
inputgate = torch.sigmoid(i_i + h_i)
|
|
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
|
|
hy = newgate + inputgate * (hx - newgate)
|
|
return hy
|
|
|
|
def _gru_layer(self, x, hx, weights):
|
|
weight_ih, weight_hh, bias_ih, bias_hh = weights
|
|
outputs = []
|
|
|
|
for x_t in x.unbind(1):
|
|
hx = self._gru_cell(x_t, hx, weight_ih, bias_ih, weight_hh, bias_hh)
|
|
outputs.append(hx)
|
|
|
|
return torch.stack(outputs, dim=1), hx
|
|
|
|
def _gru(self, x, hx):
|
|
if not self.batch_first: x = x.permute(1, 0, 2)
|
|
num_directions = 2 if self.bidirectional else 1
|
|
|
|
h_n = []
|
|
output_fwd, output_bwd = x, x
|
|
|
|
for layer in range(self.num_layers):
|
|
fwd_idx = layer * num_directions
|
|
bwd_idx = fwd_idx + 1 if self.bidirectional else None
|
|
|
|
weights_fwd = self._get_weights(fwd_idx)
|
|
h_fwd = hx[fwd_idx]
|
|
|
|
out_fwd, h_out_fwd = self._gru_layer(output_fwd, h_fwd, weights_fwd)
|
|
h_n.append(h_out_fwd)
|
|
|
|
if self.bidirectional:
|
|
weights_bwd = self._get_weights(bwd_idx)
|
|
h_bwd = hx[bwd_idx]
|
|
|
|
reversed_input = torch.flip(output_bwd, dims=[1])
|
|
out_bwd, h_out_bwd = self._gru_layer(reversed_input, h_bwd, weights_bwd)
|
|
|
|
out_bwd = torch.flip(out_bwd, dims=[1])
|
|
h_n.append(h_out_bwd)
|
|
|
|
output_fwd = torch.cat([out_fwd, out_bwd], dim=2)
|
|
output_bwd = output_fwd
|
|
else: output_fwd = out_fwd
|
|
|
|
if layer < self.num_layers - 1 and self.dropout > 0:
|
|
output_fwd = F.dropout(output_fwd, p=self.dropout, training=self.training)
|
|
if self.bidirectional: output_bwd = output_fwd
|
|
|
|
output = output_fwd
|
|
h_n = torch.stack(h_n, dim=0)
|
|
|
|
if not self.batch_first: output = output.permute(1, 0, 2)
|
|
return output, h_n
|
|
|
|
def _get_weights(self, layer_idx):
|
|
weights = self._all_weights[layer_idx]
|
|
|
|
weight_ih = getattr(self, weights[0])
|
|
weight_hh = getattr(self, weights[1])
|
|
|
|
bias_ih = getattr(self, weights[2]) if self.bias else None
|
|
bias_hh = getattr(self, weights[3]) if self.bias else None
|
|
|
|
return weight_ih, weight_hh, bias_ih, bias_hh
|
|
|
|
def forward(self, input, hx=None):
|
|
if input.dim() != 3: raise ValueError
|
|
|
|
batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
num_directions = 2 if self.bidirectional else 1
|
|
|
|
if hx is None: hx = torch.zeros(self.num_layers * num_directions, batch_size, self.hidden_size, dtype=input.dtype, device=input.device)
|
|
|
|
self.check_forward_args(input, hx, batch_sizes=None)
|
|
return self._gru(input, hx)
|
|
|
|
def group_norm(x, num_groups, weight=None, bias=None, eps=1e-5):
|
|
N, C = x.shape[:2]
|
|
assert C % num_groups == 0
|
|
|
|
shape = (N, num_groups, C // num_groups) + x.shape[2:]
|
|
x_reshaped = x.view(shape)
|
|
|
|
dims = (2,) + tuple(range(3, x_reshaped.dim()))
|
|
mean = x_reshaped.mean(dim=dims, keepdim=True)
|
|
var = x_reshaped.var(dim=dims, keepdim=True, unbiased=False)
|
|
|
|
x_norm = (x_reshaped - mean) / torch.sqrt(var + eps)
|
|
x_norm = x_norm.view_as(x)
|
|
|
|
if weight is not None:
|
|
weight = weight.view(1, C, *([1] * (x.dim() - 2)))
|
|
x_norm = x_norm * weight
|
|
|
|
if bias is not None:
|
|
bias = bias.view(1, C, *([1] * (x.dim() - 2)))
|
|
x_norm = x_norm + bias
|
|
|
|
return x_norm
|
|
|
|
def script(f, *_, **__):
|
|
f.graph = pytorch_ocl.torch._C.Graph()
|
|
return f
|
|
|
|
if torch_available:
|
|
nn.GRU = GRU
|
|
F.group_norm = group_norm
|
|
torch.jit.script = script |