### demo.py # Define model classes for inference. ### import argparse from collections import OrderedDict import json import numpy as np import os import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn import torchvision.transforms as transforms import torchvision.transforms._transforms_video as transforms_video from sklearn.metrics import confusion_matrix from lavila.data import datasets from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop from lavila.models import models from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer) from lavila.models.utils import inflate_positional_embeds from lavila.utils.config import load_cfg from lavila.utils.evaluation_charades import charades_map from lavila.utils.evaluation import get_mean_accuracy from lavila.utils.evaluation_ek100mir import (calculate_k_counts, calculate_IDCG, calculate_mAP, calculate_nDCG) class VideoModel(nn.Module): """ Base model for video understanding based on LaViLa architecture. """ def __init__(self, config): """ Initializes the model. Parameters: config: config file """ super(VideoModel, self).__init__() self.cfg = load_cfg(config) self.model = self.build_model() self.tokenizer = self.get_tokenizer() self.templates = ['{}'] self.dataset = self.cfg['data']['dataset'] self.eval() def build_model(self): cfg = self.cfg if cfg['model'].get('pretrain', False): ckpt_path = cfg['model']['pretrain'] else: raise Exception('no checkpoint found') ckpt = torch.load(ckpt_path, map_location='cpu') state_dict = OrderedDict() for k, v in ckpt['state_dict'].items(): state_dict[k.replace('module.', '')] = v old_args = vars(ckpt['args']) arch = old_args.get('model', 'CLIP_OPENAI_TIMESFORMER_BASE') self.arch = arch cfg['model']['arch'] = arch cfg['model']['norm_embed'] = old_args.get('norm_embed', True) print("=> creating model: {}".format(arch)) model = getattr(models, arch)( pretrained=old_args.get('load_visual_pretrained', None), pretrained2d=old_args.get('load_visual_pretrained', None) is not None, text_use_cls_token=old_args.get('use_cls_token', False), project_embed_dim=old_args.get('project_embed_dim', 256), timesformer_gated_xattn=False, num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']), model_cfg=cfg['model'] ) model.logit_scale.requires_grad = False if torch.cuda.is_available(): model.cuda() if ('TIMESFORMER' in arch or 'EGOVLP' in arch) and cfg['model'].get('inflat_posemb', True): # inflate weight print('=> inflating PE in models due to different frame numbers') state_dict = inflate_positional_embeds( model.state_dict(), state_dict, num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']), load_temporal_fix='bilinear', ) model.load_state_dict(state_dict, strict=True) print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch'])) return model def eval(self): cudnn.benchmark = True for p in self.model.parameters(): p.requires_grad = False self.model.eval() def get_tokenizer(self): arch = self.arch if arch.endswith('DISTILBERT_BASE'): tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') elif arch.endswith('BERT_BASE'): tokenizer = MyBertTokenizer('bert-base-uncased') elif arch.endswith('BERT_LARGE'): tokenizer = MyBertTokenizer('bert-large-uncased') elif arch.endswith('GPT2'): tokenizer = MyGPT2Tokenizer('gpt2') elif arch.endswith('GPT2_MEDIUM'): tokenizer = MyGPT2Tokenizer('gpt2-medium') elif arch.endswith('GPT2_LARGE'): tokenizer = MyGPT2Tokenizer('gpt2-large') elif arch.endswith('GPT2_XL'): tokenizer = MyGPT2Tokenizer('gpt2-xl') else: print("Using SimpleTokenizer because of model '{}'. " "Please check if this is what you want".format(arch)) tokenizer = SimpleTokenizer() return tokenizer class VideoCLSModel(VideoModel): """ Video model for video classification tasks (Charades-Ego, EGTEA). """ def __init__(self, config): super(VideoCLSModel, self).__init__(config) self.labels, self.mapping_vn2act = self.gen_label_map() self.text_features = self.get_text_features() def gen_label_map(self): labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json') if os.path.isfile(labelmap): print(f"=> Loading label maps from {labelmap}") meta = json.load(open(labelmap, 'r')) labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act'] else: from lavila.utils.preprocess import generate_label_map labels, mapping_vn2act = generate_label_map(self.dataset) meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act} meta_dir = f'meta/{self.dataset}' if not os.path.exists(meta_dir): os.makedirs(meta_dir) json.dump(meta, open(f'{meta_dir}/label_map.json', 'w')) print(f"=> Label map is generated and saved to {meta_dir}/label_map.json") return labels, mapping_vn2act def load_data(self, idx=None): print(f"=> Creating dataset") cfg, dataset = self.cfg, self.dataset data_cfg = cfg['data'] crop_size = 224 if '336PX' not in self.arch else 336 val_transform = transforms.Compose([ Permute([3, 0, 1, 2]), # T H W C -> C T H W transforms.Resize(crop_size), transforms.CenterCrop(crop_size), transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]), ]) if idx is None: metadata_val = data_cfg['metadata_val'] else: metadata_val = data_cfg['metadata_val'].format(idx) if dataset in ['charades_ego', 'egtea']: val_dataset = datasets.VideoClassyDataset( dataset, data_cfg['root'], metadata_val, transform=val_transform, is_training=False, label_mapping=self.mapping_vn2act, is_trimmed=False, num_clips=1, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], sparse_sample=data_cfg['sparse_sample'] ) else: raise NotImplementedError val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True, sampler=None, drop_last=False ) return val_loader @torch.no_grad() def get_text_features(self): print('=> Extracting text features') text_features = [] for label in self.labels: if isinstance(label, list): texts = [tmpl.format(lbl) for tmpl in self.templates for lbl in label] else: texts = [tmpl.format(label) for tmpl in self.templates] texts = self.tokenizer(texts) if isinstance(texts, tuple): # Bert-style tokenizer will output both ids and mask texts, masks = texts texts = texts.cuda(non_blocking=True) masks = masks.cuda(non_blocking=True) else: texts = texts.cuda(non_blocking=True) masks = None texts = texts.view(-1, 77).contiguous() masks = masks.view(-1, 77).contiguous() if masks is not None else None if masks is not None: class_embeddings, _ = self.model.encode_text(texts, attention_mask=masks) else: class_embeddings, _ = self.model.encode_text(texts) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) class_embeddings = class_embeddings.mean(dim=0) class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) text_features.append(class_embeddings) text_features = torch.stack(text_features, dim=0) return text_features @torch.no_grad() def forward(self, idx=None): print('=> Start forwarding') val_loader = self.load_data(idx) all_outputs = [] all_targets = [] for i, values in enumerate(val_loader): images = values[0] target = values[1] images = images.cuda(non_blocking=True) target = target.cuda(non_blocking=True) # encode images image_features, _ = self.model.encode_image(images) image_features = image_features / image_features.norm(dim=-1, keepdim=True) # cosine similarity as logits logits_per_image = image_features @ self.text_features.t() logits_per_image = torch.softmax(logits_per_image, dim=1) all_outputs.append(logits_per_image.cpu()) all_targets.append(target.cpu()) all_outputs = torch.cat(all_outputs) all_targets = torch.cat(all_targets) return all_outputs, all_targets @torch.no_grad() def predict(self, idx=0): all_outputs, all_targets = self.forward(idx) preds, targets = all_outputs.numpy(), all_targets.numpy() sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.055)[0][0] #sel = 5 df = pd.DataFrame(self.labels) pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist() gt_action = df.iloc[np.where(targets[0])[0]].values.tolist() pred_action = sorted([x[0] for x in pred_action]) gt_action = sorted([x[0] for x in gt_action]) return pred_action, gt_action @torch.no_grad() def evaluate(self): all_outputs, all_targets = self.forward() preds, targets = all_outputs.numpy(), all_targets.numpy() if self.dataset == 'charades_ego': m_ap, _, m_aps = charades_map(preds, targets) print('mAP = {:.3f}'.format(m_ap)) elif self.dataset == 'egtea': cm = confusion_matrix(targets, preds.argmax(axis=1)) mean_class_acc, acc = get_mean_accuracy(cm) print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc)) else: raise NotImplementedError class VideoMIRModel(VideoModel): """ Video model for video multi-instance retrieval tasks (EK100_MIR). """ def __init__(self, config): super(VideoMIRModel, self).__init__(config) self.narrations = pd.read_csv(self.cfg['data']['narrations']).values[:, 1] self.text_features = self.get_text_features() self.video_samples = pd.read_csv('meta/ek100_mir/sel_t2v.csv').values[:, 0] def load_data(self, idx=None, t2v=False): print(f"=> Creating dataset") cfg, dataset = self.cfg, self.dataset data_cfg = cfg['data'] crop_size = 224 if '336PX' not in self.arch else 336 val_transform = transforms.Compose([ Permute([3, 0, 1, 2]), # T H W C -> C T H W transforms.Resize(crop_size), transforms.CenterCrop(crop_size), transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]), ]) if dataset == 'ek100_mir': if t2v: metadata_val = 'meta/ek100_mir/sel_t2v.csv' self.relevancy_mat_v2t = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_v2t')) self.relevancy_mat_t2v = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_t2v')) val_dataset = datasets.VideoCaptionDatasetCLIP( 'ek100_mir_demo', data_cfg['root'], metadata_val, val_transform, is_training=False, tokenizer=self.tokenizer, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'] ) elif idx is None: metadata_val = data_cfg['metadata_val'] val_dataset = datasets.get_dataset(val_transform, self.tokenizer, cfg, is_training=False) else: metadata_val = data_cfg['metadata_val'].format(idx) self.relevancy_mat_v2t = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_v2t')) self.relevancy_mat_t2v = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_t2v')) val_dataset = datasets.VideoCaptionDatasetCLIP( 'ek100_mir_demo', data_cfg['root'], metadata_val, val_transform, is_training=False, tokenizer=self.tokenizer, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'] ) else: raise NotImplementedError val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True, sampler=None, drop_last=False ) return val_loader @torch.no_grad() def get_text_features(self): print('=> Extracting text features') text_features = [] for text in self.narrations: text = self.tokenizer(text) text = text.cuda(non_blocking=True) text = text.view(-1, 77).contiguous() text_embed, _ = self.model.encode_text(text) text_embed = F.normalize(text_embed, dim=-1).squeeze() text_features.append(text_embed) text_features = torch.stack(text_features, dim=0) return text_features @torch.no_grad() def forward_video(self, text_features=None, idx=None, t2v=False): print('=> Start forwarding') if t2v: val_loader = self.load_data(t2v=t2v) else: val_loader = self.load_data(idx=idx) all_outputs = [] for i, values in enumerate(val_loader): images = values[0].cuda(non_blocking=True) # encode images image_features, _ = self.model.encode_image(images) image_features = image_features / image_features.norm(dim=-1, keepdim=True) if t2v: all_outputs.append(image_features) else: # cosine similarity as logits logits_per_image = image_features @ text_features.t() logits_per_image = torch.softmax(logits_per_image, dim=1) all_outputs.append(logits_per_image.cpu()) all_outputs = torch.cat(all_outputs) if t2v: all_outputs = torch.softmax(text_features @ all_outputs.t(), dim=1).cpu() return all_outputs @torch.no_grad() def predict_v2t(self, idx=0, sid=0): all_outputs = self.forward_video(self.text_features, sid) preds = all_outputs.numpy() relevancy = self.relevancy_mat_v2t[idx] sel = 3 pred_action = self.narrations[(-preds[0]).argsort()[:sel]] gt_action = self.narrations[np.where(relevancy == 1)[0]] return pred_action, gt_action @torch.no_grad() def predict_t2v(self, idx=0, sid=0): text_features = self.text_features[sid].unsqueeze(0) all_outputs = self.forward_video(text_features, t2v=True) preds = all_outputs.numpy() relevancy = self.relevancy_mat_t2v[idx] sel = 1 pred_video = self.video_samples[(-preds[0]).argsort()[:sel]] gt_video = np.where(relevancy == 1)[0] return pred_video, gt_video @torch.no_grad() def evaluate(self): val_loader = self.load_data() cfg, dataset = self.cfg, self.dataset if self.dataset == 'ek100_mir': all_video_embed = [] all_text_embed = [] for i, inputs in enumerate(val_loader): inputs = [tensor.cuda(non_blocking=True) for tensor in inputs] relevancies = inputs.pop() # compute output outputs = self.model( *inputs, use_checkpoint=True, norm_embed=cfg['model']['norm_embed'] ) image_features = outputs['image_embed'] text_features = outputs['text_embed'] all_video_embed.append(image_features.cpu().numpy()) all_text_embed.append(text_features.cpu().numpy()) all_text_embed = np.vstack(all_text_embed) all_video_embed = np.vstack(all_video_embed) similarity_matrix = np.matmul(all_video_embed, all_text_embed.T) similarity_matrix = (similarity_matrix + 1) / 2 video_id = pd.read_csv(cfg['data']['metadata'].replace('train', 'test')).values[:, 0] text_id = pd.read_csv(cfg['data']['metadata'].replace('train', 'test_sentence')).values[:, 0] indexes = [video_id.tolist().index(elem) for elem in text_id] similarity_matrix = similarity_matrix[:, indexes] print(similarity_matrix.shape) rel_matrix = pd.read_pickle( cfg['data']['relevancy_path'] ) vis_map = calculate_mAP(similarity_matrix, rel_matrix) txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T) avg_map = (vis_map + txt_map) / 2 print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, avg_map)) vis_k_counts = calculate_k_counts(rel_matrix) txt_k_counts = calculate_k_counts(rel_matrix.T) vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts) txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts) vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG) txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG) avg_nDCG = (vis_nDCG + txt_nDCG) / 2 print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_nDCG, txt_nDCG, avg_nDCG)) else: raise NotImplementedError def main(): parser = argparse.ArgumentParser(description='Ego-VPA inference', add_help=False) parser.add_argument('--dataset', default='charades_ego', type=str, help='charades_ego/ek100_mir') args = parser.parse_args() if args.dataset in ['charades_ego']: lavila = VideoCLSModel(f"configs/{args.dataset}/zeroshot.yml") egovpa = VideoCLSModel(f"configs/{args.dataset}/egovpa.yml") elif args.dataset == 'ek100_mir': lavila = VideoMIRModel(f"configs/{args.dataset}/zeroshot.yml") egovpa = VideoMIRModel(f"configs/{args.dataset}/egovpa.yml") else: raise NotImplementedError lavila.evaluate() egovpa.evaluate() #egovpa.predict_t2v(idx=0, sid=2119) if __name__ == '__main__': main()