|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import argparse |
|
import numpy as np |
|
import torch |
|
|
|
from mast3r.model import AsymmetricMASt3R |
|
from mast3r.retrieval.model import RetrievalModel, extract_local_features |
|
|
|
try: |
|
import faiss |
|
faiss.StandardGpuResources() |
|
except AttributeError as e: |
|
import asmk.index |
|
|
|
class FaissCpuL2Index(asmk.index.FaissL2Index): |
|
def __init__(self, gpu_id): |
|
super().__init__() |
|
self.gpu_id = gpu_id |
|
|
|
def _faiss_index_flat(self, dim): |
|
"""Return initialized faiss.IndexFlatL2""" |
|
return faiss.IndexFlatL2(dim) |
|
|
|
asmk.index.FaissGpuL2Index = FaissCpuL2Index |
|
|
|
from asmk import asmk_method |
|
|
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser('Retrieval scores from a set of retrieval', add_help=False, allow_abbrev=False) |
|
parser.add_argument('--model', type=str, required=True, |
|
help="shortname of a retrieval model or path to the corresponding .pth") |
|
parser.add_argument('--input', type=str, required=True, |
|
help="directory containing images or a file containing a list of image paths") |
|
parser.add_argument('--outfile', type=str, required=True, help="numpy file where to store the matrix score") |
|
return parser |
|
|
|
|
|
def get_impaths(imlistfile): |
|
with open(imlistfile, 'r') as fid: |
|
impaths = [f for f in imlistfile.read().splitlines() if not f.startswith('#') |
|
and len(f) > 0] |
|
return impaths |
|
|
|
|
|
def get_impaths_from_imdir(imdir, extensions=['png', 'jpg', 'PNG', 'JPG']): |
|
assert os.path.isdir(imdir) |
|
impaths = [os.path.join(imdir, f) for f in sorted(os.listdir(imdir)) if any(f.endswith(ext) for ext in extensions)] |
|
return impaths |
|
|
|
|
|
def get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile): |
|
if os.path.isfile(input_imdir_or_imlistfile): |
|
return get_impaths(input_imdir_or_imlistfile) |
|
else: |
|
return get_impaths_from_imdir(input_imdir_or_imlistfile) |
|
|
|
|
|
class Retriever(object): |
|
def __init__(self, modelname, backbone=None, device='cuda'): |
|
|
|
assert os.path.isfile(modelname), modelname |
|
print(f'Loading retrieval model from {modelname}') |
|
ckpt = torch.load(modelname, 'cpu') |
|
ckpt_args = ckpt['args'] |
|
if backbone is None: |
|
backbone = AsymmetricMASt3R.from_pretrained(ckpt_args.pretrained) |
|
self.model = RetrievalModel( |
|
backbone, freeze_backbone=ckpt_args.freeze_backbone, prewhiten=ckpt_args.prewhiten, |
|
hdims=list(map(int, ckpt_args.hdims.split('_'))) if len(ckpt_args.hdims) > 0 else "", |
|
residual=getattr(ckpt_args, 'residual', False), postwhiten=ckpt_args.postwhiten, |
|
featweights=ckpt_args.featweights, nfeat=ckpt_args.nfeat |
|
).to(device) |
|
self.device = device |
|
msg = self.model.load_state_dict(ckpt['model'], strict=False) |
|
assert all(k.startswith('backbone') for k in msg.missing_keys) |
|
assert len(msg.unexpected_keys) == 0 |
|
self.imsize = ckpt_args.imsize |
|
|
|
|
|
dname, bname = os.path.split(modelname) |
|
bname_splits = bname.split('_') |
|
cache_codebook_fname = os.path.join(dname, '_'.join(bname_splits[:-1]) + '_codebook.pkl') |
|
assert os.path.isfile(cache_codebook_fname), cache_codebook_fname |
|
asmk_params = {'index': {'gpu_id': 0}, 'train_codebook': {'codebook': {'size': '64k'}}, |
|
'build_ivf': {'kernel': {'binary': True}, 'ivf': {'use_idf': False}, |
|
'quantize': {'multiple_assignment': 1}, 'aggregate': {}}, |
|
'query_ivf': {'quantize': {'multiple_assignment': 5}, 'aggregate': {}, |
|
'search': {'topk': None}, |
|
'similarity': {'similarity_threshold': 0.0, 'alpha': 3.0}}} |
|
asmk_params['train_codebook']['codebook']['size'] = ckpt_args.nclusters |
|
self.asmk = asmk_method.ASMKMethod.initialize_untrained(asmk_params) |
|
self.asmk = self.asmk.train_codebook(None, cache_path=cache_codebook_fname) |
|
|
|
def __call__(self, input_imdir_or_imlistfile, outfile=None): |
|
|
|
if isinstance(input_imdir_or_imlistfile, str): |
|
impaths = get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile) |
|
else: |
|
impaths = input_imdir_or_imlistfile |
|
print(f'Found {len(impaths)} images') |
|
|
|
|
|
feat, ids = extract_local_features(self.model, impaths, self.imsize, tocpu=True, device=self.device) |
|
feat = feat.cpu().numpy() |
|
ids = ids.cpu().numpy() |
|
asmk_dataset = self.asmk.build_ivf(feat, ids) |
|
|
|
|
|
metadata, query_ids, ranks, ranked_scores = asmk_dataset.query_ivf(feat, ids) |
|
|
|
|
|
|
|
scores = np.empty_like(ranked_scores) |
|
scores[np.arange(ranked_scores.shape[0])[:, None], ranks] = ranked_scores |
|
|
|
|
|
if outfile is not None: |
|
if os.path.isdir(os.path.dirname(outfile)): |
|
os.makedirs(os.path.dirname(outfile), exist_ok=True) |
|
np.save(outfile, scores) |
|
print(f'Scores matrix saved in {outfile}') |
|
return scores |
|
|