RVC-GUI / main /library /opencl.py
AnhP's picture
Upload 170 files
1e4a2ab verified
raw
history blame
7.55 kB
import torch
import platform
import subprocess
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from librosa.util import pad_center
from scipy.signal import get_window
try:
import pytorch_ocl
except:
pytorch_ocl = None
torch_available = pytorch_ocl != None
def check_amd_gpu(gpu):
for i in ["RX", "AMD", "Vega", "Radeon", "FirePro"]:
return i in gpu
def get_amd_gpu_windows():
try:
return [gpu.strip() for gpu in subprocess.check_output("wmic path win32_VideoController get name", shell=True).decode().split('\n')[1:] if check_amd_gpu(gpu)]
except:
return []
def get_amd_gpu_linux():
try:
return [gpu for gpu in subprocess.check_output("lspci | grep VGA", shell=True).decode().split('\n') if check_amd_gpu(gpu)]
except:
return []
def get_gpu_list():
return (get_amd_gpu_windows() if platform.system() == "Windows" else get_amd_gpu_linux()) if torch_available else []
def device_count():
return len(get_gpu_list()) if torch_available else 0
def device_name(device_id = 0):
return (get_gpu_list()[device_id] if device_id >= 0 and device_id < device_count() else "") if torch_available else ""
def is_available():
return (device_count() > 0) if torch_available else False
class STFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=512, win_length=None, window="hann"):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.pad_amount = int(self.filter_length / 2)
self.win_length = win_length
self.hann_window = {}
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis)
inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
if win_length is None or not win_length: win_length = filter_length
assert filter_length >= win_length
fft_window = torch.from_numpy(pad_center(get_window(window, win_length, fftbins=True), size=filter_length)).float()
forward_basis *= fft_window
inverse_basis = (inverse_basis.T * fft_window).T
self.register_buffer("forward_basis", forward_basis.float())
self.register_buffer("inverse_basis", inverse_basis.float())
self.register_buffer("fft_window", fft_window.float())
def transform(self, input_data, eps):
input_data = F.pad(input_data, (self.pad_amount, self.pad_amount), mode="reflect")
forward_transform = torch.matmul(self.forward_basis, input_data.unfold(1, self.filter_length, self.hop_length).permute(0, 2, 1))
cutoff = int(self.filter_length / 2 + 1)
return torch.sqrt(forward_transform[:, :cutoff, :]**2 + forward_transform[:, cutoff:, :]**2 + eps)
class GRU(nn.RNNBase):
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0.0, bidirectional=False, device=None, dtype=None):
super().__init__("GRU", input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, device=device, dtype=dtype)
@staticmethod
def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh):
gate_x = F.linear(x, weight_ih, bias_ih)
gate_h = F.linear(hx, weight_hh, bias_hh)
i_r, i_i, i_n = gate_x.chunk(3, 1)
h_r, h_i, h_n = gate_h.chunk(3, 1)
resetgate = torch.sigmoid(i_r + h_r)
inputgate = torch.sigmoid(i_i + h_i)
newgate = torch.tanh(i_n + resetgate * h_n)
hy = newgate + inputgate * (hx - newgate)
return hy
def _gru_layer(self, x, hx, weights):
weight_ih, weight_hh, bias_ih, bias_hh = weights
outputs = []
for x_t in x.unbind(1):
hx = self._gru_cell(x_t, hx, weight_ih, bias_ih, weight_hh, bias_hh)
outputs.append(hx)
return torch.stack(outputs, dim=1), hx
def _gru(self, x, hx):
if not self.batch_first: x = x.permute(1, 0, 2)
num_directions = 2 if self.bidirectional else 1
h_n = []
output_fwd, output_bwd = x, x
for layer in range(self.num_layers):
fwd_idx = layer * num_directions
bwd_idx = fwd_idx + 1 if self.bidirectional else None
weights_fwd = self._get_weights(fwd_idx)
h_fwd = hx[fwd_idx]
out_fwd, h_out_fwd = self._gru_layer(output_fwd, h_fwd, weights_fwd)
h_n.append(h_out_fwd)
if self.bidirectional:
weights_bwd = self._get_weights(bwd_idx)
h_bwd = hx[bwd_idx]
reversed_input = torch.flip(output_bwd, dims=[1])
out_bwd, h_out_bwd = self._gru_layer(reversed_input, h_bwd, weights_bwd)
out_bwd = torch.flip(out_bwd, dims=[1])
h_n.append(h_out_bwd)
output_fwd = torch.cat([out_fwd, out_bwd], dim=2)
output_bwd = output_fwd
else: output_fwd = out_fwd
if layer < self.num_layers - 1 and self.dropout > 0:
output_fwd = F.dropout(output_fwd, p=self.dropout, training=self.training)
if self.bidirectional: output_bwd = output_fwd
output = output_fwd
h_n = torch.stack(h_n, dim=0)
if not self.batch_first: output = output.permute(1, 0, 2)
return output, h_n
def _get_weights(self, layer_idx):
weights = self._all_weights[layer_idx]
weight_ih = getattr(self, weights[0])
weight_hh = getattr(self, weights[1])
bias_ih = getattr(self, weights[2]) if self.bias else None
bias_hh = getattr(self, weights[3]) if self.bias else None
return weight_ih, weight_hh, bias_ih, bias_hh
def forward(self, input, hx=None):
if input.dim() != 3: raise ValueError
batch_size = input.size(0) if self.batch_first else input.size(1)
num_directions = 2 if self.bidirectional else 1
if hx is None: hx = torch.zeros(self.num_layers * num_directions, batch_size, self.hidden_size, dtype=input.dtype, device=input.device)
self.check_forward_args(input, hx, batch_sizes=None)
return self._gru(input, hx)
def group_norm(x, num_groups, weight=None, bias=None, eps=1e-5):
N, C = x.shape[:2]
assert C % num_groups == 0
shape = (N, num_groups, C // num_groups) + x.shape[2:]
x_reshaped = x.view(shape)
dims = (2,) + tuple(range(3, x_reshaped.dim()))
mean = x_reshaped.mean(dim=dims, keepdim=True)
var = x_reshaped.var(dim=dims, keepdim=True, unbiased=False)
x_norm = (x_reshaped - mean) / torch.sqrt(var + eps)
x_norm = x_norm.view_as(x)
if weight is not None:
weight = weight.view(1, C, *([1] * (x.dim() - 2)))
x_norm = x_norm * weight
if bias is not None:
bias = bias.view(1, C, *([1] * (x.dim() - 2)))
x_norm = x_norm + bias
return x_norm
def script(f, *_, **__):
f.graph = pytorch_ocl.torch._C.Graph()
return f
if torch_available:
nn.GRU = GRU
F.group_norm = group_norm
torch.jit.script = script