VideoCrafterXtend / VBench /vbench /human_action.py
ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
import os
import json
import numpy as np
import clip
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from vbench.utils import load_video, load_dimension_info
from vbench.third_party.umt.datasets.video_transforms import (
Compose, Resize, CenterCrop, Normalize,
create_random_augment, random_short_side_scale_jitter,
random_crop, random_resized_crop_with_shift, random_resized_crop,
horizontal_flip, random_short_side_scale_jitter, uniform_crop,
)
from vbench.third_party.umt.datasets.volume_transforms import ClipToTensor
from timm.models import create_model
from vbench.third_party.umt.models.modeling_finetune import vit_large_patch16_224
from tqdm import tqdm
def build_dict():
CUR_DIR = os.path.dirname(os.path.abspath(__file__))
path = f'{CUR_DIR}/third_party/umt/kinetics_400_categories.txt'
results = {}
with open(path, 'r') as f:
cat_list = f.readlines()
cat_list = [c.strip() for c in cat_list]
for line in cat_list:
cat, number = line.split('\t')
results[number] = cat.lower()
return results
def human_action(umt_path, video_list, device):
state_dict = torch.load(umt_path, map_location='cpu')
model = create_model(
"vit_large_patch16_224",
pretrained=False,
num_classes=400,
all_frames=16,
tubelet_size=1,
use_learnable_pos_emb=False,
fc_drop_rate=0.,
drop_rate=0.,
drop_path_rate=0.2,
attn_drop_rate=0.,
drop_block_rate=None,
use_checkpoint=False,
checkpoint_num=16,
use_mean_pooling=True,
init_scale=0.001,
)
data_transform = Compose([
Resize(256, interpolation='bilinear'),
CenterCrop(size=(224, 224)),
ClipToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model = model.to(device)
model.load_state_dict(state_dict, strict=False)
model.eval()
cat_dict = build_dict()
cnt= 0
cor_num = 0
video_results = []
for video_path in tqdm(video_list):
video_label_ls = video_path.split('/')[-1].lower().split('-')[0].split("person is ")[-1].split('_')[0]
cnt += 1
images = load_video(video_path, data_transform, num_frames=16)
images = images.unsqueeze(0)
images = images.to(device)
with torch.no_grad():
logits = torch.sigmoid(model(images))
results, indices = torch.topk(logits, 5, dim=1)
indices = indices.squeeze().tolist()
results = results.squeeze().tolist()
results = [round(f, 4) for f in results]
cat_ls = []
for i in range(5):
if results[i] >= 0.85:
cat_ls.append(cat_dict[str(indices[i])])
flag = False
for cat in cat_ls:
if cat == video_label_ls:
cor_num += 1
flag = True
# print(f"{cnt}: {video_path} correct, top-5: {cat_ls}, logits: {results}", flush=True)
break
if flag is False:
# print(f"{cnt}: {video_path} false, gt: {video_label_ls}, top-5: {cat_ls}, logits: {results}", flush=True)
pass
video_results.append({'video_path': video_path, 'video_results': flag})
# print(f"cor num: {cor_num}, total: {cnt}")
acc = cor_num / cnt
return acc, video_results
def compute_human_action(json_dir, device, submodules_list):
umt_path = submodules_list[0]
video_list, _ = load_dimension_info(json_dir, dimension='human_action', lang='en')
all_results, video_results = human_action(umt_path, video_list, device)
return all_results, video_results