Techt3o's picture
1c0fe995ac3bf6c4bc83a727a73c46ab2d045729fb0abd53c4c78cd2b8282877
20ae9ff verified
raw
history blame
2.72 kB
from contextlib import ContextDecorator
import torch
import torch.nn.functional as F
all_times = []
class Timer(ContextDecorator):
def __init__(self, name, enabled=True):
self.name = name
self.enabled = enabled
if self.enabled:
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
def __enter__(self):
if self.enabled:
self.start.record()
def __exit__(self, type, value, traceback):
global all_times
if self.enabled:
self.end.record()
torch.cuda.synchronize()
elapsed = self.start.elapsed_time(self.end)
all_times.append(elapsed)
print(f"{self.name} {elapsed:.03f}")
def coords_grid(b, n, h, w, **kwargs):
""" coordinate grid """
x = torch.arange(0, w, dtype=torch.float, **kwargs)
y = torch.arange(0, h, dtype=torch.float, **kwargs)
coords = torch.stack(torch.meshgrid(y, x, indexing="ij"))
return coords[[1,0]].view(1, 1, 2, h, w).repeat(b, n, 1, 1, 1)
def coords_grid_with_index(d, **kwargs):
""" coordinate grid with frame index"""
b, n, h, w = d.shape
i = torch.ones_like(d)
x = torch.arange(0, w, dtype=torch.float, **kwargs)
y = torch.arange(0, h, dtype=torch.float, **kwargs)
y, x = torch.stack(torch.meshgrid(y, x, indexing="ij"))
y = y.view(1, 1, h, w).repeat(b, n, 1, 1)
x = x.view(1, 1, h, w).repeat(b, n, 1, 1)
coords = torch.stack([x, y, d], dim=2)
index = torch.arange(0, n, dtype=torch.float, **kwargs)
index = index.view(1, n, 1, 1, 1).repeat(b, 1, 1, h, w)
return coords, index
def patchify(x, patch_size=3):
""" extract patches from video """
b, n, c, h, w = x.shape
x = x.view(b*n, c, h, w)
y = F.unfold(x, patch_size)
y = y.transpose(1,2)
return y.reshape(b, -1, c, patch_size, patch_size)
def pyramidify(fmap, lvls=[1]):
""" turn fmap into a pyramid """
b, n, c, h, w = fmap.shape
pyramid = []
for lvl in lvls:
gmap = F.avg_pool2d(fmap.view(b*n, c, h, w), lvl, stride=lvl)
pyramid += [ gmap.view(b, n, c, h//lvl, w//lvl) ]
return pyramid
def all_pairs_exclusive(n, **kwargs):
ii, jj = torch.meshgrid(torch.arange(n, **kwargs), torch.arange(n, **kwargs))
k = ii != jj
return ii[k].reshape(-1), jj[k].reshape(-1)
def set_depth(patches, depth):
patches[...,2,:,:] = depth[...,None,None]
return patches
def flatmeshgrid(*args, **kwargs):
grid = torch.meshgrid(*args, **kwargs)
return (x.reshape(-1) for x in grid)