Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,112 Bytes
32287b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, cnn_type="2d", causal_offset=0, temporal_down=False):
super().__init__()
self.cnn_type = cnn_type
self.slice_seq_len = 17
if cnn_type == "2d":
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
if cnn_type == "3d":
if temporal_down == False:
stride = (1, stride, stride)
else:
stride = (stride, stride, stride)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0)
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
self.padding = (
kernel_size[0] - 1 + causal_offset, # Temporal causal padding
padding, # Height padding
padding # Width padding
)
self.causal_offset = causal_offset
self.stride = stride
self.kernel_size = kernel_size
def forward(self, x):
if self.cnn_type == "2d":
if x.ndim == 5:
B, C, T, H, W = x.shape
x = rearrange(x, "B C T H W -> (B T) C H W")
x = self.conv(x)
x = rearrange(x, "(B T) C H W -> B C T H W", T=T)
return x
else:
return self.conv(x)
if self.cnn_type == "3d":
assert self.stride[0] == 1 or self.stride[0] == 2, f"only temporal stride = 1 or 2 are supported"
xs = []
for i in range(0, x.shape[2], self.slice_seq_len+self.stride[0]-1):
st = i
en = min(i+self.slice_seq_len, x.shape[2])
_x = x[:,:,st:en,:,:]
if i == 0:
_x = F.pad(_x, (self.padding[2], self.padding[2], # Width
self.padding[1], self.padding[1], # Height
self.padding[0], 0)) # Temporal
else:
padding_0 = self.kernel_size[0] - 1
_x = F.pad(_x, (self.padding[2], self.padding[2], # Width
self.padding[1], self.padding[1], # Height
padding_0, 0)) # Temporal
_x[:,:,:padding_0,
self.padding[1]:_x.shape[-2]-self.padding[1],
self.padding[2]:_x.shape[-1]-self.padding[2]] += x[:,:,i-padding_0:i,:,:]
_x = self.conv(_x)
xs.append(_x)
try:
x = torch.cat(xs, dim=2)
except:
device = x.device
del x
xs = [_x.cpu().pin_memory() for _x in xs]
torch.cuda.empty_cache()
x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device)
return x |