Spaces:
Sleeping
Sleeping
File size: 3,899 Bytes
801501a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from ..custom_types import *
from abc import ABC
import math
def torch_no_grad(func):
def wrapper(*args, **kwargs):
with torch.no_grad():
result = func(*args, **kwargs)
return result
return wrapper
class Model(nn.Module, ABC):
def __init__(self):
super(Model, self).__init__()
self.save_model: Union[None, Callable[[nn.Module]]] = None
def save(self, **kwargs):
self.save_model(self, **kwargs)
class Concatenate(nn.Module):
def __init__(self, dim):
super(Concatenate, self).__init__()
self.dim = dim
def forward(self, x):
return torch.cat(x, dim=self.dim)
class View(nn.Module):
def __init__(self, *shape):
super(View, self).__init__()
self.shape = shape
def forward(self, x):
return x.view(*self.shape)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0, self.dim1 = dim0, dim1
def forward(self, x):
return x.transpose(self.dim0, self.dim1)
class Dummy(nn.Module):
def __init__(self, *args):
super(Dummy, self).__init__()
def forward(self, *args):
return args[0]
class SineLayer(nn.Module):
"""
From the siren repository
https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb
"""
def __init__(self, in_features, out_features, bias=True,
is_first=False, omega_0=30):
super().__init__()
self.omega_0 = omega_0
self.is_first = is_first
self.in_features = in_features
self.linear = nn.Linear(in_features, out_features, bias=bias)
self.output_channels = out_features
self.init_weights()
def init_weights(self):
with torch.no_grad():
if self.is_first:
self.linear.weight.uniform_(-1 / self.in_features,
1 / self.in_features)
else:
self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
np.sqrt(6 / self.in_features) / self.omega_0)
def forward(self, input):
return torch.sin(self.omega_0 * self.linear(input))
class MLP(nn.Module):
def forward(self, x, *_):
return self.net(x)
def __init__(self, ch: Union[List[int], Tuple[int, ...]], act: nn.Module = nn.ReLU,
weight_norm=False):
super(MLP, self).__init__()
layers = []
for i in range(len(ch) - 1):
layers.append(nn.Linear(ch[i], ch[i + 1]))
if weight_norm:
layers[-1] = nn.utils.weight_norm(layers[-1])
if i < len(ch) - 2:
layers.append(act(True))
self.net = nn.Sequential(*layers)
class GMAttend(nn.Module):
def __init__(self, hidden_dim: int):
super(GMAttend, self).__init__()
self.key_dim = hidden_dim // 8
self.query_w = nn.Linear(hidden_dim, self.key_dim)
self.key_w = nn.Linear(hidden_dim, self.key_dim)
self.value_w = nn.Linear(hidden_dim, hidden_dim)
self.softmax = nn.Softmax(dim=3)
self.gamma = nn.Parameter(torch.zeros(1))
self.scale = 1 / torch.sqrt(torch.tensor(self.key_dim, dtype=torch.float32))
def forward(self, x):
queries = self.query_w(x)
keys = self.key_w(x)
vals = self.value_w(x)
attention = self.softmax(torch.einsum('bgqf,bgkf->bgqk', queries, keys))
out = torch.einsum('bgvf,bgqv->bgqf', vals, attention)
out = self.gamma * out + x
return out
def recursive_to(item, device):
if type(item) is T:
return item.to(device)
elif type(item) is tuple or type(item) is list:
return [recursive_to(item[i], device) for i in range(len(item))]
return item
|