Spaces:
Runtime error
Runtime error
import sys | |
import os | |
sys.path.append(os.path.dirname(os.path.dirname(__file__))) | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.checkpoint import checkpoint | |
from copy import deepcopy | |
from functools import partial | |
from typing import Optional, Tuple, List, Any | |
from dataclasses import dataclass | |
from transformers import PretrainedConfig | |
from transformers.file_utils import ModelOutput | |
from dust3r.utils.misc import ( | |
fill_default_args, | |
freeze_all_params, | |
is_symmetrized, | |
interleave, | |
transpose_to_landscape, | |
) | |
from dust3r.heads import head_factory | |
from dust3r.utils.camera import PoseEncoder | |
from dust3r.patch_embed import get_patch_embed | |
import dust3r.utils.path_to_croco # noqa: F401 | |
from models.croco import CroCoNet, CrocoConfig # noqa | |
from dust3r.blocks import ( | |
Block, | |
DecoderBlock, | |
Mlp, | |
Attention, | |
CrossAttention, | |
DropPath, | |
CustomDecoderBlock, | |
) # noqa | |
inf = float("inf") | |
from accelerate.logging import get_logger | |
printer = get_logger(__name__, log_level="DEBUG") | |
class ARCroco3DStereoOutput(ModelOutput): | |
""" | |
Custom output class for ARCroco3DStereo. | |
""" | |
ress: Optional[List[Any]] = None | |
views: Optional[List[Any]] = None | |
def strip_module(state_dict): | |
""" | |
Removes the 'module.' prefix from the keys of a state_dict. | |
Args: | |
state_dict (dict): The original state_dict with possible 'module.' prefixes. | |
Returns: | |
OrderedDict: A new state_dict with 'module.' prefixes removed. | |
""" | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k[7:] if k.startswith("module.") else k | |
new_state_dict[name] = v | |
return new_state_dict | |
def load_model(model_path, device, verbose=True): | |
if verbose: | |
print("... loading model from", model_path) | |
ckpt = torch.load(model_path, map_location="cpu", weights_only=False) | |
args = ckpt["args"].model.replace( | |
"ManyAR_PatchEmbed", "PatchEmbedDust3R" | |
) # ManyAR only for aspect ratio not consistent | |
if "landscape_only" not in args: | |
args = args[:-2] + ", landscape_only=False))" | |
else: | |
args = args.replace(" ", "").replace( | |
"landscape_only=True", "landscape_only=False" | |
) | |
assert "landscape_only=False" in args | |
if verbose: | |
print(f"instantiating : {args}") | |
net = eval(args) | |
s = net.load_state_dict(ckpt["model"], strict=False) | |
if verbose: | |
print(s) | |
return net.to(device) | |
class ARCroco3DStereoConfig(PretrainedConfig): | |
model_type = "arcroco_3d_stereo" | |
def __init__( | |
self, | |
output_mode="pts3d", | |
head_type="linear", # or dpt | |
depth_mode=("exp", -float("inf"), float("inf")), | |
conf_mode=("exp", 1, float("inf")), | |
pose_mode=("exp", -float("inf"), float("inf")), | |
freeze="none", | |
landscape_only=True, | |
patch_embed_cls="PatchEmbedDust3R", | |
ray_enc_depth=2, | |
state_size=324, | |
local_mem_size=256, | |
state_pe="2d", | |
state_dec_num_heads=16, | |
depth_head=False, | |
rgb_head=False, | |
pose_conf_head=False, | |
pose_head=False, | |
**croco_kwargs, | |
): | |
super().__init__() | |
self.output_mode = output_mode | |
self.head_type = head_type | |
self.depth_mode = depth_mode | |
self.conf_mode = conf_mode | |
self.pose_mode = pose_mode | |
self.freeze = freeze | |
self.landscape_only = landscape_only | |
self.patch_embed_cls = patch_embed_cls | |
self.ray_enc_depth = ray_enc_depth | |
self.state_size = state_size | |
self.state_pe = state_pe | |
self.state_dec_num_heads = state_dec_num_heads | |
self.local_mem_size = local_mem_size | |
self.depth_head = depth_head | |
self.rgb_head = rgb_head | |
self.pose_conf_head = pose_conf_head | |
self.pose_head = pose_head | |
self.croco_kwargs = croco_kwargs | |
class LocalMemory(nn.Module): | |
def __init__( | |
self, | |
size, | |
k_dim, | |
v_dim, | |
num_heads, | |
depth=2, | |
mlp_ratio=4.0, | |
qkv_bias=False, | |
drop=0.0, | |
attn_drop=0.0, | |
drop_path=0.0, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
norm_mem=True, | |
rope=None, | |
) -> None: | |
super().__init__() | |
self.v_dim = v_dim | |
self.proj_q = nn.Linear(k_dim, v_dim) | |
self.masked_token = nn.Parameter( | |
torch.randn(1, 1, v_dim) * 0.2, requires_grad=True | |
) | |
self.mem = nn.Parameter( | |
torch.randn(1, size, 2 * v_dim) * 0.2, requires_grad=True | |
) | |
self.write_blocks = nn.ModuleList( | |
[ | |
DecoderBlock( | |
2 * v_dim, | |
num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
norm_layer=norm_layer, | |
attn_drop=attn_drop, | |
drop=drop, | |
drop_path=drop_path, | |
act_layer=act_layer, | |
norm_mem=norm_mem, | |
rope=rope, | |
) | |
for _ in range(depth) | |
] | |
) | |
self.read_blocks = nn.ModuleList( | |
[ | |
DecoderBlock( | |
2 * v_dim, | |
num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
norm_layer=norm_layer, | |
attn_drop=attn_drop, | |
drop=drop, | |
drop_path=drop_path, | |
act_layer=act_layer, | |
norm_mem=norm_mem, | |
rope=rope, | |
) | |
for _ in range(depth) | |
] | |
) | |
def update_mem(self, mem, feat_k, feat_v): | |
""" | |
mem_k: [B, size, C] | |
mem_v: [B, size, C] | |
feat_k: [B, 1, C] | |
feat_v: [B, 1, C] | |
""" | |
feat_k = self.proj_q(feat_k) # [B, 1, C] | |
feat = torch.cat([feat_k, feat_v], dim=-1) | |
for blk in self.write_blocks: | |
mem, _ = blk(mem, feat, None, None) | |
return mem | |
def inquire(self, query, mem): | |
x = self.proj_q(query) # [B, 1, C] | |
x = torch.cat([x, self.masked_token.expand(x.shape[0], -1, -1)], dim=-1) | |
for blk in self.read_blocks: | |
x, _ = blk(x, mem, None, None) | |
return x[..., -self.v_dim :] | |
class ARCroco3DStereo(CroCoNet): | |
config_class = ARCroco3DStereoConfig | |
base_model_prefix = "arcroco3dstereo" | |
supports_gradient_checkpointing = True | |
def __init__(self, config: ARCroco3DStereoConfig): | |
self.gradient_checkpointing = False | |
self.fixed_input_length = True | |
config.croco_kwargs = fill_default_args( | |
config.croco_kwargs, CrocoConfig.__init__ | |
) | |
self.config = config | |
self.patch_embed_cls = config.patch_embed_cls | |
self.croco_args = config.croco_kwargs | |
croco_cfg = CrocoConfig(**self.croco_args) | |
super().__init__(croco_cfg) | |
self.enc_blocks_ray_map = nn.ModuleList( | |
[ | |
Block( | |
self.enc_embed_dim, | |
16, | |
4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
rope=self.rope, | |
) | |
for _ in range(config.ray_enc_depth) | |
] | |
) | |
self.enc_norm_ray_map = nn.LayerNorm(self.enc_embed_dim, eps=1e-6) | |
self.dec_num_heads = self.croco_args["dec_num_heads"] | |
self.pose_head_flag = config.pose_head | |
if self.pose_head_flag: | |
self.pose_token = nn.Parameter( | |
torch.randn(1, 1, self.dec_embed_dim) * 0.02, requires_grad=True | |
) | |
self.pose_retriever = LocalMemory( | |
size=config.local_mem_size, | |
k_dim=self.enc_embed_dim, | |
v_dim=self.dec_embed_dim, | |
num_heads=self.dec_num_heads, | |
mlp_ratio=4, | |
qkv_bias=True, | |
attn_drop=0.0, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
rope=None, | |
) | |
self.register_tokens = nn.Embedding(config.state_size, self.enc_embed_dim) | |
self.state_size = config.state_size | |
self.state_pe = config.state_pe | |
self.masked_img_token = nn.Parameter( | |
torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True | |
) | |
self.masked_ray_map_token = nn.Parameter( | |
torch.randn(1, self.enc_embed_dim) * 0.02, requires_grad=True | |
) | |
self._set_state_decoder( | |
self.enc_embed_dim, | |
self.dec_embed_dim, | |
config.state_dec_num_heads, | |
self.dec_depth, | |
self.croco_args.get("mlp_ratio", None), | |
self.croco_args.get("norm_layer", None), | |
self.croco_args.get("norm_im2_in_dec", None), | |
) | |
self.set_downstream_head( | |
config.output_mode, | |
config.head_type, | |
config.landscape_only, | |
config.depth_mode, | |
config.conf_mode, | |
config.pose_mode, | |
config.depth_head, | |
config.rgb_head, | |
config.pose_conf_head, | |
config.pose_head, | |
**self.croco_args, | |
) | |
self.set_freeze(config.freeze) | |
def from_pretrained(cls, pretrained_model_name_or_path, **kw): | |
if os.path.isfile(pretrained_model_name_or_path): | |
return load_model(pretrained_model_name_or_path, device="cpu") | |
else: | |
try: | |
model = super(ARCroco3DStereo, cls).from_pretrained( | |
pretrained_model_name_or_path, **kw | |
) | |
except TypeError as e: | |
raise Exception( | |
f"tried to load {pretrained_model_name_or_path} from huggingface, but failed" | |
) | |
return model | |
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): | |
self.patch_embed = get_patch_embed( | |
self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=3 | |
) | |
self.patch_embed_ray_map = get_patch_embed( | |
self.patch_embed_cls, img_size, patch_size, enc_embed_dim, in_chans=6 | |
) | |
def _set_decoder( | |
self, | |
enc_embed_dim, | |
dec_embed_dim, | |
dec_num_heads, | |
dec_depth, | |
mlp_ratio, | |
norm_layer, | |
norm_im2_in_dec, | |
): | |
self.dec_depth = dec_depth | |
self.dec_embed_dim = dec_embed_dim | |
self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) | |
self.dec_blocks = nn.ModuleList( | |
[ | |
DecoderBlock( | |
dec_embed_dim, | |
dec_num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=True, | |
norm_layer=norm_layer, | |
norm_mem=norm_im2_in_dec, | |
rope=self.rope, | |
) | |
for i in range(dec_depth) | |
] | |
) | |
self.dec_norm = norm_layer(dec_embed_dim) | |
def _set_state_decoder( | |
self, | |
enc_embed_dim, | |
dec_embed_dim, | |
dec_num_heads, | |
dec_depth, | |
mlp_ratio, | |
norm_layer, | |
norm_im2_in_dec, | |
): | |
self.dec_depth_state = dec_depth | |
self.dec_embed_dim_state = dec_embed_dim | |
self.decoder_embed_state = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True) | |
self.dec_blocks_state = nn.ModuleList( | |
[ | |
DecoderBlock( | |
dec_embed_dim, | |
dec_num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=True, | |
norm_layer=norm_layer, | |
norm_mem=norm_im2_in_dec, | |
rope=self.rope, | |
) | |
for i in range(dec_depth) | |
] | |
) | |
self.dec_norm_state = norm_layer(dec_embed_dim) | |
def load_state_dict(self, ckpt, **kw): | |
if all(k.startswith("module") for k in ckpt): | |
ckpt = strip_module(ckpt) | |
new_ckpt = dict(ckpt) | |
if not any(k.startswith("dec_blocks_state") for k in ckpt): | |
for key, value in ckpt.items(): | |
if key.startswith("dec_blocks"): | |
new_ckpt[key.replace("dec_blocks", "dec_blocks_state")] = value | |
try: | |
return super().load_state_dict(new_ckpt, **kw) | |
except: | |
try: | |
new_new_ckpt = { | |
k: v | |
for k, v in new_ckpt.items() | |
if not k.startswith("dec_blocks") | |
and not k.startswith("dec_norm") | |
and not k.startswith("decoder_embed") | |
} | |
return super().load_state_dict(new_new_ckpt, **kw) | |
except: | |
new_new_ckpt = {} | |
for key in new_ckpt: | |
if key in self.state_dict(): | |
if new_ckpt[key].size() == self.state_dict()[key].size(): | |
new_new_ckpt[key] = new_ckpt[key] | |
else: | |
printer.info( | |
f"Skipping '{key}': size mismatch (ckpt: {new_ckpt[key].size()}, model: {self.state_dict()[key].size()})" | |
) | |
else: | |
printer.info(f"Skipping '{key}': not found in model") | |
return super().load_state_dict(new_new_ckpt, **kw) | |
def set_freeze(self, freeze): # this is for use by downstream models | |
self.freeze = freeze | |
to_be_frozen = { | |
"none": [], | |
"mask": [self.mask_token] if hasattr(self, "mask_token") else [], | |
"encoder": [ | |
self.patch_embed, | |
self.patch_embed_ray_map, | |
self.masked_img_token, | |
self.masked_ray_map_token, | |
self.enc_blocks, | |
self.enc_blocks_ray_map, | |
self.enc_norm, | |
self.enc_norm_ray_map, | |
], | |
"encoder_and_head": [ | |
self.patch_embed, | |
self.patch_embed_ray_map, | |
self.masked_img_token, | |
self.masked_ray_map_token, | |
self.enc_blocks, | |
self.enc_blocks_ray_map, | |
self.enc_norm, | |
self.enc_norm_ray_map, | |
self.downstream_head, | |
], | |
"encoder_and_decoder": [ | |
self.patch_embed, | |
self.patch_embed_ray_map, | |
self.masked_img_token, | |
self.masked_ray_map_token, | |
self.enc_blocks, | |
self.enc_blocks_ray_map, | |
self.enc_norm, | |
self.enc_norm_ray_map, | |
self.dec_blocks, | |
self.dec_blocks_state, | |
self.pose_retriever, | |
self.pose_token, | |
self.register_tokens, | |
self.decoder_embed_state, | |
self.decoder_embed, | |
self.dec_norm, | |
self.dec_norm_state, | |
], | |
"decoder": [ | |
self.dec_blocks, | |
self.dec_blocks_state, | |
self.pose_retriever, | |
self.pose_token, | |
], | |
} | |
freeze_all_params(to_be_frozen[freeze]) | |
def _set_prediction_head(self, *args, **kwargs): | |
"""No prediction head""" | |
return | |
def set_downstream_head( | |
self, | |
output_mode, | |
head_type, | |
landscape_only, | |
depth_mode, | |
conf_mode, | |
pose_mode, | |
depth_head, | |
rgb_head, | |
pose_conf_head, | |
pose_head, | |
patch_size, | |
img_size, | |
**kw, | |
): | |
assert ( | |
img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0 | |
), f"{img_size=} must be multiple of {patch_size=}" | |
self.output_mode = output_mode | |
self.head_type = head_type | |
self.depth_mode = depth_mode | |
self.conf_mode = conf_mode | |
self.pose_mode = pose_mode | |
self.downstream_head = head_factory( | |
head_type, | |
output_mode, | |
self, | |
has_conf=bool(conf_mode), | |
has_depth=bool(depth_head), | |
has_rgb=bool(rgb_head), | |
has_pose_conf=bool(pose_conf_head), | |
has_pose=bool(pose_head), | |
) | |
self.head = transpose_to_landscape( | |
self.downstream_head, activate=landscape_only | |
) | |
def _encode_image(self, image, true_shape): | |
x, pos = self.patch_embed(image, true_shape=true_shape) | |
assert self.enc_pos_embed is None | |
for blk in self.enc_blocks: | |
if self.gradient_checkpointing and self.training: | |
x = checkpoint(blk, x, pos, use_reentrant=False) | |
else: | |
x = blk(x, pos) | |
x = self.enc_norm(x) | |
return [x], pos, None | |
def _encode_ray_map(self, ray_map, true_shape): | |
x, pos = self.patch_embed_ray_map(ray_map, true_shape=true_shape) | |
assert self.enc_pos_embed is None | |
for blk in self.enc_blocks_ray_map: | |
if self.gradient_checkpointing and self.training: | |
x = checkpoint(blk, x, pos, use_reentrant=False) | |
else: | |
x = blk(x, pos) | |
x = self.enc_norm_ray_map(x) | |
return [x], pos, None | |
def _encode_state(self, image_tokens, image_pos): | |
batch_size = image_tokens.shape[0] | |
state_feat = self.register_tokens( | |
torch.arange(self.state_size, device=image_pos.device) | |
) | |
if self.state_pe == "1d": | |
state_pos = ( | |
torch.tensor( | |
[[i, i] for i in range(self.state_size)], | |
dtype=image_pos.dtype, | |
device=image_pos.device, | |
)[None] | |
.expand(batch_size, -1, -1) | |
.contiguous() | |
) # .long() | |
elif self.state_pe == "2d": | |
width = int(self.state_size**0.5) | |
width = width + 1 if width % 2 == 1 else width | |
state_pos = ( | |
torch.tensor( | |
[[i // width, i % width] for i in range(self.state_size)], | |
dtype=image_pos.dtype, | |
device=image_pos.device, | |
)[None] | |
.expand(batch_size, -1, -1) | |
.contiguous() | |
) | |
elif self.state_pe == "none": | |
state_pos = None | |
state_feat = state_feat[None].expand(batch_size, -1, -1) | |
return state_feat, state_pos, None | |
def _encode_views(self, views, img_mask=None, ray_mask=None): | |
device = views[0]["img"].device | |
batch_size = views[0]["img"].shape[0] | |
given = True | |
if img_mask is None and ray_mask is None: | |
given = False | |
if not given: | |
img_mask = torch.stack( | |
[view["img_mask"] for view in views], dim=0 | |
) # Shape: (num_views, batch_size) | |
ray_mask = torch.stack( | |
[view["ray_mask"] for view in views], dim=0 | |
) # Shape: (num_views, batch_size) | |
imgs = torch.stack( | |
[view["img"] for view in views], dim=0 | |
) # Shape: (num_views, batch_size, C, H, W) | |
ray_maps = torch.stack( | |
[view["ray_map"] for view in views], dim=0 | |
) # Shape: (num_views, batch_size, H, W, C) | |
shapes = [] | |
for view in views: | |
if "true_shape" in view: | |
shapes.append(view["true_shape"]) | |
else: | |
shape = torch.tensor(view["img"].shape[-2:], device=device) | |
shapes.append(shape.unsqueeze(0).repeat(batch_size, 1)) | |
shapes = torch.stack(shapes, dim=0).to( | |
imgs.device | |
) # Shape: (num_views, batch_size, 2) | |
imgs = imgs.view( | |
-1, *imgs.shape[2:] | |
) # Shape: (num_views * batch_size, C, H, W) | |
ray_maps = ray_maps.view( | |
-1, *ray_maps.shape[2:] | |
) # Shape: (num_views * batch_size, H, W, C) | |
shapes = shapes.view(-1, 2) # Shape: (num_views * batch_size, 2) | |
img_masks_flat = img_mask.view(-1) # Shape: (num_views * batch_size) | |
ray_masks_flat = ray_mask.view(-1) | |
selected_imgs = imgs[img_masks_flat] | |
selected_shapes = shapes[img_masks_flat] | |
if selected_imgs.size(0) > 0: | |
img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes) | |
else: | |
raise NotImplementedError | |
full_out = [ | |
torch.zeros( | |
len(views) * batch_size, *img_out[0].shape[1:], device=img_out[0].device | |
) | |
for _ in range(len(img_out)) | |
] | |
full_pos = torch.zeros( | |
len(views) * batch_size, | |
*img_pos.shape[1:], | |
device=img_pos.device, | |
dtype=img_pos.dtype, | |
) | |
for i in range(len(img_out)): | |
full_out[i][img_masks_flat] += img_out[i] | |
full_out[i][~img_masks_flat] += self.masked_img_token | |
full_pos[img_masks_flat] += img_pos | |
ray_maps = ray_maps.permute(0, 3, 1, 2) # Change shape to (N, C, H, W) | |
selected_ray_maps = ray_maps[ray_masks_flat] | |
selected_shapes_ray = shapes[ray_masks_flat] | |
if selected_ray_maps.size(0) > 0: | |
ray_out, ray_pos, _ = self._encode_ray_map( | |
selected_ray_maps, selected_shapes_ray | |
) | |
assert len(ray_out) == len(full_out), f"{len(ray_out)}, {len(full_out)}" | |
for i in range(len(ray_out)): | |
full_out[i][ray_masks_flat] += ray_out[i] | |
full_out[i][~ray_masks_flat] += self.masked_ray_map_token | |
full_pos[ray_masks_flat] += ( | |
ray_pos * (~img_masks_flat[ray_masks_flat][:, None, None]).long() | |
) | |
else: | |
raymaps = torch.zeros( | |
1, 6, imgs[0].shape[-2], imgs[0].shape[-1], device=img_out[0].device | |
) | |
ray_mask_flat = torch.zeros_like(img_masks_flat) | |
ray_mask_flat[:1] = True | |
ray_out, ray_pos, _ = self._encode_ray_map(raymaps, shapes[ray_mask_flat]) | |
for i in range(len(ray_out)): | |
full_out[i][ray_mask_flat] += ray_out[i] * 0.0 | |
full_out[i][~ray_mask_flat] += self.masked_ray_map_token * 0.0 | |
return ( | |
shapes.chunk(len(views), dim=0), | |
[out.chunk(len(views), dim=0) for out in full_out], | |
full_pos.chunk(len(views), dim=0), | |
) | |
def _decoder(self, f_state, pos_state, f_img, pos_img, f_pose, pos_pose): | |
final_output = [(f_state, f_img)] # before projection | |
assert f_state.shape[-1] == self.dec_embed_dim | |
f_img = self.decoder_embed(f_img) | |
if self.pose_head_flag: | |
assert f_pose is not None and pos_pose is not None | |
f_img = torch.cat([f_pose, f_img], dim=1) | |
pos_img = torch.cat([pos_pose, pos_img], dim=1) | |
final_output.append((f_state, f_img)) | |
for blk_state, blk_img in zip(self.dec_blocks_state, self.dec_blocks): | |
if ( | |
self.gradient_checkpointing | |
and self.training | |
and torch.is_grad_enabled() | |
): | |
f_state, _ = checkpoint( | |
blk_state, | |
*final_output[-1][::+1], | |
pos_state, | |
pos_img, | |
use_reentrant=not self.fixed_input_length, | |
) | |
f_img, _ = checkpoint( | |
blk_img, | |
*final_output[-1][::-1], | |
pos_img, | |
pos_state, | |
use_reentrant=not self.fixed_input_length, | |
) | |
else: | |
f_state, _ = blk_state(*final_output[-1][::+1], pos_state, pos_img) | |
f_img, _ = blk_img(*final_output[-1][::-1], pos_img, pos_state) | |
final_output.append((f_state, f_img)) | |
del final_output[1] # duplicate with final_output[0] | |
final_output[-1] = ( | |
self.dec_norm_state(final_output[-1][0]), | |
self.dec_norm(final_output[-1][1]), | |
) | |
return zip(*final_output) | |
def _downstream_head(self, decout, img_shape, **kwargs): | |
B, S, D = decout[-1].shape | |
head = getattr(self, f"head") | |
return head(decout, img_shape, **kwargs) | |
def _init_state(self, image_tokens, image_pos): | |
""" | |
Current Version: input the first frame img feature and pose to initialize the state feature and pose | |
""" | |
state_feat, state_pos, _ = self._encode_state(image_tokens, image_pos) | |
state_feat = self.decoder_embed_state(state_feat) | |
return state_feat, state_pos | |
def _recurrent_rollout( | |
self, | |
state_feat, | |
state_pos, | |
current_feat, | |
current_pos, | |
pose_feat, | |
pose_pos, | |
init_state_feat, | |
img_mask=None, | |
reset_mask=None, | |
update=None, | |
): | |
new_state_feat, dec = self._decoder( | |
state_feat, state_pos, current_feat, current_pos, pose_feat, pose_pos | |
) | |
new_state_feat = new_state_feat[-1] | |
return new_state_feat, dec | |
def _get_img_level_feat(self, feat): | |
return torch.mean(feat, dim=1, keepdim=True) | |
def _forward_encoder(self, views): | |
shape, feat_ls, pos = self._encode_views(views) | |
feat = feat_ls[-1] | |
state_feat, state_pos = self._init_state(feat[0], pos[0]) | |
mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1) | |
init_state_feat = state_feat.clone() | |
init_mem = mem.clone() | |
return (feat, pos, shape), ( | |
init_state_feat, | |
init_mem, | |
state_feat, | |
state_pos, | |
mem, | |
) | |
def _forward_decoder_step( | |
self, | |
views, | |
i, | |
feat_i, | |
pos_i, | |
shape_i, | |
init_state_feat, | |
init_mem, | |
state_feat, | |
state_pos, | |
mem, | |
): | |
if self.pose_head_flag: | |
global_img_feat_i = self._get_img_level_feat(feat_i) | |
if i == 0: | |
pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1) | |
else: | |
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) | |
pose_pos_i = -torch.ones( | |
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype | |
) | |
else: | |
pose_feat_i = None | |
pose_pos_i = None | |
new_state_feat, dec = self._recurrent_rollout( | |
state_feat, | |
state_pos, | |
feat_i, | |
pos_i, | |
pose_feat_i, | |
pose_pos_i, | |
init_state_feat, | |
img_mask=views[i]["img_mask"], | |
reset_mask=views[i]["reset"], | |
update=views[i].get("update", None), | |
) | |
out_pose_feat_i = dec[-1][:, 0:1] | |
new_mem = self.pose_retriever.update_mem( | |
mem, global_img_feat_i, out_pose_feat_i | |
) | |
head_input = [ | |
dec[0].float(), | |
dec[self.dec_depth * 2 // 4][:, 1:].float(), | |
dec[self.dec_depth * 3 // 4][:, 1:].float(), | |
dec[self.dec_depth].float(), | |
] | |
res = self._downstream_head(head_input, shape_i, pos=pos_i) | |
img_mask = views[i]["img_mask"] | |
update = views[i].get("update", None) | |
if update is not None: | |
update_mask = img_mask & update # if don't update, then whatever img_mask | |
else: | |
update_mask = img_mask | |
update_mask = update_mask[:, None, None].float() | |
state_feat = new_state_feat * update_mask + state_feat * ( | |
1 - update_mask | |
) # update global state | |
mem = new_mem * update_mask + mem * (1 - update_mask) # then update local state | |
reset_mask = views[i]["reset"] | |
if reset_mask is not None: | |
reset_mask = reset_mask[:, None, None].float() | |
state_feat = init_state_feat * reset_mask + state_feat * (1 - reset_mask) | |
mem = init_mem * reset_mask + mem * (1 - reset_mask) | |
return res, (state_feat, mem) | |
def _forward_impl(self, views, ret_state=False): | |
shape, feat_ls, pos = self._encode_views(views) | |
feat = feat_ls[-1] | |
state_feat, state_pos = self._init_state(feat[0], pos[0]) | |
mem = self.pose_retriever.mem.expand(feat[0].shape[0], -1, -1) | |
init_state_feat = state_feat.clone() | |
init_mem = mem.clone() | |
all_state_args = [(state_feat, state_pos, init_state_feat, mem, init_mem)] | |
ress = [] | |
for i in range(len(views)): | |
feat_i = feat[i] | |
pos_i = pos[i] | |
if self.pose_head_flag: | |
global_img_feat_i = self._get_img_level_feat(feat_i) | |
if i == 0: | |
pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1) | |
else: | |
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) | |
pose_pos_i = -torch.ones( | |
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype | |
) | |
else: | |
pose_feat_i = None | |
pose_pos_i = None | |
new_state_feat, dec = self._recurrent_rollout( | |
state_feat, | |
state_pos, | |
feat_i, | |
pos_i, | |
pose_feat_i, | |
pose_pos_i, | |
init_state_feat, | |
img_mask=views[i]["img_mask"], | |
reset_mask=views[i]["reset"], | |
update=views[i].get("update", None), | |
) | |
out_pose_feat_i = dec[-1][:, 0:1] | |
new_mem = self.pose_retriever.update_mem( | |
mem, global_img_feat_i, out_pose_feat_i | |
) | |
assert len(dec) == self.dec_depth + 1 | |
head_input = [ | |
dec[0].float(), | |
dec[self.dec_depth * 2 // 4][:, 1:].float(), | |
dec[self.dec_depth * 3 // 4][:, 1:].float(), | |
dec[self.dec_depth].float(), | |
] | |
res = self._downstream_head(head_input, shape[i], pos=pos_i) | |
ress.append(res) | |
img_mask = views[i]["img_mask"] | |
update = views[i].get("update", None) | |
if update is not None: | |
update_mask = ( | |
img_mask & update | |
) # if don't update, then whatever img_mask | |
else: | |
update_mask = img_mask | |
update_mask = update_mask[:, None, None].float() | |
state_feat = new_state_feat * update_mask + state_feat * ( | |
1 - update_mask | |
) # update global state | |
mem = new_mem * update_mask + mem * ( | |
1 - update_mask | |
) # then update local state | |
reset_mask = views[i]["reset"] | |
if reset_mask is not None: | |
reset_mask = reset_mask[:, None, None].float() | |
state_feat = init_state_feat * reset_mask + state_feat * ( | |
1 - reset_mask | |
) | |
mem = init_mem * reset_mask + mem * (1 - reset_mask) | |
all_state_args.append( | |
(state_feat, state_pos, init_state_feat, mem, init_mem) | |
) | |
if ret_state: | |
return ress, views, all_state_args | |
return ress, views | |
def forward(self, views, ret_state=False): | |
if ret_state: | |
ress, views, state_args = self._forward_impl(views, ret_state=ret_state) | |
return ARCroco3DStereoOutput(ress=ress, views=views), state_args | |
else: | |
ress, views = self._forward_impl(views, ret_state=ret_state) | |
return ARCroco3DStereoOutput(ress=ress, views=views) | |
def inference_step( | |
self, view, state_feat, state_pos, init_state_feat, mem, init_mem | |
): | |
batch_size = view["img"].shape[0] | |
raymaps = [] | |
shapes = [] | |
for j in range(batch_size): | |
assert view["ray_mask"][j] | |
raymap = view["ray_map"][[j]].permute(0, 3, 1, 2) | |
raymaps.append(raymap) | |
shapes.append( | |
view.get( | |
"true_shape", | |
torch.tensor(view["ray_map"].shape[-2:])[None].repeat( | |
view["ray_map"].shape[0], 1 | |
), | |
)[[j]] | |
) | |
raymaps = torch.cat(raymaps, dim=0) | |
shape = torch.cat(shapes, dim=0).to(raymaps.device) | |
feat_ls, pos, _ = self._encode_ray_map(raymaps, shapes) | |
feat_i = feat_ls[-1] | |
pos_i = pos | |
if self.pose_head_flag: | |
global_img_feat_i = self._get_img_level_feat(feat_i) | |
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) | |
pose_pos_i = -torch.ones( | |
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype | |
) | |
else: | |
pose_feat_i = None | |
pose_pos_i = None | |
new_state_feat, dec = self._recurrent_rollout( | |
state_feat, | |
state_pos, | |
feat_i, | |
pos_i, | |
pose_feat_i, | |
pose_pos_i, | |
init_state_feat, | |
img_mask=view["img_mask"], | |
reset_mask=view["reset"], | |
update=view.get("update", None), | |
) | |
out_pose_feat_i = dec[-1][:, 0:1] | |
new_mem = self.pose_retriever.update_mem( | |
mem, global_img_feat_i, out_pose_feat_i | |
) | |
assert len(dec) == self.dec_depth + 1 | |
head_input = [ | |
dec[0].float(), | |
dec[self.dec_depth * 2 // 4][:, 1:].float(), | |
dec[self.dec_depth * 3 // 4][:, 1:].float(), | |
dec[self.dec_depth].float(), | |
] | |
res = self._downstream_head(head_input, shape, pos=pos_i) | |
return res, view | |
def forward_recurrent(self, views, device, ret_state=False): | |
ress = [] | |
all_state_args = [] | |
for i, view in enumerate(views): | |
device = view["img"].device | |
batch_size = view["img"].shape[0] | |
img_mask = view["img_mask"].reshape( | |
-1, batch_size | |
) # Shape: (1, batch_size) | |
ray_mask = view["ray_mask"].reshape( | |
-1, batch_size | |
) # Shape: (1, batch_size) | |
imgs = view["img"].unsqueeze(0) # Shape: (1, batch_size, C, H, W) | |
ray_maps = view["ray_map"].unsqueeze( | |
0 | |
) # Shape: (num_views, batch_size, H, W, C) | |
shapes = ( | |
view["true_shape"].unsqueeze(0) | |
if "true_shape" in view | |
else torch.tensor(view["img"].shape[-2:], device=device) | |
.unsqueeze(0) | |
.repeat(batch_size, 1) | |
.unsqueeze(0) | |
) # Shape: (num_views, batch_size, 2) | |
imgs = imgs.view( | |
-1, *imgs.shape[2:] | |
) # Shape: (num_views * batch_size, C, H, W) | |
ray_maps = ray_maps.view( | |
-1, *ray_maps.shape[2:] | |
) # Shape: (num_views * batch_size, H, W, C) | |
shapes = shapes.view(-1, 2).to( | |
imgs.device | |
) # Shape: (num_views * batch_size, 2) | |
img_masks_flat = img_mask.view(-1) # Shape: (num_views * batch_size) | |
ray_masks_flat = ray_mask.view(-1) | |
selected_imgs = imgs[img_masks_flat] | |
selected_shapes = shapes[img_masks_flat] | |
if selected_imgs.size(0) > 0: | |
img_out, img_pos, _ = self._encode_image(selected_imgs, selected_shapes) | |
else: | |
img_out, img_pos = None, None | |
ray_maps = ray_maps.permute(0, 3, 1, 2) # Change shape to (N, C, H, W) | |
selected_ray_maps = ray_maps[ray_masks_flat] | |
selected_shapes_ray = shapes[ray_masks_flat] | |
if selected_ray_maps.size(0) > 0: | |
ray_out, ray_pos, _ = self._encode_ray_map( | |
selected_ray_maps, selected_shapes_ray | |
) | |
else: | |
ray_out, ray_pos = None, None | |
shape = shapes | |
if img_out is not None and ray_out is None: | |
feat_i = img_out[-1] | |
pos_i = img_pos | |
elif img_out is None and ray_out is not None: | |
feat_i = ray_out[-1] | |
pos_i = ray_pos | |
elif img_out is not None and ray_out is not None: | |
feat_i = img_out[-1] + ray_out[-1] | |
pos_i = img_pos | |
else: | |
raise NotImplementedError | |
if i == 0: | |
state_feat, state_pos = self._init_state(feat_i, pos_i) | |
mem = self.pose_retriever.mem.expand(feat_i.shape[0], -1, -1) | |
init_state_feat = state_feat.clone() | |
init_mem = mem.clone() | |
all_state_args.append( | |
(state_feat, state_pos, init_state_feat, mem, init_mem) | |
) | |
if self.pose_head_flag: | |
global_img_feat_i = self._get_img_level_feat(feat_i) | |
if i == 0: | |
pose_feat_i = self.pose_token.expand(feat_i.shape[0], -1, -1) | |
else: | |
pose_feat_i = self.pose_retriever.inquire(global_img_feat_i, mem) | |
pose_pos_i = -torch.ones( | |
feat_i.shape[0], 1, 2, device=feat_i.device, dtype=pos_i.dtype | |
) | |
else: | |
pose_feat_i = None | |
pose_pos_i = None | |
new_state_feat, dec = self._recurrent_rollout( | |
state_feat, | |
state_pos, | |
feat_i, | |
pos_i, | |
pose_feat_i, | |
pose_pos_i, | |
init_state_feat, | |
img_mask=view["img_mask"], | |
reset_mask=view["reset"], | |
update=view.get("update", None), | |
) | |
out_pose_feat_i = dec[-1][:, 0:1] | |
new_mem = self.pose_retriever.update_mem( | |
mem, global_img_feat_i, out_pose_feat_i | |
) | |
assert len(dec) == self.dec_depth + 1 | |
head_input = [ | |
dec[0].float(), | |
dec[self.dec_depth * 2 // 4][:, 1:].float(), | |
dec[self.dec_depth * 3 // 4][:, 1:].float(), | |
dec[self.dec_depth].float(), | |
] | |
res = self._downstream_head(head_input, shape, pos=pos_i) | |
ress.append(res) | |
img_mask = view["img_mask"] | |
update = view.get("update", None) | |
if update is not None: | |
update_mask = ( | |
img_mask & update | |
) # if don't update, then whatever img_mask | |
else: | |
update_mask = img_mask | |
update_mask = update_mask[:, None, None].float() | |
state_feat = new_state_feat * update_mask + state_feat * ( | |
1 - update_mask | |
) # update global state | |
mem = new_mem * update_mask + mem * ( | |
1 - update_mask | |
) # then update local state | |
reset_mask = view["reset"] | |
if reset_mask is not None: | |
reset_mask = reset_mask[:, None, None].float() | |
state_feat = init_state_feat * reset_mask + state_feat * ( | |
1 - reset_mask | |
) | |
mem = init_mem * reset_mask + mem * (1 - reset_mask) | |
all_state_args.append( | |
(state_feat, state_pos, init_state_feat, mem, init_mem) | |
) | |
if ret_state: | |
return ress, views, all_state_args | |
return ress, views | |
if __name__ == "__main__": | |
print(ARCroco3DStereo.mro()) | |
cfg = ARCroco3DStereoConfig( | |
state_size=256, | |
pos_embed="RoPE100", | |
rgb_head=True, | |
pose_head=True, | |
img_size=(224, 224), | |
head_type="linear", | |
output_mode="pts3d+pose", | |
depth_mode=("exp", -inf, inf), | |
conf_mode=("exp", 1, inf), | |
pose_mode=("exp", -inf, inf), | |
enc_embed_dim=1024, | |
enc_depth=24, | |
enc_num_heads=16, | |
dec_embed_dim=768, | |
dec_depth=12, | |
dec_num_heads=12, | |
) | |
ARCroco3DStereo(cfg) | |