File size: 7,547 Bytes
1e4a2ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import torch
import platform
import subprocess
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from librosa.util import pad_center
from scipy.signal import get_window
try:
import pytorch_ocl
except:
pytorch_ocl = None
torch_available = pytorch_ocl != None
def check_amd_gpu(gpu):
for i in ["RX", "AMD", "Vega", "Radeon", "FirePro"]:
return i in gpu
def get_amd_gpu_windows():
try:
return [gpu.strip() for gpu in subprocess.check_output("wmic path win32_VideoController get name", shell=True).decode().split('\n')[1:] if check_amd_gpu(gpu)]
except:
return []
def get_amd_gpu_linux():
try:
return [gpu for gpu in subprocess.check_output("lspci | grep VGA", shell=True).decode().split('\n') if check_amd_gpu(gpu)]
except:
return []
def get_gpu_list():
return (get_amd_gpu_windows() if platform.system() == "Windows" else get_amd_gpu_linux()) if torch_available else []
def device_count():
return len(get_gpu_list()) if torch_available else 0
def device_name(device_id = 0):
return (get_gpu_list()[device_id] if device_id >= 0 and device_id < device_count() else "") if torch_available else ""
def is_available():
return (device_count() > 0) if torch_available else False
class STFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=512, win_length=None, window="hann"):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.pad_amount = int(self.filter_length / 2)
self.win_length = win_length
self.hann_window = {}
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis)
inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
if win_length is None or not win_length: win_length = filter_length
assert filter_length >= win_length
fft_window = torch.from_numpy(pad_center(get_window(window, win_length, fftbins=True), size=filter_length)).float()
forward_basis *= fft_window
inverse_basis = (inverse_basis.T * fft_window).T
self.register_buffer("forward_basis", forward_basis.float())
self.register_buffer("inverse_basis", inverse_basis.float())
self.register_buffer("fft_window", fft_window.float())
def transform(self, input_data, eps):
input_data = F.pad(input_data, (self.pad_amount, self.pad_amount), mode="reflect")
forward_transform = torch.matmul(self.forward_basis, input_data.unfold(1, self.filter_length, self.hop_length).permute(0, 2, 1))
cutoff = int(self.filter_length / 2 + 1)
return torch.sqrt(forward_transform[:, :cutoff, :]**2 + forward_transform[:, cutoff:, :]**2 + eps)
class GRU(nn.RNNBase):
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0.0, bidirectional=False, device=None, dtype=None):
super().__init__("GRU", input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, device=device, dtype=dtype)
@staticmethod
def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh):
gate_x = F.linear(x, weight_ih, bias_ih)
gate_h = F.linear(hx, weight_hh, bias_hh)
i_r, i_i, i_n = gate_x.chunk(3, 1)
h_r, h_i, h_n = gate_h.chunk(3, 1)
resetgate = torch.sigmoid(i_r + h_r)
inputgate = torch.sigmoid(i_i + h_i)
newgate = torch.tanh(i_n + resetgate * h_n)
hy = newgate + inputgate * (hx - newgate)
return hy
def _gru_layer(self, x, hx, weights):
weight_ih, weight_hh, bias_ih, bias_hh = weights
outputs = []
for x_t in x.unbind(1):
hx = self._gru_cell(x_t, hx, weight_ih, bias_ih, weight_hh, bias_hh)
outputs.append(hx)
return torch.stack(outputs, dim=1), hx
def _gru(self, x, hx):
if not self.batch_first: x = x.permute(1, 0, 2)
num_directions = 2 if self.bidirectional else 1
h_n = []
output_fwd, output_bwd = x, x
for layer in range(self.num_layers):
fwd_idx = layer * num_directions
bwd_idx = fwd_idx + 1 if self.bidirectional else None
weights_fwd = self._get_weights(fwd_idx)
h_fwd = hx[fwd_idx]
out_fwd, h_out_fwd = self._gru_layer(output_fwd, h_fwd, weights_fwd)
h_n.append(h_out_fwd)
if self.bidirectional:
weights_bwd = self._get_weights(bwd_idx)
h_bwd = hx[bwd_idx]
reversed_input = torch.flip(output_bwd, dims=[1])
out_bwd, h_out_bwd = self._gru_layer(reversed_input, h_bwd, weights_bwd)
out_bwd = torch.flip(out_bwd, dims=[1])
h_n.append(h_out_bwd)
output_fwd = torch.cat([out_fwd, out_bwd], dim=2)
output_bwd = output_fwd
else: output_fwd = out_fwd
if layer < self.num_layers - 1 and self.dropout > 0:
output_fwd = F.dropout(output_fwd, p=self.dropout, training=self.training)
if self.bidirectional: output_bwd = output_fwd
output = output_fwd
h_n = torch.stack(h_n, dim=0)
if not self.batch_first: output = output.permute(1, 0, 2)
return output, h_n
def _get_weights(self, layer_idx):
weights = self._all_weights[layer_idx]
weight_ih = getattr(self, weights[0])
weight_hh = getattr(self, weights[1])
bias_ih = getattr(self, weights[2]) if self.bias else None
bias_hh = getattr(self, weights[3]) if self.bias else None
return weight_ih, weight_hh, bias_ih, bias_hh
def forward(self, input, hx=None):
if input.dim() != 3: raise ValueError
batch_size = input.size(0) if self.batch_first else input.size(1)
num_directions = 2 if self.bidirectional else 1
if hx is None: hx = torch.zeros(self.num_layers * num_directions, batch_size, self.hidden_size, dtype=input.dtype, device=input.device)
self.check_forward_args(input, hx, batch_sizes=None)
return self._gru(input, hx)
def group_norm(x, num_groups, weight=None, bias=None, eps=1e-5):
N, C = x.shape[:2]
assert C % num_groups == 0
shape = (N, num_groups, C // num_groups) + x.shape[2:]
x_reshaped = x.view(shape)
dims = (2,) + tuple(range(3, x_reshaped.dim()))
mean = x_reshaped.mean(dim=dims, keepdim=True)
var = x_reshaped.var(dim=dims, keepdim=True, unbiased=False)
x_norm = (x_reshaped - mean) / torch.sqrt(var + eps)
x_norm = x_norm.view_as(x)
if weight is not None:
weight = weight.view(1, C, *([1] * (x.dim() - 2)))
x_norm = x_norm * weight
if bias is not None:
bias = bias.view(1, C, *([1] * (x.dim() - 2)))
x_norm = x_norm + bias
return x_norm
def script(f, *_, **__):
f.graph = pytorch_ocl.torch._C.Graph()
return f
if torch_available:
nn.GRU = GRU
F.group_norm = group_norm
torch.jit.script = script |