Spaces:
Paused
Paused
### demo.py | |
# Define model classes for inference. | |
### | |
from collections import OrderedDict | |
import json | |
import numpy as np | |
import os | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.backends.cudnn as cudnn | |
import torchvision.transforms as transforms | |
import torchvision.transforms._transforms_video as transforms_video | |
from sklearn.metrics import confusion_matrix | |
from lavila.data import datasets | |
from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop | |
from lavila.models import models | |
from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer) | |
from lavila.models.utils import inflate_positional_embeds | |
from lavila.utils.config import load_cfg | |
from lavila.utils.evaluation_charades import charades_map | |
from lavila.utils.evaluation import get_mean_accuracy | |
class VideoModel(nn.Module): | |
""" Base model for video understanding based on LaViLa architecture. """ | |
def __init__(self, config): | |
""" Initializes the model. | |
Parameters: | |
config: config file | |
""" | |
super(VideoModel, self).__init__() | |
self.cfg = load_cfg(config) | |
self.model = self.build_model() | |
self.tokenizer = self.get_tokenizer() | |
self.templates = ['{}'] | |
self.dataset = self.cfg['data']['dataset'] | |
self.eval() | |
def build_model(self): | |
cfg = self.cfg | |
if cfg['model'].get('pretrain', False): | |
ckpt_path = cfg['model']['pretrain'] | |
else: | |
raise Exception('no checkpoint found') | |
ckpt = torch.load(ckpt_path, map_location='cpu') | |
state_dict = OrderedDict() | |
for k, v in ckpt['state_dict'].items(): | |
state_dict[k.replace('module.', '')] = v | |
old_args = vars(ckpt['args']) | |
arch = old_args.get('model', 'CLIP_OPENAI_TIMESFORMER_BASE') | |
self.arch = arch | |
cfg['model']['arch'] = arch | |
cfg['model']['norm_embed'] = old_args.get('norm_embed', True) | |
print("=> creating model: {}".format(arch)) | |
model = getattr(models, arch)( | |
pretrained=old_args.get('load_visual_pretrained', None), | |
pretrained2d=old_args.get('load_visual_pretrained', None) is not None, | |
text_use_cls_token=old_args.get('use_cls_token', False), | |
project_embed_dim=old_args.get('project_embed_dim', 256), | |
timesformer_gated_xattn=False, | |
num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']), | |
model_cfg=cfg['model'] | |
) | |
model.logit_scale.requires_grad = False | |
if torch.cuda.is_available(): | |
model.cuda() | |
if ('TIMESFORMER' in arch or 'EGOVLP' in arch) and cfg['model'].get('inflat_posemb', True): | |
# inflate weight | |
print('=> inflating PE in models due to different frame numbers') | |
state_dict = inflate_positional_embeds( | |
model.state_dict(), state_dict, | |
num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']), | |
load_temporal_fix='bilinear', | |
) | |
model.load_state_dict(state_dict, strict=True) | |
print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch'])) | |
return model | |
def eval(self): | |
cudnn.benchmark = True | |
for p in self.model.parameters(): | |
p.requires_grad = False | |
self.model.eval() | |
def get_tokenizer(self): | |
arch = self.arch | |
if arch.endswith('DISTILBERT_BASE'): | |
tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') | |
elif arch.endswith('BERT_BASE'): | |
tokenizer = MyBertTokenizer('bert-base-uncased') | |
elif arch.endswith('BERT_LARGE'): | |
tokenizer = MyBertTokenizer('bert-large-uncased') | |
elif arch.endswith('GPT2'): | |
tokenizer = MyGPT2Tokenizer('gpt2') | |
elif arch.endswith('GPT2_MEDIUM'): | |
tokenizer = MyGPT2Tokenizer('gpt2-medium') | |
elif arch.endswith('GPT2_LARGE'): | |
tokenizer = MyGPT2Tokenizer('gpt2-large') | |
elif arch.endswith('GPT2_XL'): | |
tokenizer = MyGPT2Tokenizer('gpt2-xl') | |
else: | |
print("Using SimpleTokenizer because of model '{}'. " | |
"Please check if this is what you want".format(arch)) | |
tokenizer = SimpleTokenizer() | |
return tokenizer | |
class VideoCLSModel(VideoModel): | |
""" Video model for video classification tasks (Charades-Ego, EGTEA). """ | |
def __init__(self, config): | |
super(VideoCLSModel, self).__init__(config) | |
self.labels, self.mapping_vn2act = self.gen_label_map() | |
self.text_features = self.get_text_features() | |
def gen_label_map(self): | |
labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json') | |
if os.path.isfile(labelmap): | |
print(f"=> Loading label maps from {labelmap}") | |
meta = json.load(open(labelmap, 'r')) | |
labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act'] | |
else: | |
from lavila.utils.preprocess import generate_label_map | |
labels, mapping_vn2act = generate_label_map(self.dataset) | |
meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act} | |
meta_dir = f'meta/{self.dataset}' | |
if not os.path.exists(meta_dir): | |
os.makedirs(meta_dir) | |
json.dump(meta, open(f'{meta_dir}/label_map.json', 'w')) | |
print(f"=> Label map is generated and saved to {meta_dir}/label_map.json") | |
return labels, mapping_vn2act | |
def load_data(self, idx=None): | |
print(f"=> Creating dataset") | |
cfg, dataset = self.cfg, self.dataset | |
data_cfg = cfg['data'] | |
crop_size = 224 if '336PX' not in self.arch else 336 | |
val_transform = transforms.Compose([ | |
Permute([3, 0, 1, 2]), # T H W C -> C T H W | |
transforms.Resize(crop_size), | |
transforms.CenterCrop(crop_size), | |
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]), | |
]) | |
if idx is None: | |
metadata_val = data_cfg['metadata_val'] | |
else: | |
metadata_val = data_cfg['metadata_val'].format(idx) | |
if dataset in ['charades_ego', 'egtea']: | |
val_dataset = datasets.VideoClassyDataset( | |
dataset, data_cfg['root'], metadata_val, | |
transform=val_transform, is_training=False, | |
label_mapping=self.mapping_vn2act, is_trimmed=False, | |
num_clips=1, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], | |
sparse_sample=data_cfg['sparse_sample'] | |
) | |
else: | |
raise NotImplementedError | |
val_loader = torch.utils.data.DataLoader( | |
val_dataset, batch_size=8, shuffle=False, | |
num_workers=4, pin_memory=True, sampler=None, drop_last=False | |
) | |
return val_loader | |
def get_text_features(self): | |
print('=> Extracting text features') | |
text_features = [] | |
for label in self.labels: | |
if isinstance(label, list): | |
texts = [tmpl.format(lbl) for tmpl in self.templates for lbl in label] | |
else: | |
texts = [tmpl.format(label) for tmpl in self.templates] | |
texts = self.tokenizer(texts) | |
if isinstance(texts, tuple): | |
# Bert-style tokenizer will output both ids and mask | |
texts, masks = texts | |
texts = texts.cuda(non_blocking=True) | |
masks = masks.cuda(non_blocking=True) | |
else: | |
texts = texts.cuda(non_blocking=True) | |
masks = None | |
texts = texts.view(-1, 77).contiguous() | |
masks = masks.view(-1, 77).contiguous() if masks is not None else None | |
if masks is not None: | |
class_embeddings, _ = self.model.encode_text(texts, attention_mask=masks) | |
else: | |
class_embeddings, _ = self.model.encode_text(texts) | |
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) | |
class_embeddings = class_embeddings.mean(dim=0) | |
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) | |
text_features.append(class_embeddings) | |
text_features = torch.stack(text_features, dim=0) | |
return text_features | |
def forward(self, idx=None): | |
print('=> Start forwarding') | |
val_loader = self.load_data(idx) | |
all_outputs = [] | |
all_targets = [] | |
for i, values in enumerate(val_loader): | |
images = values[0] | |
target = values[1] | |
images = images.cuda(non_blocking=True) | |
target = target.cuda(non_blocking=True) | |
# encode images | |
image_features, _ = self.model.encode_image(images) | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
# cosine similarity as logits | |
logits_per_image = image_features @ self.text_features.t() | |
logits_per_image = torch.softmax(logits_per_image, dim=1) | |
all_outputs.append(logits_per_image.cpu()) | |
all_targets.append(target.cpu()) | |
all_outputs = torch.cat(all_outputs) | |
all_targets = torch.cat(all_targets) | |
return all_outputs, all_targets | |
def predict(self, idx=0): | |
all_outputs, all_targets = self.forward(idx) | |
preds, targets = all_outputs.numpy(), all_targets.numpy() | |
sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.055)[0][0] | |
#sel = 5 | |
df = pd.DataFrame(self.labels) | |
pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist() | |
gt_action = df.iloc[np.where(targets[0])[0]].values.tolist() | |
pred_action = sorted([x[0] for x in pred_action]) | |
gt_action = sorted([x[0] for x in gt_action]) | |
return pred_action, gt_action | |
def evaluate(self): | |
all_outputs, all_targets = self.forward() | |
preds, targets = all_outputs.numpy(), all_targets.numpy() | |
if self.dataset == 'charades_ego': | |
m_ap, _, m_aps = charades_map(preds, targets) | |
print('mAP = {:.3f}'.format(m_ap)) | |
elif self.dataset == 'egtea': | |
cm = confusion_matrix(targets, preds.argmax(axis=1)) | |
mean_class_acc, acc = get_mean_accuracy(cm) | |
print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc)) | |
else: | |
raise NotImplementedError | |
def main(): | |
lavila = VideoCLSModel("configs/charades_ego/zeroshot.yml") | |
egovpa = VideoCLSModel("configs/charades_ego/egovpa.yml") | |
lavila.evaluate() | |
egovpa.evaluate() | |
if __name__ == '__main__': | |
main() | |