liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from copy import deepcopy
from functools import partial
from typing import Optional, Tuple, List, Any
from dataclasses import dataclass
from transformers import PretrainedConfig
from transformers.file_utils import ModelOutput
from dust3r.utils.misc import (
fill_default_args,
freeze_all_params,
is_symmetrized,
interleave,
transpose_to_landscape,
)
from dust3r.heads import head_factory
from dust3r.utils.camera import PoseEncoder
from dust3r.patch_embed import get_patch_embed
import dust3r.utils.path_to_croco # noqa: F401
from models.croco import CroCoNet, CrocoConfig # noqa
from dust3r.blocks import (
Block,
DecoderBlock,
Mlp,
Attention,
CrossAttention,
DropPath,
CustomDecoderBlock,
) # noqa
inf = float("inf")
from accelerate.logging import get_logger
printer = get_logger(__name__, log_level="DEBUG")
@dataclass
class ARCroco3DStereoOutput(ModelOutput):
"""
Custom output class for ARCroco3DStereo.
"""
ress: Optional[List[Any]] = None
views: Optional[List[Any]] = None
def strip_module(state_dict):
"""
Removes the 'module.' prefix from the keys of a state_dict.
Args:
state_dict (dict): The original state_dict with possible 'module.' prefixes.
Returns:
OrderedDict: A new state_dict with 'module.' prefixes removed.
"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] if k.startswith("module.") else k
new_state_dict[name] = v
return new_state_dict
def load_model(model_path, device, verbose=True):
if verbose:
print("... loading model from", model_path)
ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
args = ckpt["args"].model.replace(
"ManyAR_PatchEmbed", "PatchEmbedDust3R"
) # ManyAR only for aspect ratio not consistent
if "landscape_only" not in args:
args = args[:-2] + ", landscape_only=False))"
else:
args = args.replace(" ", "").replace(
"landscape_only=True", "landscape_only=False"
)
assert "landscape_only=False" in args
if verbose:
print(f"instantiating : {args}")
net = eval(args)
s = net.load_state_dict(ckpt["model"], strict=False)
if verbose:
print(s)
return net.to(device)
class ARCroco3DStereoConfig(PretrainedConfig):
model_type = "arcroco_3d_stereo"
def __init__(
self,
output_mode="pts3d",
head_type="linear", # or dpt
depth_mode=("exp", -float("inf"), float("inf")),
conf_mode=("exp", 1, float("inf")),
pose_mode=("exp", -float("inf"), float("inf")),
freeze="none",
landscape_only=True,
patch_embed_cls="PatchEmbedDust3R",
ray_enc_depth=2,
state_size=324,
local_mem_size=256,
state_pe="2d",
state_dec_num_heads=16,
depth_head=False,
rgb_head=False,
pose_conf_head=False,
pose_head=False,
**croco_kwargs,
):
super().__init__()
self.output_mode = output_mode
self.head_type = head_type
self.depth_mode = depth_mode
self.conf_mode = conf_mode
self.pose_mode = pose_mode
self.freeze = freeze
self.landscape_only = landscape_only
self.patch_embed_cls = patch_embed_cls
self.ray_enc_depth = ray_enc_depth
self.state_size = state_size
self.state_pe = state_pe
self.state_dec_num_heads = state_dec_num_heads
self.local_mem_size = local_mem_size
self.depth_head = depth_head
self.rgb_head = rgb_head
self.pose_conf_head = pose_conf_head
self.pose_head = pose_head
self.croco_kwargs = croco_kwargs
class LocalMemory(nn.Module):
def __init__(
self,
size,
k_dim,
v_dim,
num_heads,
depth=2,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
norm_mem=True,
rope=None,
) -> None:
super().__init__()
self.v_dim = v_dim
self.proj_q = nn.Linear(k_dim, v_dim)
self.masked_token = nn.Parameter(
torch.randn(1, 1, v_dim) * 0.2, requires_grad=True
)
self.mem = nn.Parameter(
torch.randn(1, size, 2 * v_dim) * 0.2, requires_grad=True
)
self.write_blocks = nn.ModuleList(
[
DecoderBlock(
2 * v_dim,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
attn_drop=attn_drop,
drop=drop,
drop_path=drop_path,
act_layer=act_layer,
norm_mem=norm_mem,
rope=rope,
)
for _ in range(depth)
]
)
self.read_blocks = nn.ModuleList(
[
DecoderBlock(
2 * v_dim,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
attn_drop=attn_drop,
drop=drop,
drop_path=drop_path,
act_layer=act_layer,
norm_mem=norm_mem,
rope=rope,
)
for _ in range(depth)
]
)
def update_mem(self, mem, feat_k, feat_v):
"""
mem_k: [B, size, C]
mem_v: [B, size, C]
feat_k: [B, 1, C]
feat_v: [B, 1, C]
"""
feat_k = self.proj_q(feat_k) # [B, 1, C]
feat = torch.cat([feat_k, feat_v], dim=-1)
for blk in self.write_blocks:
mem, _ = blk(mem, feat, None, None)
return mem
def inquire(self, query, mem):
x = self.proj_q(query) # [B, 1, C]
x = torch.cat([x, self.masked_token.expand(x.shape[0], -1, -1)], dim=-1)
for blk in self.read_blocks:
x, _ = blk(x, mem, None, None)
return x[..., -self.v_dim :]
class ARCroco3DStereo(CroCoNet):
config_class = ARCroco3DStereoConfig
base_model_prefix = "arcroco3dstereo"
supports_gradient_checkpointing = True
def __init__(self, config: ARCroco3DStereoConfig):
self.gradient_checkpointing = False
self.fixed_input_length = True
config.croco_kwargs = fill_default_args(
config.croco_kwargs, CrocoConfig.__init__
)
self.config = config
self.patch_embed_cls = config.patch_embed_cls
self.croco_args = config.croco_kwargs
croco_cfg = CrocoConfig(**self.croco_args)
super().__init__(croco_cfg)
self.enc_blocks_ray_map = nn.ModuleList(
[
Block(
self.enc_embed_dim,
16,
4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
rope=self.rope,
)
for _ in range(config.ray_enc_depth)
]
)
self.enc_norm_ray_map = nn.LayerNorm(self.enc_embed_dim, eps=1e-6)
self.dec_num_heads = self.croco_args["dec_num_heads"]
self.pose_head_flag = config.pose_head
if self.pose_head_flag:
self.pose_token = nn.Parameter(
torch.randn(1, 1, self.dec_embed_dim) * 0.02, requires_grad=True
)
self.pose_retriever = LocalMemory(
size=config.local_mem_size,
k_dim=self.enc_embed_dim,
v_dim=self.dec_embed_dim,
num_heads=self.dec_num_heads,
mlp_ratio=4,
qkv_bias=True,
attn_drop=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
rope=None,
)
self.register_tokens = nn.Embedding(config.state_size, self.enc_embed_dim)
self.state_size = config.state_size
self.state_pe = config.state_pe
self.masked_img_token = nn.Parameter(
torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True
)
self.masked_ray_map_token = nn.Parameter(
torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True
)
self._set_state_decoder(
self.enc_embed_dim,
self.dec_embed_dim,
config.state_dec_num_heads,
self.dec_depth,
self.croco_args.get("mlp_ratio", None),
self.croco_args.get("norm_layer", None),
self.croco_args.get("norm_im2_in_dec", None),
)
self.set_downstream_head(
config.output_mode,
config.head_type,
config.landscape_only,
config.depth_mode,
config.conf_mode,
config.pose_mode,
config.depth_head,
config.rgb_head,
config.pose_conf_head,
config.pose_head,
**self.croco_args,
)
self.set_freeze(config.freeze)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
if os.path.isfile(pretrained_model_name_or_path):
return load_model(pretrained_model_name_or_path, device="cpu")
else:
try:
model = super(ARCroco3DStereo, cls).from_pretrained(
pretrained_model_name_or_path, **kw
)
except TypeError as e:
raise Exception(
f"tried to load {pretrained_model_name_or_path} from huggingface, but failed"
)
return model
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
self.patch_embed = get_patch_embed(
self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3
)
self.patch_embed_ray_map = get_patch_embed(
self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=6
)
def _set_decoder(
self,
enc_embed_dim,
dec_embed_dim,
dec_num_heads,
dec_depth,
mlp_ratio,
norm_layer,
norm_im2_in_dec,
):
self.dec_depth = dec_depth
self.dec_embed_dim = dec_embed_dim
self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
self.dec_blocks = nn.ModuleList(
[
DecoderBlock(
dec_embed_dim,
dec_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
norm_mem=norm_im2_in_dec,
rope=self.rope,
)
for i in range(dec_depth)
]
)
self.dec_norm = norm_layer(dec_embed_dim)
def _set_state_decoder(
self,
enc_embed_dim,
dec_embed_dim,
dec_num_heads,
dec_depth,
mlp_ratio,
norm_layer,
norm_im2_in_dec,
):
self.dec_depth_state = dec_depth
self.dec_embed_dim_state = dec_embed_dim
self.decoder_embed_state = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
self.dec_blocks_state = nn.ModuleList(
[
DecoderBlock(
dec_embed_dim,
dec_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
norm_mem=norm_im2_in_dec,
rope=self.rope,
)
for i in range(dec_depth)
]
)
self.dec_norm_state = norm_layer(dec_embed_dim)
def load_state_dict(self, ckpt, **kw):
if all(k.startswith("module") for k in ckpt):
ckpt = strip_module(ckpt)
new_ckpt = dict(ckpt)
if not any(k.startswith("dec_blocks_state") for k in ckpt):
for key, value in ckpt.items():
if key.startswith("dec_blocks"):
new_ckpt[key.replace("dec_blocks", "dec_blocks_state")] = value
try:
return super().load_state_dict(new_ckpt, **kw)
except:
try:
new_new_ckpt = {
k: v
for k, v in new_ckpt.items()
if not k.startswith("dec_blocks")
and not k.startswith("dec_norm")
and not k.startswith("decoder_embed")
}
return super().load_state_dict(new_new_ckpt, **kw)
except:
new_new_ckpt = {}
for key in new_ckpt:
if key in self.state_dict():
if new_ckpt[key].size() == self.state_dict()[key].size():
new_new_ckpt[key] = new_ckpt[key]
else:
printer.info(
f"Skipping '{key}': size mismatch (ckpt: {new_ckpt[key].size()}, model: {self.state_dict()[key].size()})"
)
else:
printer.info(f"Skipping '{key}': not found in model")
return super().load_state_dict(new_new_ckpt, **kw)
def set_freeze(self, freeze): # this is for use by downstream models
self.freeze = freeze
to_be_frozen = {
"none": [],
"mask": [self.mask_token] if hasattr(self, "mask_token") else [],
"encoder": [
self.patch_embed,
self.patch_embed_ray_map,
self.masked_img_token,
self.masked_ray_map_token,
self.enc_blocks,
self.enc_blocks_ray_map,
self.enc_norm,
self.enc_norm_ray_map,
],
"encoder_and_head": [
self.patch_embed,
self.patch_embed_ray_map,
self.masked_img_token,
self.masked_ray_map_token,
self.enc_blocks,
self.enc_blocks_ray_map,
self.enc_norm,
self.enc_norm_ray_map,
self.downstream_head,
],
"encoder_and_decoder": [
self.patch_embed,
self.patch_embed_ray_map,
self.masked_img_token,
self.masked_ray_map_token,
self.enc_blocks,
self.enc_blocks_ray_map,
self.enc_norm,
self.enc_norm_ray_map,
self.dec_blocks,
self.dec_blocks_state,
self.pose_retriever,
self.pose_token,
self.register_tokens,
self.decoder_embed_state,
self.decoder_embed,
self.dec_norm,
self.dec_norm_state,
],
"decoder": [
self.dec_blocks,
self.dec_blocks_state,
self.pose_retriever,
self.pose_token,
],
}
freeze_all_params(to_be_frozen[freeze])
def _set_prediction_head(self, *args, **kwargs):
"""No prediction head"""
return
def set_downstream_head(
self,
output_mode,
head_type,
landscape_only,
depth_mode,
conf_mode,
pose_mode,
depth_head,
rgb_head,
pose_conf_head,
pose_head,
patch_size,
img_size,
**kw,
):
assert (
img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0
), f"{img_size=} must be multiple of {patch_size=}"
self.output_mode = output_mode
self.head_type = head_type
self.depth_mode = depth_mode
self.conf_mode = conf_mode
self.pose_mode = pose_mode
self.downstream_head = head_factory(
head_type,
output_mode,
self,
has_conf=bool(conf_mode),
has_depth=bool(depth_head),
has_rgb=bool(rgb_head),
has_pose_conf=bool(pose_conf_head),
has_pose=bool(pose_head),
)
self.head = transpose_to_landscape(
self.downstream_head, activate=landscape_only
)
def _encode_image(self, image, true_shape):
x, pos = self.patch_embed(image, true_shape=true_shape)
assert self.enc_pos_embed is None
for blk in self.enc_blocks:
if self.gradient_checkpointing and self.training:
x = checkpoint(blk, x, pos, use_reentrant=False)
else:
x = blk(x, pos)
x = self.enc_norm(x)
return [x], pos, None
def _encode_ray_map(self, ray_map, true_shape):
x, pos = self.patch_embed_ray_map(ray_map, true_shape=true_shape)
assert self.enc_pos_embed is None
for blk in self.enc_blocks_ray_map:
if self.gradient_checkpointing and self.training:
x = checkpoint(blk, x, pos, use_reentrant=False)
else:
x = blk(x, pos)
x = self.enc_norm_ray_map(x)
return [x], pos, None
def _encode_state(self, image_tokens, image_pos):
batch_size = image_tokens.shape[0]
state_feat = self.register_tokens(
torch.arange(self.state_size, device=image_pos.device)
)
if self.state_pe == "1d":
state_pos = (
torch.tensor(
[[i, i] for i in range(self.state_size)],
dtype=image_pos.dtype,
device=image_pos.device,
)[None]
.expand(batch_size, -1, -1)
.contiguous()
) # .long()
elif self.state_pe == "2d":
width = int(self.state_size**0.5)
width = width + 1 if width % 2 == 1 else width
state_pos = (
torch.tensor(
[[i // width, i % width] for i in range(self.state_size)],
dtype=image_pos.dtype,
device=image_pos.device,
)[None]
.expand(batch_size, -1, -1)
.contiguous()
)
elif self.state_pe == "none":
state_pos = None
state_feat = state_feat[None].expand(batch_size, -1, -1)
return state_feat, state_pos, None
def _encode_views(self, views, img_mask=None, ray_mask=None):
device = views[0]["img"].device
batch_size = views[0]["img"].shape[0]
given = True
if img_mask is None and ray_mask is None:
given = False
if not given:
img_mask = torch.stack(
[view["img_mask"] for view in views], dim=0
) # Shape: (num_views, batch_size)
ray_mask = torch.stack(
[view["ray_mask"] for view in views], dim=0
) # Shape: (num_views, batch_size)
imgs = torch.stack(
[view["img"] for view in views], dim=0
) # Shape: (num_views, batch_size, C, H, W)
ray_maps = torch.stack(
[view["ray_map"] for view in views], dim=0
) # Shape: (num_views, batch_size, H, W, C)
shapes = []
for view in views:
if "true_shape" in view:
shapes.append(view["true_shape"])
else:
shape = torch.tensor(view["img"].shape[-2:], device=device)
shapes.append(shape.unsqueeze(0).repeat(batch_size, 1))
shapes = torch.stack(shapes, dim=0).to(
imgs.device
) # Shape: (num_views, batch_size, 2)
imgs = imgs.view(
-1, *imgs.shape[2:]
) # Shape: (num_views * batch_size, C, H, W)
ray_maps = ray_maps.view(
-1, *ray_maps.shape[2:]
) # Shape: (num_views * batch_size, H, W, C)
shapes = shapes.view(-1, 2) # Shape: (num_views * batch_size, 2)
img_masks_flat = img_mask.view(-1) # Shape: (num_views * batch_size)
ray_masks_flat = ray_mask.view(-1)
selected_imgs = imgs[img_masks_flat]
selected_shapes = shapes[img_masks_flat]
if selected_imgs.size(0) > 0:
img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes)
else:
raise NotImplementedError
full_out = [
torch.zeros(
len(views) * batch_size, *img_out[0].shape[1:], device=img_out[0].device
)
for _ in range(len(img_out))
]
full_pos = torch.zeros(
len(views) * batch_size,
*img_pos.shape[1:],
device=img_pos.device,
dtype=img_pos.dtype,
)
for i in range(len(img_out)):
full_out[i][img_masks_flat] += img_out[i]
full_out[i][~img_masks_flat] += self.masked_img_token
full_pos[img_masks_flat] += img_pos
ray_maps = ray_maps.permute(0, 3, 1, 2) # Change shape to (N, C, H, W)
selected_ray_maps = ray_maps[ray_masks_flat]
selected_shapes_ray = shapes[ray_masks_flat]
if selected_ray_maps.size(0) > 0:
ray_out, ray_pos, _ = self._encode_ray_map(
selected_ray_maps, selected_shapes_ray
)
assert len(ray_out) == len(full_out), f"{len(ray_out)}, {len(full_out)}"
for i in range(len(ray_out)):
full_out[i][ray_masks_flat] += ray_out[i]
full_out[i][~ray_masks_flat] += self.masked_ray_map_token
full_pos[ray_masks_flat] += (
ray_pos * (~img_masks_flat[ray_masks_flat][:, None, None]).long()
)
else:
raymaps = torch.zeros(
1, 6, imgs[0].shape[-2], imgs[0].shape[-1], device=img_out[0].device
)
ray_mask_flat = torch.zeros_like(img_masks_flat)
ray_mask_flat[:1] = True
ray_out, ray_pos, _ = self._encode_ray_map(raymaps, shapes[ray_mask_flat])
for i in range(len(ray_out)):
full_out[i][ray_mask_flat] += ray_out[i] * 0.0
full_out[i][~ray_mask_flat] += self.masked_ray_map_token * 0.0
return (
shapes.chunk(len(views), dim=0),
[out.chunk(len(views), dim=0) for out in full_out],
full_pos.chunk(len(views), dim=0),
)
def _decoder(self, f_state, pos_state, f_img, pos_img, f_pose, pos_pose):
final_output = [(f_state, f_img)] # before projection
assert f_state.shape[-1] == self.dec_embed_dim
f_img = self.decoder_embed(f_img)
if self.pose_head_flag:
assert f_pose is not None and pos_pose is not None
f_img = torch.cat([f_pose, f_img], dim=1)
pos_img = torch.cat([pos_pose, pos_img], dim=1)
final_output.append((f_state, f_img))
for blk_state, blk_img in zip(self.dec_blocks_state, self.dec_blocks):
if (
self.gradient_checkpointing
and self.training
and torch.is_grad_enabled()
):
f_state, _ = checkpoint(
blk_state,
*final_output[-1][::+1],
pos_state,
pos_img,
use_reentrant=not self.fixed_input_length,
)
f_img, _ = checkpoint(
blk_img,
*final_output[-1][::-1],
pos_img,
pos_state,
use_reentrant=not self.fixed_input_length,
)
else:
f_state, _ = blk_state(*final_output[-1][::+1], pos_state, pos_img)
f_img, _ = blk_img(*final_output[-1][::-1], pos_img, pos_state)
final_output.append((f_state, f_img))
del final_output[1] # duplicate with final_output[0]
final_output[-1] = (
self.dec_norm_state(final_output[-1][0]),
self.dec_norm(final_output[-1][1]),
)
return zip(*final_output)
def _downstream_head(self, decout, img_shape, **kwargs):
B, S, D = decout[-1].shape
head = getattr(self, f"head")
return head(decout, img_shape, **kwargs)
def _init_state(self, image_tokens, image_pos):
"""
Current Version: input the first frame img feature and pose to initialize the state feature and pose
"""
state_feat, state_pos, _ = self._encode_state(image_tokens, image_pos)
state_feat = self.decoder_embed_state(state_feat)
return state_feat, state_pos
def _recurrent_rollout(
self,
state_feat,
state_pos,
current_feat,
current_pos,
pose_feat,
pose_pos,
init_state_feat,
img_mask=None,
reset_mask=None,
update=None,
):
new_state_feat, dec = self._decoder(
state_feat, state_pos, current_feat, current_pos, pose_feat, pose_pos
)
new_state_feat = new_state_feat[-1]
return new_state_feat, dec
def _get_img_level_feat(self, feat):
return torch.mean(feat, dim=1, keepdim=True)
def _forward_encoder(self, views):
shape, feat_ls, pos = self._encode_views(views)
feat = feat_ls[-1]
state_feat, state_pos = self._init_state(feat[0], pos[0])
mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1)
init_state_feat = state_feat.clone()
init_mem = mem.clone()
return (feat, pos, shape), (
init_state_feat,
init_mem,
state_feat,
state_pos,
mem,
)
def _forward_decoder_step(
self,
views,
i,
feat_i,
pos_i,
shape_i,
init_state_feat,
init_mem,
state_feat,
state_pos,
mem,
):
if self.pose_head_flag:
global_img_feat_i = self._get_img_level_feat(feat_i)
if i == 0:
pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
else:
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
pose_pos_i = -torch.ones(
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
)
else:
pose_feat_i = None
pose_pos_i = None
new_state_feat, dec = self._recurrent_rollout(
state_feat,
state_pos,
feat_i,
pos_i,
pose_feat_i,
pose_pos_i,
init_state_feat,
img_mask=views[i]["img_mask"],
reset_mask=views[i]["reset"],
update=views[i].get("update", None),
)
out_pose_feat_i = dec[-1][:, 0:1]
new_mem = self.pose_retriever.update_mem(
mem, global_img_feat_i, out_pose_feat_i
)
head_input = [
dec[0].float(),
dec[self.dec_depth * 2 // 4][:, 1:].float(),
dec[self.dec_depth * 3 // 4][:, 1:].float(),
dec[self.dec_depth].float(),
]
res = self._downstream_head(head_input, shape_i, pos=pos_i)
img_mask = views[i]["img_mask"]
update = views[i].get("update", None)
if update is not None:
update_mask = img_mask & update # if don't update, then whatever img_mask
else:
update_mask = img_mask
update_mask = update_mask[:, None, None].float()
state_feat = new_state_feat * update_mask + state_feat * (
1 - update_mask
) # update global state
mem = new_mem * update_mask + mem * (1 - update_mask) # then update local state
reset_mask = views[i]["reset"]
if reset_mask is not None:
reset_mask = reset_mask[:, None, None].float()
state_feat = init_state_feat * reset_mask + state_feat * (1 - reset_mask)
mem = init_mem * reset_mask + mem * (1 - reset_mask)
return res, (state_feat, mem)
def _forward_impl(self, views, ret_state=False):
shape, feat_ls, pos = self._encode_views(views)
feat = feat_ls[-1]
state_feat, state_pos = self._init_state(feat[0], pos[0])
mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1)
init_state_feat = state_feat.clone()
init_mem = mem.clone()
all_state_args = [(state_feat, state_pos, init_state_feat, mem, init_mem)]
ress = []
for i in range(len(views)):
feat_i = feat[i]
pos_i = pos[i]
if self.pose_head_flag:
global_img_feat_i = self._get_img_level_feat(feat_i)
if i == 0:
pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
else:
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
pose_pos_i = -torch.ones(
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
)
else:
pose_feat_i = None
pose_pos_i = None
new_state_feat, dec = self._recurrent_rollout(
state_feat,
state_pos,
feat_i,
pos_i,
pose_feat_i,
pose_pos_i,
init_state_feat,
img_mask=views[i]["img_mask"],
reset_mask=views[i]["reset"],
update=views[i].get("update", None),
)
out_pose_feat_i = dec[-1][:, 0:1]
new_mem = self.pose_retriever.update_mem(
mem, global_img_feat_i, out_pose_feat_i
)
assert len(dec) == self.dec_depth + 1
head_input = [
dec[0].float(),
dec[self.dec_depth * 2 // 4][:, 1:].float(),
dec[self.dec_depth * 3 // 4][:, 1:].float(),
dec[self.dec_depth].float(),
]
res = self._downstream_head(head_input, shape[i], pos=pos_i)
ress.append(res)
img_mask = views[i]["img_mask"]
update = views[i].get("update", None)
if update is not None:
update_mask = (
img_mask & update
) # if don't update, then whatever img_mask
else:
update_mask = img_mask
update_mask = update_mask[:, None, None].float()
state_feat = new_state_feat * update_mask + state_feat * (
1 - update_mask
) # update global state
mem = new_mem * update_mask + mem * (
1 - update_mask
) # then update local state
reset_mask = views[i]["reset"]
if reset_mask is not None:
reset_mask = reset_mask[:, None, None].float()
state_feat = init_state_feat * reset_mask + state_feat * (
1 - reset_mask
)
mem = init_mem * reset_mask + mem * (1 - reset_mask)
all_state_args.append(
(state_feat, state_pos, init_state_feat, mem, init_mem)
)
if ret_state:
return ress, views, all_state_args
return ress, views
def forward(self, views, ret_state=False):
if ret_state:
ress, views, state_args = self._forward_impl(views, ret_state=ret_state)
return ARCroco3DStereoOutput(ress=ress, views=views), state_args
else:
ress, views = self._forward_impl(views, ret_state=ret_state)
return ARCroco3DStereoOutput(ress=ress, views=views)
def inference_step(
self, view, state_feat, state_pos, init_state_feat, mem, init_mem
):
batch_size = view["img"].shape[0]
raymaps = []
shapes = []
for j in range(batch_size):
assert view["ray_mask"][j]
raymap = view["ray_map"][[j]].permute(0, 3, 1, 2)
raymaps.append(raymap)
shapes.append(
view.get(
"true_shape",
torch.tensor(view["ray_map"].shape[-2:])[None].repeat(
view["ray_map"].shape[0], 1
),
)[[j]]
)
raymaps = torch.cat(raymaps, dim=0)
shape = torch.cat(shapes, dim=0).to(raymaps.device)
feat_ls, pos, _ = self._encode_ray_map(raymaps, shapes)
feat_i = feat_ls[-1]
pos_i = pos
if self.pose_head_flag:
global_img_feat_i = self._get_img_level_feat(feat_i)
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
pose_pos_i = -torch.ones(
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
)
else:
pose_feat_i = None
pose_pos_i = None
new_state_feat, dec = self._recurrent_rollout(
state_feat,
state_pos,
feat_i,
pos_i,
pose_feat_i,
pose_pos_i,
init_state_feat,
img_mask=view["img_mask"],
reset_mask=view["reset"],
update=view.get("update", None),
)
out_pose_feat_i = dec[-1][:, 0:1]
new_mem = self.pose_retriever.update_mem(
mem, global_img_feat_i, out_pose_feat_i
)
assert len(dec) == self.dec_depth + 1
head_input = [
dec[0].float(),
dec[self.dec_depth * 2 // 4][:, 1:].float(),
dec[self.dec_depth * 3 // 4][:, 1:].float(),
dec[self.dec_depth].float(),
]
res = self._downstream_head(head_input, shape, pos=pos_i)
return res, view
def forward_recurrent(self, views, device, ret_state=False):
ress = []
all_state_args = []
for i, view in enumerate(views):
device = view["img"].device
batch_size = view["img"].shape[0]
img_mask = view["img_mask"].reshape(
-1, batch_size
) # Shape: (1, batch_size)
ray_mask = view["ray_mask"].reshape(
-1, batch_size
) # Shape: (1, batch_size)
imgs = view["img"].unsqueeze(0) # Shape: (1, batch_size, C, H, W)
ray_maps = view["ray_map"].unsqueeze(
0
) # Shape: (num_views, batch_size, H, W, C)
shapes = (
view["true_shape"].unsqueeze(0)
if "true_shape" in view
else torch.tensor(view["img"].shape[-2:], device=device)
.unsqueeze(0)
.repeat(batch_size, 1)
.unsqueeze(0)
) # Shape: (num_views, batch_size, 2)
imgs = imgs.view(
-1, *imgs.shape[2:]
) # Shape: (num_views * batch_size, C, H, W)
ray_maps = ray_maps.view(
-1, *ray_maps.shape[2:]
) # Shape: (num_views * batch_size, H, W, C)
shapes = shapes.view(-1, 2).to(
imgs.device
) # Shape: (num_views * batch_size, 2)
img_masks_flat = img_mask.view(-1) # Shape: (num_views * batch_size)
ray_masks_flat = ray_mask.view(-1)
selected_imgs = imgs[img_masks_flat]
selected_shapes = shapes[img_masks_flat]
if selected_imgs.size(0) > 0:
img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes)
else:
img_out, img_pos = None, None
ray_maps = ray_maps.permute(0, 3, 1, 2) # Change shape to (N, C, H, W)
selected_ray_maps = ray_maps[ray_masks_flat]
selected_shapes_ray = shapes[ray_masks_flat]
if selected_ray_maps.size(0) > 0:
ray_out, ray_pos, _ = self._encode_ray_map(
selected_ray_maps, selected_shapes_ray
)
else:
ray_out, ray_pos = None, None
shape = shapes
if img_out is not None and ray_out is None:
feat_i = img_out[-1]
pos_i = img_pos
elif img_out is None and ray_out is not None:
feat_i = ray_out[-1]
pos_i = ray_pos
elif img_out is not None and ray_out is not None:
feat_i = img_out[-1] + ray_out[-1]
pos_i = img_pos
else:
raise NotImplementedError
if i == 0:
state_feat, state_pos = self._init_state(feat_i, pos_i)
mem = self.pose_retriever.mem.expand(feat_i.shape[0], -1, -1)
init_state_feat = state_feat.clone()
init_mem = mem.clone()
all_state_args.append(
(state_feat, state_pos, init_state_feat, mem, init_mem)
)
if self.pose_head_flag:
global_img_feat_i = self._get_img_level_feat(feat_i)
if i == 0:
pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1)
else:
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem)
pose_pos_i = -torch.ones(
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype
)
else:
pose_feat_i = None
pose_pos_i = None
new_state_feat, dec = self._recurrent_rollout(
state_feat,
state_pos,
feat_i,
pos_i,
pose_feat_i,
pose_pos_i,
init_state_feat,
img_mask=view["img_mask"],
reset_mask=view["reset"],
update=view.get("update", None),
)
out_pose_feat_i = dec[-1][:, 0:1]
new_mem = self.pose_retriever.update_mem(
mem, global_img_feat_i, out_pose_feat_i
)
assert len(dec) == self.dec_depth + 1
head_input = [
dec[0].float(),
dec[self.dec_depth * 2 // 4][:, 1:].float(),
dec[self.dec_depth * 3 // 4][:, 1:].float(),
dec[self.dec_depth].float(),
]
res = self._downstream_head(head_input, shape, pos=pos_i)
ress.append(res)
img_mask = view["img_mask"]
update = view.get("update", None)
if update is not None:
update_mask = (
img_mask & update
) # if don't update, then whatever img_mask
else:
update_mask = img_mask
update_mask = update_mask[:, None, None].float()
state_feat = new_state_feat * update_mask + state_feat * (
1 - update_mask
) # update global state
mem = new_mem * update_mask + mem * (
1 - update_mask
) # then update local state
reset_mask = view["reset"]
if reset_mask is not None:
reset_mask = reset_mask[:, None, None].float()
state_feat = init_state_feat * reset_mask + state_feat * (
1 - reset_mask
)
mem = init_mem * reset_mask + mem * (1 - reset_mask)
all_state_args.append(
(state_feat, state_pos, init_state_feat, mem, init_mem)
)
if ret_state:
return ress, views, all_state_args
return ress, views
if __name__ == "__main__":
print(ARCroco3DStereo.mro())
cfg = ARCroco3DStereoConfig(
state_size=256,
pos_embed="RoPE100",
rgb_head=True,
pose_head=True,
img_size=(224, 224),
head_type="linear",
output_mode="pts3d+pose",
depth_mode=("exp", -inf, inf),
conf_mode=("exp", 1, inf),
pose_mode=("exp", -inf, inf),
enc_embed_dim=1024,
enc_depth=24,
enc_num_heads=16,
dec_embed_dim=768,
dec_depth=12,
dec_num_heads=12,
)
ARCroco3DStereo(cfg)