Spaces:
Runtime error
Runtime error
import os | |
import json | |
import numpy as np | |
import clip | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from vbench.utils import load_video, load_dimension_info | |
from vbench.third_party.umt.datasets.video_transforms import ( | |
Compose, Resize, CenterCrop, Normalize, | |
create_random_augment, random_short_side_scale_jitter, | |
random_crop, random_resized_crop_with_shift, random_resized_crop, | |
horizontal_flip, random_short_side_scale_jitter, uniform_crop, | |
) | |
from vbench.third_party.umt.datasets.volume_transforms import ClipToTensor | |
from timm.models import create_model | |
from vbench.third_party.umt.models.modeling_finetune import vit_large_patch16_224 | |
from tqdm import tqdm | |
def build_dict(): | |
CUR_DIR = os.path.dirname(os.path.abspath(__file__)) | |
path = f'{CUR_DIR}/third_party/umt/kinetics_400_categories.txt' | |
results = {} | |
with open(path, 'r') as f: | |
cat_list = f.readlines() | |
cat_list = [c.strip() for c in cat_list] | |
for line in cat_list: | |
cat, number = line.split('\t') | |
results[number] = cat.lower() | |
return results | |
def human_action(umt_path, video_list, device): | |
state_dict = torch.load(umt_path, map_location='cpu') | |
model = create_model( | |
"vit_large_patch16_224", | |
pretrained=False, | |
num_classes=400, | |
all_frames=16, | |
tubelet_size=1, | |
use_learnable_pos_emb=False, | |
fc_drop_rate=0., | |
drop_rate=0., | |
drop_path_rate=0.2, | |
attn_drop_rate=0., | |
drop_block_rate=None, | |
use_checkpoint=False, | |
checkpoint_num=16, | |
use_mean_pooling=True, | |
init_scale=0.001, | |
) | |
data_transform = Compose([ | |
Resize(256, interpolation='bilinear'), | |
CenterCrop(size=(224, 224)), | |
ClipToTensor(), | |
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
model = model.to(device) | |
model.load_state_dict(state_dict, strict=False) | |
model.eval() | |
cat_dict = build_dict() | |
cnt= 0 | |
cor_num = 0 | |
video_results = [] | |
for video_path in tqdm(video_list): | |
video_label_ls = video_path.split('/')[-1].lower().split('-')[0].split("person is ")[-1].split('_')[0] | |
cnt += 1 | |
images = load_video(video_path, data_transform, num_frames=16) | |
images = images.unsqueeze(0) | |
images = images.to(device) | |
with torch.no_grad(): | |
logits = torch.sigmoid(model(images)) | |
results, indices = torch.topk(logits, 5, dim=1) | |
indices = indices.squeeze().tolist() | |
results = results.squeeze().tolist() | |
results = [round(f, 4) for f in results] | |
cat_ls = [] | |
for i in range(5): | |
if results[i] >= 0.85: | |
cat_ls.append(cat_dict[str(indices[i])]) | |
flag = False | |
for cat in cat_ls: | |
if cat == video_label_ls: | |
cor_num += 1 | |
flag = True | |
# print(f"{cnt}: {video_path} correct, top-5: {cat_ls}, logits: {results}", flush=True) | |
break | |
if flag is False: | |
# print(f"{cnt}: {video_path} false, gt: {video_label_ls}, top-5: {cat_ls}, logits: {results}", flush=True) | |
pass | |
video_results.append({'video_path': video_path, 'video_results': flag}) | |
# print(f"cor num: {cor_num}, total: {cnt}") | |
acc = cor_num / cnt | |
return acc, video_results | |
def compute_human_action(json_dir, device, submodules_list): | |
umt_path = submodules_list[0] | |
video_list, _ = load_dimension_info(json_dir, dimension='human_action', lang='en') | |
all_results, video_results = human_action(umt_path, video_list, device) | |
return all_results, video_results | |