File size: 12,206 Bytes
0a82b18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Whitener and RetrievalModel
# --------------------------------------------------------
import numpy as np
from tqdm import tqdm
import time
import torch
import torch.nn as nn
import mast3r.utils.path_to_dust3r # noqa
from dust3r.utils.image import load_images
default_device = torch.device('cuda' if torch.cuda.is_available() and torch.cuda.device_count() > 0 else 'cpu')
# from https://github.com/gtolias/how/blob/4d73c88e0ffb55506e2ce6249e2a015ef6ccf79f/how/utils/whitening.py#L20
def pcawhitenlearn_shrinkage(X, s=1.0):
"""Learn PCA whitening with shrinkage from given descriptors"""
N = X.shape[0]
# Learning PCA w/o annotations
m = X.mean(axis=0, keepdims=True)
Xc = X - m
Xcov = np.dot(Xc.T, Xc)
Xcov = (Xcov + Xcov.T) / (2 * N)
eigval, eigvec = np.linalg.eig(Xcov)
order = eigval.argsort()[::-1]
eigval = eigval[order]
eigvec = eigvec[:, order]
eigval = np.clip(eigval, a_min=1e-14, a_max=None)
P = np.dot(np.linalg.inv(np.diag(np.power(eigval, 0.5 * s))), eigvec.T)
return m, P.T
class Dust3rInputFromImageList(torch.utils.data.Dataset):
def __init__(self, image_list, imsize=512):
super().__init__()
self.image_list = image_list
assert imsize == 512
self.imsize = imsize
def __len__(self):
return len(self.image_list)
def __getitem__(self, index):
return load_images([self.image_list[index]], size=self.imsize, verbose=False)[0]
class Whitener(nn.Module):
def __init__(self, dim, l2norm=None):
super().__init__()
self.m = torch.nn.Parameter(torch.zeros((1, dim)).double())
self.p = torch.nn.Parameter(torch.eye(dim, dim).double())
self.l2norm = l2norm # if not None, apply l2 norm along a given dimension
def forward(self, x):
with torch.autocast(self.m.device.type, enabled=False):
shape = x.size()
input_type = x.dtype
x_reshaped = x.view(-1, shape[-1]).to(dtype=self.m.dtype)
# Center the input data
x_centered = x_reshaped - self.m
# Apply PCA transformation
pca_output = torch.matmul(x_centered, self.p)
# reshape back
pca_output_shape = shape # list(shape[:-1]) + [shape[-1]]
pca_output = pca_output.view(pca_output_shape)
if self.l2norm is not None:
return torch.nn.functional.normalize(pca_output, dim=self.l2norm).to(dtype=input_type)
return pca_output.to(dtype=input_type)
def weighted_spoc(feat, attn):
"""
feat: BxNxC
attn: BxN
output: BxC L2-normalization weighted-sum-pooling of features
"""
return torch.nn.functional.normalize((feat * attn[:, :, None]).sum(dim=1), dim=1)
def how_select_local(feat, attn, nfeat):
"""
feat: BxNxC
attn: BxN
nfeat: nfeat to keep
"""
# get nfeat
if nfeat < 0:
assert nfeat >= -1.0
nfeat = int(-nfeat * feat.size(1))
else:
nfeat = int(nfeat)
# asort
topk_attn, topk_indices = torch.topk(attn, min(nfeat, attn.size(1)), dim=1)
topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, feat.size(2))
topk_features = torch.gather(feat, 1, topk_indices_expanded)
return topk_features, topk_attn, topk_indices
class RetrievalModel(nn.Module):
def __init__(self, backbone, freeze_backbone=1, prewhiten=None, hdims=[1024], residual=False, postwhiten=None,
featweights='l2norm', nfeat=300, pretrained_retrieval=None):
super().__init__()
self.backbone = backbone
self.freeze_backbone = freeze_backbone
if freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
self.backbone_dim = backbone.enc_embed_dim
self.prewhiten = nn.Identity() if prewhiten is None else Whitener(self.backbone_dim)
self.prewhiten_freq = prewhiten
if prewhiten is not None and prewhiten != -1:
for p in self.prewhiten.parameters():
p.requires_grad = False
self.residual = residual
self.projector = self.build_projector(hdims, residual)
self.dim = hdims[-1] if len(hdims) > 0 else self.backbone_dim
self.postwhiten_freq = postwhiten
self.postwhiten = nn.Identity() if postwhiten is None else Whitener(self.dim)
if postwhiten is not None and postwhiten != -1:
assert len(hdims) > 0
for p in self.postwhiten.parameters():
p.requires_grad = False
self.featweights = featweights
if featweights == 'l2norm':
self.attention = lambda x: x.norm(dim=-1)
else:
raise NotImplementedError(featweights)
self.nfeat = nfeat
self.pretrained_retrieval = pretrained_retrieval
if self.pretrained_retrieval is not None:
ckpt = torch.load(pretrained_retrieval, 'cpu')
msg = self.load_state_dict(ckpt['model'], strict=False)
assert len(msg.unexpected_keys) == 0 and all(k.startswith('backbone')
or k.startswith('postwhiten') for k in msg.missing_keys)
def build_projector(self, hdims, residual):
if self.residual:
assert hdims[-1] == self.backbone_dim
d = self.backbone_dim
if len(hdims) == 0:
return nn.Identity()
layers = []
for i in range(len(hdims) - 1):
layers.append(nn.Linear(d, hdims[i]))
d = hdims[i]
layers.append(nn.LayerNorm(d))
layers.append(nn.GELU())
layers.append(nn.Linear(d, hdims[-1]))
return nn.Sequential(*layers)
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
ss = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
if self.freeze_backbone:
ss = {k: v for k, v in ss.items() if not k.startswith('backbone')}
return ss
def reinitialize_whitening(self, epoch, train_dataset, nimgs=5000, log_writer=None, max_nfeat_per_image=None, seed=0, device=default_device):
do_prewhiten = self.prewhiten_freq is not None and self.pretrained_retrieval is None and \
(epoch == 0 or (self.prewhiten_freq > 0 and epoch % self.prewhiten_freq == 0))
do_postwhiten = self.postwhiten_freq is not None and ((epoch == 0 and self.postwhiten_freq in [0, -1])
or (self.postwhiten_freq > 0 and
epoch % self.postwhiten_freq == 0 and epoch > 0))
if do_prewhiten or do_postwhiten:
self.eval()
imdataset = train_dataset.imlist_dataset_n_images(nimgs, seed)
loader = torch.utils.data.DataLoader(imdataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
if do_prewhiten:
print('Re-initialization of pre-whitening')
t = time.time()
with torch.no_grad():
features = []
for d in tqdm(loader):
feat = self.backbone._encode_image(d['img'][0, ...].to(device),
true_shape=d['true_shape'][0, ...])[0]
feat = feat.flatten(0, 1)
if max_nfeat_per_image is not None and max_nfeat_per_image < feat.size(0):
l2norms = torch.linalg.vector_norm(feat, dim=1)
feat = feat[torch.argsort(-l2norms)[:max_nfeat_per_image], :]
features.append(feat.cpu())
features = torch.cat(features, dim=0)
features = features.numpy()
m, P = pcawhitenlearn_shrinkage(features)
self.prewhiten.load_state_dict({'m': torch.from_numpy(m), 'p': torch.from_numpy(P)})
prewhiten_time = time.time() - t
print(f'Done in {prewhiten_time:.1f} seconds')
if log_writer is not None:
log_writer.add_scalar('time/prewhiten', prewhiten_time, epoch)
if do_postwhiten:
print(f'Re-initialization of post-whitening')
t = time.time()
with torch.no_grad():
features = []
for d in tqdm(loader):
backbone_feat = self.backbone._encode_image(d['img'][0, ...].to(device),
true_shape=d['true_shape'][0, ...])[0]
backbone_feat_prewhitened = self.prewhiten(backbone_feat)
proj_feat = self.projector(backbone_feat_prewhitened) + \
(0.0 if not self.residual else backbone_feat_prewhitened)
proj_feat = proj_feat.flatten(0, 1)
if max_nfeat_per_image is not None and max_nfeat_per_image < proj_feat.size(0):
l2norms = torch.linalg.vector_norm(proj_feat, dim=1)
proj_feat = proj_feat[torch.argsort(-l2norms)[:max_nfeat_per_image], :]
features.append(proj_feat.cpu())
features = torch.cat(features, dim=0)
features = features.numpy()
m, P = pcawhitenlearn_shrinkage(features)
self.postwhiten.load_state_dict({'m': torch.from_numpy(m), 'p': torch.from_numpy(P)})
postwhiten_time = time.time() - t
print(f'Done in {postwhiten_time:.1f} seconds')
if log_writer is not None:
log_writer.add_scalar('time/postwhiten', postwhiten_time, epoch)
def extract_features_and_attention(self, x):
backbone_feat = self.backbone._encode_image(x['img'], true_shape=x['true_shape'])[0]
backbone_feat_prewhitened = self.prewhiten(backbone_feat)
proj_feat = self.projector(backbone_feat_prewhitened) + \
(0.0 if not self.residual else backbone_feat_prewhitened)
attention = self.attention(proj_feat)
proj_feat_whitened = self.postwhiten(proj_feat)
return proj_feat_whitened, attention
def forward_local(self, x):
feat, attn = self.extract_features_and_attention(x)
return how_select_local(feat, attn, self.nfeat)
def forward_global(self, x):
feat, attn = self.extract_features_and_attention(x)
return weighted_spoc(feat, attn)
def forward(self, x):
return self.forward_global(x)
def identity(x): # to avoid Can't pickle local object 'extract_local_features.<locals>.<lambda>'
return x
@torch.no_grad()
def extract_local_features(model, images, imsize, seed=0, tocpu=False, max_nfeat_per_image=None,
max_nfeat_per_image2=None, device=default_device):
model.eval()
imdataset = Dust3rInputFromImageList(images, imsize=imsize) if isinstance(images, list) else images
loader = torch.utils.data.DataLoader(imdataset, batch_size=1, shuffle=False,
num_workers=8, pin_memory=True, collate_fn=identity)
with torch.no_grad():
features = []
imids = []
for i, d in enumerate(tqdm(loader)):
dd = d[0]
dd['img'] = dd['img'].to(device, non_blocking=True)
feat, _, _ = model.forward_local(dd)
feat = feat.flatten(0, 1)
if max_nfeat_per_image is not None and feat.size(0) > max_nfeat_per_image:
feat = feat[torch.randperm(feat.size(0))[:max_nfeat_per_image], :]
if max_nfeat_per_image2 is not None and feat.size(0) > max_nfeat_per_image2:
feat = feat[:max_nfeat_per_image2, :]
features.append(feat)
if tocpu:
features[-1] = features[-1].cpu()
imids.append(i * torch.ones_like(features[-1][:, 0]).to(dtype=torch.int64))
features = torch.cat(features, dim=0)
imids = torch.cat(imids, dim=0)
return features, imids
|