Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,973 Bytes
1da48bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import pdb
import torch as t
import torch.nn as nn
from resnet import Resnet, Resnet1D
from utils.torch_utils import assert_shape
class EncoderConvBlock(nn.Module):
def __init__(self, input_emb_width, output_emb_width, down_t,
stride_t, width, depth, m_conv,
dilation_growth_rate=1, dilation_cycle=None, zero_out=False,
res_scale=False):
super().__init__()
blocks = []
filter_t, pad_t = stride_t * 2, stride_t // 2
if down_t > 0:
for i in range(down_t):
block = nn.Sequential(
# nn.Conv1d(input_emb_width if i == 0 else width, width, filter_t, stride_t, pad_t, padding_mode='replicate'),
nn.Conv1d(input_emb_width if i == 0 else width, width, filter_t, stride_t, pad_t),
Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out, res_scale),
)
blocks.append(block)
block = nn.Conv1d(width, output_emb_width, 3, 1, 1)
# block = nn.Conv1d(width, output_emb_width, 3, 1, 1, padding_mode='replicate')
blocks.append(block)
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class DecoderConvBock(nn.Module):
def __init__(self, input_emb_width, output_emb_width, down_t,
stride_t, width, depth, m_conv, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_decoder_dilation=False, checkpoint_res=False):
super().__init__()
blocks = []
if down_t > 0:
filter_t, pad_t = stride_t * 2, stride_t // 2
block = nn.Conv1d(output_emb_width, width, 3, 1, 1)
# block = nn.Conv1d(output_emb_width, width, 3, 1, 1, padding_mode='replicate')
blocks.append(block)
for i in range(down_t):
block = nn.Sequential(
Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out=zero_out, res_scale=res_scale, reverse_dilation=reverse_decoder_dilation, checkpoint_res=checkpoint_res),
nn.ConvTranspose1d(width, input_emb_width if i == (down_t - 1) else width, filter_t, stride_t, pad_t)
)
blocks.append(block)
self.model = nn.Sequential(*blocks)
def forward(self, x):
return self.model(x)
class Encoder(nn.Module):
def __init__(self, input_emb_width, output_emb_width, levels, downs_t,
strides_t, **block_kwargs):
super().__init__()
self.input_emb_width = input_emb_width
self.output_emb_width = output_emb_width
self.levels = levels
self.downs_t = downs_t
self.strides_t = strides_t
block_kwargs_copy = dict(**block_kwargs)
if 'reverse_decoder_dilation' in block_kwargs_copy:
del block_kwargs_copy['reverse_decoder_dilation']
level_block = lambda level, down_t, stride_t: EncoderConvBlock(input_emb_width if level == 0 else output_emb_width,
output_emb_width,
down_t, stride_t,
**block_kwargs_copy)
self.level_blocks = nn.ModuleList()
iterator = zip(list(range(self.levels)), downs_t, strides_t)
for level, down_t, stride_t in iterator:
self.level_blocks.append(level_block(level, down_t, stride_t))
def forward(self, x):
N, T = x.shape[0], x.shape[-1]
emb = self.input_emb_width
assert_shape(x, (N, emb, T))
xs = []
# 64, 32, ...
iterator = zip(list(range(self.levels)), self.downs_t, self.strides_t)
for level, down_t, stride_t in iterator:
level_block = self.level_blocks[level]
x = level_block(x)
emb, T = self.output_emb_width, T // (stride_t ** down_t)
assert_shape(x, (N, emb, T))
xs.append(x)
return xs
class Decoder(nn.Module):
def __init__(self, input_emb_width, output_emb_width, levels, downs_t,
strides_t, **block_kwargs):
super().__init__()
self.input_emb_width = input_emb_width
self.output_emb_width = output_emb_width
self.levels = levels
self.downs_t = downs_t
self.strides_t = strides_t
level_block = lambda level, down_t, stride_t: DecoderConvBock(output_emb_width,
output_emb_width,
down_t, stride_t,
**block_kwargs)
self.level_blocks = nn.ModuleList()
iterator = zip(list(range(self.levels)), downs_t, strides_t)
for level, down_t, stride_t in iterator:
self.level_blocks.append(level_block(level, down_t, stride_t))
self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1)
# self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1, padding_mode='replicate')
def forward(self, xs, all_levels=True):
if all_levels:
assert len(xs) == self.levels
else:
assert len(xs) == 1
x = xs[-1]
N, T = x.shape[0], x.shape[-1]
emb = self.output_emb_width
assert_shape(x, (N, emb, T))
# 32, 64 ...
iterator = reversed(list(zip(list(range(self.levels)), self.downs_t, self.strides_t)))
for level, down_t, stride_t in iterator:
level_block = self.level_blocks[level]
x = level_block(x)
emb, T = self.output_emb_width, T * (stride_t ** down_t)
assert_shape(x, (N, emb, T))
if level != 0 and all_levels:
x = x + xs[level - 1]
x = self.out(x)
return x
|