Spaces:
Running
Running
File size: 8,048 Bytes
95f8bbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import torch.nn as nn
class TemporalModelBase(nn.Module):
"""
Do not instantiate this class.
"""
def __init__(self, num_joints_in, in_features, num_joints_out,
filter_widths, causal, dropout, channels):
super().__init__()
# Validate input
for fw in filter_widths:
assert fw % 2 != 0, 'Only odd filter widths are supported'
self.num_joints_in = num_joints_in
self.in_features = in_features
self.num_joints_out = num_joints_out
self.filter_widths = filter_widths
self.drop = nn.Dropout(dropout)
self.relu = nn.ReLU(inplace=True)
self.pad = [filter_widths[0] // 2]
self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1)
self.shrink = nn.Conv1d(channels, num_joints_out * 3, 1)
def set_bn_momentum(self, momentum):
self.expand_bn.momentum = momentum
for bn in self.layers_bn:
bn.momentum = momentum
def receptive_field(self):
"""
Return the total receptive field of this model as # of frames.
"""
frames = 0
for f in self.pad:
frames += f
return 1 + 2 * frames
def total_causal_shift(self):
"""
Return the asymmetric offset for sequence padding.
The returned value is typically 0 if causal convolutions are disabled,
otherwise it is half the receptive field.
"""
frames = self.causal_shift[0]
next_dilation = self.filter_widths[0]
for i in range(1, len(self.filter_widths)):
frames += self.causal_shift[i] * next_dilation
next_dilation *= self.filter_widths[i]
return frames
def forward(self, x):
assert len(x.shape) == 4
assert x.shape[-2] == self.num_joints_in
assert x.shape[-1] == self.in_features
sz = x.shape[:3]
x = x.view(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1)
x = self._forward_blocks(x)
x = x.permute(0, 2, 1)
x = x.view(sz[0], -1, self.num_joints_out, 3)
return x
class TemporalModel(TemporalModelBase):
"""
Reference 3D pose estimation model with temporal convolutions.
This implementation can be used for all use-cases.
"""
def __init__(self, num_joints_in, in_features, num_joints_out,
filter_widths, causal=False, dropout=0.25, channels=1024, dense=False):
"""
Initialize this model.
Arguments:
num_joints_in -- number of input joints (e.g. 17 for Human3.6M)
in_features -- number of input features for each joint (typically 2 for 2D input)
num_joints_out -- number of output joints (can be different than input)
filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field
causal -- use causal convolutions instead of symmetric convolutions (for real-time applications)
dropout -- dropout probability
channels -- number of convolution channels
dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment)
"""
super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels)
self.expand_conv = nn.Conv1d(num_joints_in * in_features, channels, filter_widths[0], bias=False)
layers_conv = []
layers_bn = []
self.causal_shift = [(filter_widths[0]) // 2 if causal else 0]
next_dilation = filter_widths[0]
for i in range(1, len(filter_widths)):
self.pad.append((filter_widths[i] - 1) * next_dilation // 2)
self.causal_shift.append((filter_widths[i] // 2 * next_dilation) if causal else 0)
layers_conv.append(nn.Conv1d(channels, channels,
filter_widths[i] if not dense else (2 * self.pad[-1] + 1),
dilation=next_dilation if not dense else 1,
bias=False))
layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
next_dilation *= filter_widths[i]
self.layers_conv = nn.ModuleList(layers_conv)
self.layers_bn = nn.ModuleList(layers_bn)
def _forward_blocks(self, x):
x = self.drop(self.relu(self.expand_bn(self.expand_conv(x))))
for i in range(len(self.pad) - 1):
pad = self.pad[i + 1]
shift = self.causal_shift[i + 1]
# clip
res = x[:, :, pad + shift: x.shape[2] - pad + shift]
x = self.drop(self.relu(self.layers_bn[2 * i](self.layers_conv[2 * i](x))))
x = res + self.drop(self.relu(self.layers_bn[2 * i + 1](self.layers_conv[2 * i + 1](x))))
x = self.shrink(x)
return x
class TemporalModelOptimized1f(TemporalModelBase):
"""
3D pose estimation model optimized for single-frame batching, i.e.
where batches have input length = receptive field, and output length = 1.
This scenario is only used for training when stride == 1.
This implementation replaces dilated convolutions with strided convolutions
to avoid generating unused intermediate results. The weights are interchangeable
with the reference implementation.
"""
def __init__(self, num_joints_in, in_features, num_joints_out,
filter_widths, causal=False, dropout=0.25, channels=1024):
"""
Initialize this model.
Arguments:
num_joints_in -- number of input joints (e.g. 17 for Human3.6M)
in_features -- number of input features for each joint (typically 2 for 2D input)
num_joints_out -- number of output joints (can be different than input)
filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field
causal -- use causal convolutions instead of symmetric convolutions (for real-time applications)
dropout -- dropout probability
channels -- number of convolution channels
"""
super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels)
self.expand_conv = nn.Conv1d(num_joints_in * in_features, channels, filter_widths[0], stride=filter_widths[0], bias=False)
layers_conv = []
layers_bn = []
self.causal_shift = [(filter_widths[0] // 2) if causal else 0]
next_dilation = filter_widths[0]
for i in range(1, len(filter_widths)):
self.pad.append((filter_widths[i] - 1) * next_dilation // 2)
self.causal_shift.append((filter_widths[i] // 2) if causal else 0)
layers_conv.append(nn.Conv1d(channels, channels, filter_widths[i], stride=filter_widths[i], bias=False))
layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
next_dilation *= filter_widths[i]
self.layers_conv = nn.ModuleList(layers_conv)
self.layers_bn = nn.ModuleList(layers_bn)
def _forward_blocks(self, x):
x = self.drop(self.relu(self.expand_bn(self.expand_conv(x))))
for i in range(len(self.pad) - 1):
res = x[:, :, self.causal_shift[i + 1] + self.filter_widths[i + 1] // 2:: self.filter_widths[i + 1]]
x = self.drop(self.relu(self.layers_bn[2 * i](self.layers_conv[2 * i](x))))
x = res + self.drop(self.relu(self.layers_bn[2 * i + 1](self.layers_conv[2 * i + 1](x))))
x = self.shrink(x)
return x
|