import torch import torch.nn as nn from einops import rearrange import torch.nn.functional as F class Conv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, cnn_type="2d", causal_offset=0, temporal_down=False): super().__init__() self.cnn_type = cnn_type self.slice_seq_len = 17 if cnn_type == "2d": self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding) if cnn_type == "3d": if temporal_down == False: stride = (1, stride, stride) else: stride = (stride, stride, stride) self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0) if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) self.padding = ( kernel_size[0] - 1 + causal_offset, # Temporal causal padding padding, # Height padding padding # Width padding ) self.causal_offset = causal_offset self.stride = stride self.kernel_size = kernel_size def forward(self, x): if self.cnn_type == "2d": if x.ndim == 5: B, C, T, H, W = x.shape x = rearrange(x, "B C T H W -> (B T) C H W") x = self.conv(x) x = rearrange(x, "(B T) C H W -> B C T H W", T=T) return x else: return self.conv(x) if self.cnn_type == "3d": assert self.stride[0] == 1 or self.stride[0] == 2, f"only temporal stride = 1 or 2 are supported" xs = [] for i in range(0, x.shape[2], self.slice_seq_len+self.stride[0]-1): st = i en = min(i+self.slice_seq_len, x.shape[2]) _x = x[:,:,st:en,:,:] if i == 0: _x = F.pad(_x, (self.padding[2], self.padding[2], # Width self.padding[1], self.padding[1], # Height self.padding[0], 0)) # Temporal else: padding_0 = self.kernel_size[0] - 1 _x = F.pad(_x, (self.padding[2], self.padding[2], # Width self.padding[1], self.padding[1], # Height padding_0, 0)) # Temporal _x[:,:,:padding_0, self.padding[1]:_x.shape[-2]-self.padding[1], self.padding[2]:_x.shape[-1]-self.padding[2]] += x[:,:,i-padding_0:i,:,:] _x = self.conv(_x) xs.append(_x) try: x = torch.cat(xs, dim=2) except: device = x.device del x xs = [_x.cpu().pin_memory() for _x in xs] torch.cuda.empty_cache() x = torch.cat([_x.cpu() for _x in xs], dim=2).to(device=device) return x