Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,137 Bytes
7f2690b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import sys
import torch
import torch.nn as nn
import torchvision
sys.path.insert(0, '.') # nopep8
from foleycrafter.models.specvqgan.modules.video_model.resnet import r2plus1d_18
FPS = 15
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class r2plus1d18KeepTemp(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.model = r2plus1d_18(pretrained=pretrained)
self.model.layer2[0].conv1[0][3] = nn.Conv3d(230, 128, kernel_size=(3, 1, 1),
stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
self.model.layer2[0].downsample = nn.Sequential(
nn.Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),
nn.BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
self.model.layer3[0].conv1[0][3] = nn.Conv3d(460, 256, kernel_size=(3, 1, 1),
stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
self.model.layer3[0].downsample = nn.Sequential(
nn.Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),
nn.BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
self.model.layer4[0].conv1[0][3] = nn.Conv3d(921, 512, kernel_size=(3, 1, 1),
stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
self.model.layer4[0].downsample = nn.Sequential(
nn.Conv3d(256, 512, kernel_size=(1, 1, 1), stride=(1, 2, 2), bias=False),
nn.BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
self.model.avgpool = nn.AdaptiveAvgPool3d((None, 1, 1))
self.model.fc = Identity()
with torch.no_grad():
rand_input = torch.randn((1, 3, 30, 112, 112))
output = self.model(rand_input).detach().cpu()
print('Validate Video feature shape: ', output.shape) # (1, 512, 30)
def forward(self, x):
N = x.shape[0]
return self.model(x).reshape(N, 512, -1)
def eval(self):
return self
def encode(self, c):
info = None, None, c
return c, None, info
def decode(self, c):
return c
def get_input(self, batch, k, drop_cond=False):
x = batch[k].cuda()
x = x.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format) # (N, 3, T, 112, 112)
T = x.shape[2]
if drop_cond:
output = self.model(x) # (N, 512, T)
else:
cond_x = x[:, :, :T//2] # (N, 3, T//2, 112, 112)
x = x[:, :, T//2:] # (N, 3, T//2, 112, 112)
cond_feat = self.model(cond_x) # (N, 512, T//2)
feat = self.model(x) # (N, 512, T//2)
output = torch.cat([cond_feat, feat], dim=-1) # (N, 512, T)
assert output.shape[2] == T
return output
class resnet50(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
self.model = torchvision.models.resnet50(pretrained=pretrained)
self.model.fc = nn.Identity()
# freeze resnet 50 model
for params in self.model.parameters():
params.requires_grad = False
def forward(self, x):
N = x.shape[0]
return self.model(x).reshape(N, 2048)
def eval(self):
return self
def encode(self, c):
info = None, None, c
return c, None, info
def decode(self, c):
return c
def get_input(self, batch, k, drop_cond=False):
x = batch[k].cuda()
x = x.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format) # (N, 3, T, 112, 112)
T = x.shape[2]
feats = []
for t in range(T):
xt = x[:, :, t]
feats.append(self.model(xt))
output = torch.stack(feats, dim=-1)
assert output.shape[2] == T
return output
if __name__ == '__main__':
model = r2plus1d18KeepTemp(False).cuda()
x = {'input': torch.randn((1, 60, 3, 112, 112))}
out = model.get_input(x, 'input')
print(out.shape)
|