Techt3o's picture
1c0fe995ac3bf6c4bc83a727a73c46ab2d045729fb0abd53c4c78cd2b8282877
20ae9ff verified
raw
history blame
17.6 kB
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from . import altcorr, fastba, lietorch
from . import projective_ops as pops
from .lietorch import SE3
from .net import VONet
from .patchgraph import PatchGraph
from .utils import *
mp.set_start_method('spawn', True)
autocast = torch.cuda.amp.autocast
Id = SE3.Identity(1, device="cuda")
class DPVO:
def __init__(self, cfg, network, ht=480, wd=640, viz=False):
self.cfg = cfg
self.load_weights(network)
self.is_initialized = False
self.enable_timing = False
torch.set_num_threads(2)
self.M = self.cfg.PATCHES_PER_FRAME
self.N = self.cfg.BUFFER_SIZE
self.ht = ht # image height
self.wd = wd # image width
DIM = self.DIM
RES = self.RES
### state attributes ###
self.tlist = []
self.counter = 0
# keep track of global-BA calls
self.ran_global_ba = np.zeros(100000, dtype=bool)
ht = ht // RES
wd = wd // RES
# dummy image for visualization
self.image_ = torch.zeros(self.ht, self.wd, 3, dtype=torch.uint8, device="cpu")
### network attributes ###
if self.cfg.MIXED_PRECISION:
self.kwargs = kwargs = {"device": "cuda", "dtype": torch.half}
else:
self.kwargs = kwargs = {"device": "cuda", "dtype": torch.float}
### frame memory size ###
self.pmem = self.mem = 36 # 32 was too small given default settings
if self.cfg.LOOP_CLOSURE:
self.last_global_ba = -1000 # keep track of time since last global opt
self.pmem = self.cfg.MAX_EDGE_AGE # patch memory
self.imap_ = torch.zeros(self.pmem, self.M, DIM, **kwargs)
self.gmap_ = torch.zeros(self.pmem, self.M, 128, self.P, self.P, **kwargs)
self.pg = PatchGraph(self.cfg, self.P, self.DIM, self.pmem, **kwargs)
# classic backend
if self.cfg.CLASSIC_LOOP_CLOSURE:
self.load_long_term_loop_closure()
self.fmap1_ = torch.zeros(1, self.mem, 128, ht // 1, wd // 1, **kwargs)
self.fmap2_ = torch.zeros(1, self.mem, 128, ht // 4, wd // 4, **kwargs)
# feature pyramid
self.pyramid = (self.fmap1_, self.fmap2_)
self.viewer = None
if viz:
self.start_viewer()
def load_long_term_loop_closure(self):
try:
from .loop_closure.long_term import LongTermLoopClosure
self.long_term_lc = LongTermLoopClosure(self.cfg, self.pg)
except ModuleNotFoundError as e:
self.cfg.CLASSIC_LOOP_CLOSURE = False
print(f"WARNING: {e}")
def load_weights(self, network):
# load network from checkpoint file
if isinstance(network, str):
from collections import OrderedDict
state_dict = torch.load(network)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if "update.lmbda" not in k:
new_state_dict[k.replace('module.', '')] = v
self.network = VONet()
self.network.load_state_dict(new_state_dict)
else:
self.network = network
# steal network attributes
self.DIM = self.network.DIM
self.RES = self.network.RES
self.P = self.network.P
self.network.cuda()
self.network.eval()
def start_viewer(self):
from dpviewer import Viewer
intrinsics_ = torch.zeros(1, 4, dtype=torch.float32, device="cuda")
self.viewer = Viewer(
self.image_,
self.pg.poses_,
self.pg.points_,
self.pg.colors_,
intrinsics_)
@property
def poses(self):
return self.pg.poses_.view(1, self.N, 7)
@property
def patches(self):
return self.pg.patches_.view(1, self.N*self.M, 3, 3, 3)
@property
def intrinsics(self):
return self.pg.intrinsics_.view(1, self.N, 4)
@property
def ix(self):
return self.pg.index_.view(-1)
@property
def imap(self):
return self.imap_.view(1, self.pmem * self.M, self.DIM)
@property
def gmap(self):
return self.gmap_.view(1, self.pmem * self.M, 128, 3, 3)
@property
def n(self):
return self.pg.n
@n.setter
def n(self, val):
self.pg.n = val
@property
def m(self):
return self.pg.m
@m.setter
def m(self, val):
self.pg.m = val
def get_pose(self, t):
if t in self.traj:
return SE3(self.traj[t])
t0, dP = self.pg.delta[t]
return dP * self.get_pose(t0)
def terminate(self):
if self.cfg.CLASSIC_LOOP_CLOSURE:
self.long_term_lc.terminate(self.n)
if self.cfg.LOOP_CLOSURE:
self.append_factors(*self.pg.edges_loop())
for _ in range(12):
self.ran_global_ba[self.n] = False
self.update()
""" interpolate missing poses """
self.traj = {}
for i in range(self.n):
self.traj[self.pg.tstamps_[i]] = self.pg.poses_[i]
poses = [self.get_pose(t) for t in range(self.counter)]
poses = lietorch.stack(poses, dim=0)
poses = poses.inv().data.cpu().numpy()
tstamps = np.array(self.tlist, dtype=np.float64)
if self.viewer is not None:
self.viewer.join()
# Poses: x y z qx qy qz qw
return poses, tstamps
def corr(self, coords, indicies=None):
""" local correlation volume """
ii, jj = indicies if indicies is not None else (self.pg.kk, self.pg.jj)
ii1 = ii % (self.M * self.pmem)
jj1 = jj % (self.mem)
corr1 = altcorr.corr(self.gmap, self.pyramid[0], coords / 1, ii1, jj1, 3)
corr2 = altcorr.corr(self.gmap, self.pyramid[1], coords / 4, ii1, jj1, 3)
return torch.stack([corr1, corr2], -1).view(1, len(ii), -1)
def reproject(self, indicies=None):
""" reproject patch k from i -> j """
(ii, jj, kk) = indicies if indicies is not None else (self.pg.ii, self.pg.jj, self.pg.kk)
coords = pops.transform(SE3(self.poses), self.patches, self.intrinsics, ii, jj, kk)
return coords.permute(0, 1, 4, 2, 3).contiguous()
def append_factors(self, ii, jj):
self.pg.jj = torch.cat([self.pg.jj, jj])
self.pg.kk = torch.cat([self.pg.kk, ii])
self.pg.ii = torch.cat([self.pg.ii, self.ix[ii]])
net = torch.zeros(1, len(ii), self.DIM, **self.kwargs)
self.pg.net = torch.cat([self.pg.net, net], dim=1)
def remove_factors(self, m, store: bool):
assert self.pg.ii.numel() == self.pg.weight.shape[1]
if store:
self.pg.ii_inac = torch.cat((self.pg.ii_inac, self.pg.ii[m]))
self.pg.jj_inac = torch.cat((self.pg.jj_inac, self.pg.jj[m]))
self.pg.kk_inac = torch.cat((self.pg.kk_inac, self.pg.kk[m]))
self.pg.weight_inac = torch.cat((self.pg.weight_inac, self.pg.weight[:,m]), dim=1)
self.pg.target_inac = torch.cat((self.pg.target_inac, self.pg.target[:,m]), dim=1)
self.pg.weight = self.pg.weight[:,~m]
self.pg.target = self.pg.target[:,~m]
self.pg.ii = self.pg.ii[~m]
self.pg.jj = self.pg.jj[~m]
self.pg.kk = self.pg.kk[~m]
self.pg.net = self.pg.net[:,~m]
assert self.pg.ii.numel() == self.pg.weight.shape[1]
def motion_probe(self):
""" kinda hacky way to ensure enough motion for initialization """
kk = torch.arange(self.m-self.M, self.m, device="cuda")
jj = self.n * torch.ones_like(kk)
ii = self.ix[kk]
net = torch.zeros(1, len(ii), self.DIM, **self.kwargs)
coords = self.reproject(indicies=(ii, jj, kk))
with autocast(enabled=self.cfg.MIXED_PRECISION):
corr = self.corr(coords, indicies=(kk, jj))
ctx = self.imap[:,kk % (self.M * self.pmem)]
net, (delta, weight, _) = \
self.network.update(net, ctx, corr, None, ii, jj, kk)
return torch.quantile(delta.norm(dim=-1).float(), 0.5)
def motionmag(self, i, j):
k = (self.pg.ii == i) & (self.pg.jj == j)
ii = self.pg.ii[k]
jj = self.pg.jj[k]
kk = self.pg.kk[k]
flow, _ = pops.flow_mag(SE3(self.poses), self.patches, self.intrinsics, ii, jj, kk, beta=0.5)
return flow.mean().item()
def keyframe(self):
i = self.n - self.cfg.KEYFRAME_INDEX - 1
j = self.n - self.cfg.KEYFRAME_INDEX + 1
m = self.motionmag(i, j) + self.motionmag(j, i)
if m / 2 < self.cfg.KEYFRAME_THRESH:
k = self.n - self.cfg.KEYFRAME_INDEX
t0 = self.pg.tstamps_[k-1]
t1 = self.pg.tstamps_[k]
dP = SE3(self.pg.poses_[k]) * SE3(self.pg.poses_[k-1]).inv()
self.pg.delta[t1] = (t0, dP)
to_remove = (self.pg.ii == k) | (self.pg.jj == k)
self.remove_factors(to_remove, store=False)
self.pg.kk[self.pg.ii > k] -= self.M
self.pg.ii[self.pg.ii > k] -= 1
self.pg.jj[self.pg.jj > k] -= 1
for i in range(k, self.n-1):
self.pg.tstamps_[i] = self.pg.tstamps_[i+1]
self.pg.colors_[i] = self.pg.colors_[i+1]
self.pg.poses_[i] = self.pg.poses_[i+1]
self.pg.patches_[i] = self.pg.patches_[i+1]
self.pg.intrinsics_[i] = self.pg.intrinsics_[i+1]
self.imap_[i % self.pmem] = self.imap_[(i+1) % self.pmem]
self.gmap_[i % self.pmem] = self.gmap_[(i+1) % self.pmem]
self.fmap1_[0,i%self.mem] = self.fmap1_[0,(i+1)%self.mem]
self.fmap2_[0,i%self.mem] = self.fmap2_[0,(i+1)%self.mem]
self.n -= 1
self.m-= self.M
if self.cfg.CLASSIC_LOOP_CLOSURE:
self.long_term_lc.keyframe(k)
to_remove = self.ix[self.pg.kk] < self.n - self.cfg.REMOVAL_WINDOW # Remove edges falling outside the optimization window
if self.cfg.LOOP_CLOSURE:
# ...unless they are being used for loop closure
lc_edges = ((self.pg.jj - self.pg.ii) > 30) & (self.pg.jj > (self.n - self.cfg.OPTIMIZATION_WINDOW))
to_remove = to_remove & ~lc_edges
self.remove_factors(to_remove, store=True)
def __run_global_BA(self):
""" Global bundle adjustment
Includes both active and inactive edges """
full_target = torch.cat((self.pg.target_inac, self.pg.target), dim=1)
full_weight = torch.cat((self.pg.weight_inac, self.pg.weight), dim=1)
full_ii = torch.cat((self.pg.ii_inac, self.pg.ii))
full_jj = torch.cat((self.pg.jj_inac, self.pg.jj))
full_kk = torch.cat((self.pg.kk_inac, self.pg.kk))
self.pg.normalize()
lmbda = torch.as_tensor([1e-4], device="cuda")
t0 = self.pg.ii.min().item()
fastba.BA(self.poses, self.patches, self.intrinsics,
full_target, full_weight, lmbda, full_ii, full_jj, full_kk, t0, self.n, M=self.M, iterations=2, eff_impl=True)
self.ran_global_ba[self.n] = True
def update(self):
with Timer("other", enabled=self.enable_timing):
coords = self.reproject()
with autocast(enabled=True):
corr = self.corr(coords)
ctx = self.imap[:, self.pg.kk % (self.M * self.pmem)]
self.pg.net, (delta, weight, _) = \
self.network.update(self.pg.net, ctx, corr, None, self.pg.ii, self.pg.jj, self.pg.kk)
lmbda = torch.as_tensor([1e-4], device="cuda")
weight = weight.float()
target = coords[...,self.P//2,self.P//2] + delta.float()
self.pg.target = target
self.pg.weight = weight
with Timer("BA", enabled=self.enable_timing):
try:
# run global bundle adjustment if there exist long-range edges
if (self.pg.ii < self.n - self.cfg.REMOVAL_WINDOW - 1).any() and not self.ran_global_ba[self.n]:
self.__run_global_BA()
else:
t0 = self.n - self.cfg.OPTIMIZATION_WINDOW if self.is_initialized else 1
t0 = max(t0, 1)
fastba.BA(self.poses, self.patches, self.intrinsics,
target, weight, lmbda, self.pg.ii, self.pg.jj, self.pg.kk, t0, self.n, M=self.M, iterations=2, eff_impl=False)
except:
print("Warning BA failed...")
points = pops.point_cloud(SE3(self.poses), self.patches[:, :self.m], self.intrinsics, self.ix[:self.m])
points = (points[...,1,1,:3] / points[...,1,1,3:]).reshape(-1, 3)
self.pg.points_[:len(points)] = points[:]
def __edges_forw(self):
r=self.cfg.PATCH_LIFETIME
t0 = self.M * max((self.n - r), 0)
t1 = self.M * max((self.n - 1), 0)
return flatmeshgrid(
torch.arange(t0, t1, device="cuda"),
torch.arange(self.n-1, self.n, device="cuda"), indexing='ij')
def __edges_back(self):
r=self.cfg.PATCH_LIFETIME
t0 = self.M * max((self.n - 1), 0)
t1 = self.M * max((self.n - 0), 0)
return flatmeshgrid(torch.arange(t0, t1, device="cuda"),
torch.arange(max(self.n-r, 0), self.n, device="cuda"), indexing='ij')
def __call__(self, tstamp, image, intrinsics):
""" track new frame """
if self.cfg.CLASSIC_LOOP_CLOSURE:
self.long_term_lc(image, self.n)
if (self.n+1) >= self.N:
raise Exception(f'The buffer size is too small. You can increase it using "--opts BUFFER_SIZE={self.N*2}"')
if self.viewer is not None:
self.viewer.update_image(image.contiguous())
image = 2 * (image[None,None] / 255.0) - 0.5
with autocast(enabled=self.cfg.MIXED_PRECISION):
fmap, gmap, imap, patches, _, clr = \
self.network.patchify(image,
patches_per_image=self.cfg.PATCHES_PER_FRAME,
centroid_sel_strat=self.cfg.CENTROID_SEL_STRAT,
return_color=True)
### update state attributes ###
self.tlist.append(tstamp)
self.pg.tstamps_[self.n] = self.counter
self.pg.intrinsics_[self.n] = intrinsics / self.RES
# color info for visualization
clr = (clr[0,:,[2,1,0]] + 0.5) * (255.0 / 2)
self.pg.colors_[self.n] = clr.to(torch.uint8)
self.pg.index_[self.n + 1] = self.n + 1
self.pg.index_map_[self.n + 1] = self.m + self.M
if self.n > 1:
if self.cfg.MOTION_MODEL == 'DAMPED_LINEAR':
P1 = SE3(self.pg.poses_[self.n-1])
P2 = SE3(self.pg.poses_[self.n-2])
# To deal with varying camera hz
*_, a,b,c = [1]*3 + self.tlist
fac = (c-b) / (b-a)
xi = self.cfg.MOTION_DAMPING * fac * (P1 * P2.inv()).log()
tvec_qvec = (SE3.exp(xi) * P1).data
self.pg.poses_[self.n] = tvec_qvec
else:
tvec_qvec = self.poses[self.n-1]
self.pg.poses_[self.n] = tvec_qvec
# TODO better depth initialization
patches[:,:,2] = torch.rand_like(patches[:,:,2,0,0,None,None])
if self.is_initialized:
s = torch.median(self.pg.patches_[self.n-3:self.n,:,2])
patches[:,:,2] = s
self.pg.patches_[self.n] = patches
### update network attributes ###
self.imap_[self.n % self.pmem] = imap.squeeze()
self.gmap_[self.n % self.pmem] = gmap.squeeze()
self.fmap1_[:, self.n % self.mem] = F.avg_pool2d(fmap[0], 1, 1)
self.fmap2_[:, self.n % self.mem] = F.avg_pool2d(fmap[0], 4, 4)
self.counter += 1
if self.n > 0 and not self.is_initialized:
if self.motion_probe() < 2.0:
self.pg.delta[self.counter - 1] = (self.counter - 2, Id[0])
return
self.n += 1
self.m += self.M
if self.cfg.LOOP_CLOSURE:
if self.n - self.last_global_ba >= self.cfg.GLOBAL_OPT_FREQ:
""" Add loop closure factors """
lii, ljj = self.pg.edges_loop()
if lii.numel() > 0:
self.last_global_ba = self.n
self.append_factors(lii, ljj)
# Add forward and backward factors
self.append_factors(*self.__edges_forw())
self.append_factors(*self.__edges_back())
if self.n == 8 and not self.is_initialized:
self.is_initialized = True
for itr in range(12):
self.update()
elif self.is_initialized:
self.update()
self.keyframe()
if self.cfg.CLASSIC_LOOP_CLOSURE:
self.long_term_lc.attempt_loop_closure(self.n)
self.long_term_lc.lc_callback()