Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import nn | |
from einops import rearrange | |
from .transformer_utils import BaseTemperalPointModel | |
import math | |
from einops_exts import check_shape, rearrange_many | |
from functools import partial | |
class SinusoidalPosEmb(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
return emb | |
class RelativePositionBias(nn.Module): | |
def __init__( | |
self, | |
heads = 8, | |
num_buckets = 32, | |
max_distance = 128 | |
): | |
super().__init__() | |
self.num_buckets = num_buckets | |
self.max_distance = max_distance | |
self.relative_attention_bias = nn.Embedding(num_buckets, heads) | |
def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128): | |
ret = 0 | |
n = -relative_position | |
num_buckets //= 2 | |
ret += (n < 0).long() * num_buckets | |
n = torch.abs(n) | |
max_exact = num_buckets // 2 | |
is_small = n < max_exact | |
val_if_large = max_exact + ( | |
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) | |
).long() | |
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) | |
ret += torch.where(is_small, n, val_if_large) | |
return ret | |
def forward(self, n, device): | |
q_pos = torch.arange(n, dtype = torch.long, device = device) | |
k_pos = torch.arange(n, dtype = torch.long, device = device) | |
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') | |
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) | |
values = self.relative_attention_bias(rp_bucket) | |
return rearrange(values, 'i j h -> h i j') | |
def exists(x): | |
return x is not None | |
class Residual(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x, *args, **kwargs): | |
return self.fn(x, *args, **kwargs) + x | |
class LayerNorm(nn.Module): | |
def __init__(self, dim, eps = 1e-5): | |
super().__init__() | |
self.eps = eps | |
self.gamma = nn.Parameter(torch.ones(1, 1, dim)) | |
self.beta = nn.Parameter(torch.zeros(1, 1, dim)) | |
def forward(self, x): | |
var = torch.var(x, dim = -1, unbiased = False, keepdim = True) | |
mean = torch.mean(x, dim = -1, keepdim = True) | |
return (x - mean) / (var + self.eps).sqrt() * self.gamma + self.beta | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = LayerNorm(dim) | |
def forward(self, x, **kwargs): | |
x = self.norm(x) | |
return self.fn(x, **kwargs) | |
class EinopsToAndFrom(nn.Module): | |
def __init__(self, from_einops, to_einops, fn): | |
super().__init__() | |
self.from_einops = from_einops | |
self.to_einops = to_einops | |
self.fn = fn | |
def forward(self, x, **kwargs): | |
shape = x.shape | |
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape))) | |
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') | |
x = self.fn(x, **kwargs) | |
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) | |
return x | |
class Attention(nn.Module): | |
def __init__( | |
self, dim, heads=4, attn_head_dim=None, casual_attn=False,rotary_emb = None): | |
super().__init__() | |
self.num_heads = heads | |
head_dim = dim // heads | |
self.casual_attn = casual_attn | |
if attn_head_dim is not None: | |
head_dim = attn_head_dim | |
all_head_dim = head_dim * self.num_heads | |
self.scale = head_dim ** -0.5 | |
self.to_qkv = nn.Linear(dim, all_head_dim * 3, bias=False) | |
self.proj = nn.Linear(all_head_dim, dim) | |
self.rotary_emb = rotary_emb | |
def forward(self, x, pos_bias = None): | |
N, device = x.shape[-2], x.device | |
qkv = self.to_qkv(x).chunk(3, dim = -1) | |
q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.num_heads) | |
q = q * self.scale | |
if exists(self.rotary_emb): | |
q = self.rotary_emb.rotate_queries_or_keys(q) | |
k = self.rotary_emb.rotate_queries_or_keys(k) | |
sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k) | |
if exists(pos_bias): | |
sim = sim + pos_bias | |
if self.casual_attn: | |
mask = torch.tril(torch.ones(sim.size(-1), sim.size(-2))).to(device) | |
sim = sim.masked_fill(mask[..., :, :] == 0, float('-inf')) | |
attn = sim.softmax(dim = -1) | |
x = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v) | |
x = rearrange(x, '... h n d -> ... n (h d)') | |
x = self.proj(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, dim, dim_out): | |
super().__init__() | |
self.proj = nn.Linear(dim, dim_out) | |
self.norm = LayerNorm(dim) | |
self.act = nn.SiLU() | |
def forward(self, x, scale_shift=None): | |
x = self.proj(x) | |
if exists(scale_shift): | |
x = self.norm(x) | |
scale, shift = scale_shift | |
x = x * (scale + 1) + shift | |
return self.act(x) | |
class ResnetBlock(nn.Module): | |
def __init__(self, dim, dim_out, cond_dim=None): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(cond_dim, dim_out * 2) | |
) if exists(cond_dim) else None | |
self.block1 = Block(dim, dim_out) | |
self.block2 = Block(dim_out, dim_out) | |
def forward(self, x, cond_emb=None): | |
scale_shift = None | |
if exists(self.mlp): | |
assert exists(cond_emb), 'time emb must be passed in' | |
cond_emb = self.mlp(cond_emb) | |
#cond_emb = rearrange(cond_emb, 'b f c -> b f 1 c') | |
scale_shift = cond_emb.chunk(2, dim=-1) | |
h = self.block1(x, scale_shift=scale_shift) | |
h = self.block2(h) | |
return h + x | |
from rotary_embedding_torch import RotaryEmbedding | |
class SimpleTransModel(BaseTemperalPointModel): | |
""" | |
A simple model that processes a point cloud by applying a series of MLPs to each point | |
individually, along with some pooled global features. | |
""" | |
def get_layers(self): | |
# self.input_projection = nn.Linear( | |
# in_features=51, | |
# out_features=self.dim | |
# ) | |
self.input_projection = nn.Linear( | |
in_features=70, | |
out_features=self.dim | |
) | |
cond_dim = 512 + self.timestep_embed_dim | |
num_head = self.dim//64 | |
rotary_emb = RotaryEmbedding(min(32, num_head)) | |
self.time_rel_pos_bias = RelativePositionBias(heads=num_head, max_distance=128) # realistically will not be able to generate that many frames of video... yet | |
temporal_casual_attn = lambda dim: Attention(dim, heads=num_head, casual_attn=False,rotary_emb=rotary_emb) | |
cond_block= partial(ResnetBlock,cond_dim=cond_dim) | |
layers = nn.ModuleList([]) | |
for _ in range(self.num_layers): | |
layers.append(nn.ModuleList([ | |
cond_block(self.dim,self.dim), | |
cond_block(self.dim,self.dim), | |
Residual(PreNorm(self.dim,temporal_casual_attn(self.dim))) | |
])) | |
return layers | |
def forward(self, inputs: torch.Tensor, timesteps: torch.Tensor, context=None): | |
""" | |
Apply the model to an input batch. | |
:param x: an [N x C x ...] Tensor of inputs. | |
:param timesteps: a 1-D batch of timesteps. | |
:param context: conditioning plugged in via crossattn | |
""" | |
# Prepare inputs | |
batch, num_frames, channels = inputs.size() | |
device = inputs.device | |
#assert channels==3 | |
# Positional encoding of point coords | |
# inputs=rearrange(inputs,'b f p c->(b f) p c') | |
# pos_emb=self.positional_encoding(inputs) | |
x = self.input_projection(inputs) | |
#x = rearrange(x,'(b f) p c-> b f p c',b=batch) | |
t_emb = self.time_mlp(timesteps) if exists(self.time_mlp) else None | |
t_emb = t_emb[:,None,:].expand(-1, num_frames, -1) # b f c | |
if context is not None: | |
t_emb = torch.cat([t_emb, context],-1) | |
time_rel_pos_bias = self.time_rel_pos_bias(num_frames, device=device) | |
for block1, block2, temporal_casual_attn in self.layers: | |
x = block1(x, t_emb) | |
x = block2(x, t_emb) | |
x = temporal_casual_attn(x, pos_bias=time_rel_pos_bias) | |
# Project | |
x = self.output_projection(x) | |
return x | |
class SimpleTemperalPointModel(BaseTemperalPointModel): | |
""" | |
A simple model that processes a point cloud by applying a series of MLPs to each point | |
individually, along with some pooled global features. | |
""" | |
def get_layers(self): | |
audio_dim = 512 | |
cond_dim = audio_dim + self.timestep_embed_dim | |
num_head = 4 | |
rotary_emb = RotaryEmbedding(min(32, num_head)) | |
self.time_rel_pos_bias = RelativePositionBias(heads=num_head, max_distance=128) # realistically will not be able to generate that many frames of video... yet | |
temporal_casual_attn = lambda dim: EinopsToAndFrom('b f p c', 'b p f c', Attention(dim, heads=num_head, casual_attn=False, rotary_emb = rotary_emb)) | |
spatial_kp_attn= lambda dim: EinopsToAndFrom('b f p c', 'b f p c', Attention(dim, heads=num_head)) | |
cond_block= partial(ResnetBlock,cond_dim=cond_dim) | |
layers = nn.ModuleList([]) | |
for _ in range(self.num_layers): | |
layers.append(nn.ModuleList([ | |
cond_block(self.dim,self.dim), | |
cond_block(self.dim,self.dim), | |
Residual(PreNorm(self.dim,spatial_kp_attn(self.dim))), | |
Residual(PreNorm(self.dim,temporal_casual_attn(self.dim))) | |
])) | |
return layers | |
def forward(self, inputs: torch.Tensor, timesteps: torch.Tensor, context=None): | |
""" | |
Apply the model to an input batch. | |
:param x: an [N x C x ...] Tensor of inputs. | |
:param timesteps: a 1-D batch of timesteps. | |
:param context: conditioning plugged in via crossattn | |
""" | |
# Prepare inputs | |
batch, num_frames, num_points, channels = inputs.size() | |
device = inputs.device | |
#assert channels==3 | |
# Positional encoding of point coords | |
inputs=rearrange(inputs,'b f p c->(b f) p c') | |
pos_emb=self.positional_encoding(inputs) | |
x = self.input_projection(torch.cat([inputs, pos_emb], -1)) | |
x = rearrange(x,'(b f) p c-> b f p c',b=batch) | |
t_emb = self.time_mlp(timesteps) if exists(self.time_mlp) else None | |
t_emb = t_emb[:,None,:].expand(-1, num_frames, -1) # b f c | |
if context is not None: | |
t_emb = torch.cat([t_emb,context],-1) | |
time_rel_pos_bias = self.time_rel_pos_bias(num_frames, device=device) | |
for block1, block2, spatial_kp_attn, temporal_casual_attn in self.layers: | |
x = block1(x, t_emb) | |
x = block2(x, t_emb) | |
x = spatial_kp_attn(x) | |
x = temporal_casual_attn(x, pos_bias=time_rel_pos_bias) | |
# Project | |
x = self.output_projection(x) | |
return x | |