Upload 29 files
Browse files- README.md +6 -6
- app_mir.py +115 -0
- cache/cache.txt +0 -0
- configs/base.yml +23 -0
- configs/ek100_mir/egovpa.yml +39 -0
- configs/ek100_mir/zeroshot.yml +21 -0
- demo.py +461 -0
- lavila/data/datasets.py +542 -0
- lavila/data/video_transforms.py +186 -0
- lavila/models/bpe_simple_vocab_16e6.txt.gz +3 -0
- lavila/models/distributed_utils.py +89 -0
- lavila/models/models.py +252 -0
- lavila/models/openai_clip.py +237 -0
- lavila/models/openai_model.py +535 -0
- lavila/models/prompt_tuning.py +291 -0
- lavila/models/timesformer.py +650 -0
- lavila/models/tokenizer.py +239 -0
- lavila/models/utils.py +110 -0
- lavila/utils/config.py +18 -0
- lavila/utils/evaluation.py +36 -0
- lavila/utils/evaluation_charades.py +56 -0
- lavila/utils/evaluation_ek100mir.py +201 -0
- lavila/utils/preprocess.py +86 -0
- meta/ek100_mir/EPIC_100_retrieval_test_sentence.csv +0 -0
- meta/ek100_mir/relevancy_sel_t2v.npy +3 -0
- meta/ek100_mir/relevancy_sel_v2t.npy +3 -0
- meta/ek100_mir/sel_t2v.csv +10 -0
- meta/ek100_mir/sel_v2t.csv +3 -0
- requirements.txt +14 -0
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Ego VPA
|
3 |
+
emoji: 📉
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.17.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app_mir.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### app_mir.py
|
2 |
+
# User interface for the demo.
|
3 |
+
###
|
4 |
+
|
5 |
+
import os, pdb
|
6 |
+
import pandas as pd
|
7 |
+
import gradio as gr
|
8 |
+
from gradio_rich_textbox import RichTextbox
|
9 |
+
|
10 |
+
from demo import VideoMIRModel
|
11 |
+
|
12 |
+
|
13 |
+
def load_v2t_samples(data_root):
|
14 |
+
sample_videos = []
|
15 |
+
df = pd.read_csv("meta/ek100_mir/sel_v2t.csv", header=None)
|
16 |
+
idx2sid = {}
|
17 |
+
for i, x in enumerate(df[0].values):
|
18 |
+
sample_videos.append(f'{data_root}/video/gif/{x}.gif')
|
19 |
+
idx2sid[i] = x
|
20 |
+
|
21 |
+
return sample_videos, idx2sid
|
22 |
+
|
23 |
+
def load_t2v_samples(data_root):
|
24 |
+
sample_text = ['cut the sausage', 'stir vegetables into salmon', 'rinse cutting board']
|
25 |
+
idx2sid = {0: 2119, 1: 1730, 2: 1276}
|
26 |
+
|
27 |
+
return sample_text, idx2sid
|
28 |
+
|
29 |
+
|
30 |
+
def format_pred(pred, gt):
|
31 |
+
tp = '[color=green]{}[/color]'
|
32 |
+
fp = '[color=red]{}[/color]'
|
33 |
+
fmt_pred = []
|
34 |
+
for x in pred:
|
35 |
+
if x in gt:
|
36 |
+
fmt_pred.append(tp.format(x))
|
37 |
+
else:
|
38 |
+
fmt_pred.append(fp.format(x))
|
39 |
+
|
40 |
+
return ', '.join(fmt_pred)
|
41 |
+
|
42 |
+
def main():
|
43 |
+
lavila = VideoMIRModel("configs/ek100_mir/zeroshot.yml")
|
44 |
+
egovpa = VideoMIRModel("configs/ek100_mir/egovpa.yml")
|
45 |
+
v2t_samples, idx2sid_v2t = load_v2t_samples('data/ek100_mir')
|
46 |
+
t2v_samples, idx2sid_t2v = load_t2v_samples('data/ek100_mir')
|
47 |
+
print(v2t_samples)
|
48 |
+
|
49 |
+
def predict_v2t(idx):
|
50 |
+
sid = idx2sid_v2t[idx]
|
51 |
+
zeroshot_action, gt_action = lavila.predict_v2t(idx, sid)
|
52 |
+
egovpa_action, gt_action = egovpa.predict_v2t(idx, sid)
|
53 |
+
zeroshot_action = format_pred(zeroshot_action, gt_action)
|
54 |
+
egovpa_action = format_pred(egovpa_action, gt_action)
|
55 |
+
|
56 |
+
return gt_action, zeroshot_action, egovpa_action
|
57 |
+
|
58 |
+
def predict_t2v(idx):
|
59 |
+
sid = idx2sid_t2v[idx]
|
60 |
+
zeroshot_video, gt_video = lavila.predict_t2v(idx, sid)
|
61 |
+
egovpa_video, gt_video = egovpa.predict_t2v(idx, sid)
|
62 |
+
|
63 |
+
return gt_video, zeroshot_video, egovpa_video
|
64 |
+
|
65 |
+
with gr.Blocks() as demo:
|
66 |
+
with gr.Tab("Video-to-text retrieval"):
|
67 |
+
gr.Markdown(
|
68 |
+
"""
|
69 |
+
# Ego-VPA Demo
|
70 |
+
Choose a sample video and click predict to view the text queried by the selected video
|
71 |
+
(<span style="color:green">correct</span>/<span style="color:red">incorrect</span>).
|
72 |
+
"""
|
73 |
+
)
|
74 |
+
|
75 |
+
with gr.Row():
|
76 |
+
with gr.Column():
|
77 |
+
video = gr.Image(label="video query", height='300px', interactive=False)
|
78 |
+
with gr.Column():
|
79 |
+
idx = gr.Number(label="Idx", visible=False)
|
80 |
+
label = RichTextbox(label="Ground Truth", visible=False)
|
81 |
+
zeroshot = RichTextbox(label="LaViLa (zero-shot) prediction")
|
82 |
+
ours = RichTextbox(label="Ego-VPA prediction")
|
83 |
+
btn = gr.Button("Predict", variant="primary")
|
84 |
+
btn.click(predict_v2t, inputs=[idx], outputs=[label, zeroshot, ours])
|
85 |
+
gr.Examples(examples=[[i, x] for i, x in enumerate(v2t_samples)], inputs=[idx, video])
|
86 |
+
|
87 |
+
with gr.Tab("Text-to-video retrieval"):
|
88 |
+
gr.Markdown(
|
89 |
+
"""
|
90 |
+
# Ego-VPA Demo
|
91 |
+
Choose a sample narration and click predict to view the video queried by the selected text.
|
92 |
+
"""
|
93 |
+
)
|
94 |
+
|
95 |
+
with gr.Row():
|
96 |
+
with gr.Column():
|
97 |
+
text = gr.Text(label="text query")
|
98 |
+
with gr.Column():
|
99 |
+
idx = gr.Number(label="Idx", visible=False)
|
100 |
+
zeroshot = gr.Textbox(label="LaViLa (zero-shot) prediction")
|
101 |
+
#zeroshot = gr.Gallery(label="LaViLa (zero-shot) prediction", columns=[3], rows=[1], object_fit="contain", height="auto")
|
102 |
+
ours = gr.Textbox(label="Ego-VPA prediction")
|
103 |
+
#ours = gr.Gallery(label="Ego-VPA prediction", columns=[3], rows=[1], object_fit="contain", height="auto")
|
104 |
+
btn = gr.Button("Predict", variant="primary")
|
105 |
+
btn.click(predict_t2v, inputs=[idx], outputs=[label, zeroshot, ours])
|
106 |
+
gr.Examples(examples=[[i, x] for i, x in enumerate(t2v_samples)], inputs=[idx, text])
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
demo.launch(share=True)
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
main()
|
115 |
+
|
cache/cache.txt
ADDED
File without changes
|
configs/base.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
model:
|
3 |
+
pretrain: ""
|
4 |
+
resume: ""
|
5 |
+
timesformer_freeze_space: false
|
6 |
+
drop_path_rate: 0.1
|
7 |
+
dropout_ratio: 0.5
|
8 |
+
freeze_vis_backbone: false
|
9 |
+
freeze_txt_backbone: false
|
10 |
+
use_vn_classifier: false
|
11 |
+
|
12 |
+
data:
|
13 |
+
dataset: ek100_mir
|
14 |
+
root: datasets/EK100/video_ht256px
|
15 |
+
metadata: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_train.csv
|
16 |
+
metadata_val: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv
|
17 |
+
relevancy_path: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl
|
18 |
+
clip_length: 16
|
19 |
+
clip_stride: 4
|
20 |
+
sparse_sample: false
|
21 |
+
num_crops: 1
|
22 |
+
num_clips: 1
|
23 |
+
|
configs/ek100_mir/egovpa.yml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
model:
|
3 |
+
pretrain: ../ckpt/ek100mir.pt
|
4 |
+
freeze_vis_backbone: true
|
5 |
+
freeze_txt_backbone: true
|
6 |
+
inflat_posemb: true # false for cascade models; true for single-stage models (default: true)
|
7 |
+
num_frames: 16
|
8 |
+
text_prompt:
|
9 |
+
n_ctx: 8
|
10 |
+
use_bank: true
|
11 |
+
visual_prompt:
|
12 |
+
num_layers: 12
|
13 |
+
prompt_dim: 512
|
14 |
+
num_tokens: 128
|
15 |
+
deep: true
|
16 |
+
deep_shared: false
|
17 |
+
split_st: false
|
18 |
+
pt_spt: true
|
19 |
+
pt_tmp: false
|
20 |
+
style: VoP_c_pool
|
21 |
+
n_seg: 16 # number of segments per video (n_seg=clip_length -> 1 frame/seg)
|
22 |
+
K_s: 8 # boundary of intra-frame/inter-frame attention (VoP_f+c)
|
23 |
+
pool:
|
24 |
+
size: 10
|
25 |
+
|
26 |
+
|
27 |
+
data:
|
28 |
+
dataset: ek100_mir
|
29 |
+
#root: /data/EK100/video_ht256px
|
30 |
+
#metadata: /data/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_train.csv
|
31 |
+
#metadata_val: /data/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv
|
32 |
+
#relevancy_path: /data/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl
|
33 |
+
root: data/ek100_mir/video
|
34 |
+
metadata_val: data/ek100_mir/csv/{}.csv
|
35 |
+
relevancy_path: meta/ek100_mir/relevancy_sel.npy
|
36 |
+
narrations: meta/ek100_mir/EPIC_100_retrieval_test_sentence.csv
|
37 |
+
clip_length: 16
|
38 |
+
|
39 |
+
|
configs/ek100_mir/zeroshot.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
model:
|
3 |
+
pretrain: /store/nosnap/results/LaViLa/checkpoints/pt/TSF-B/lavila_best.pth
|
4 |
+
freeze_vis_backbone: true
|
5 |
+
freeze_txt_backbone: true
|
6 |
+
inflat_posemb: true # false for cascade models; true for single-stage models (default: true)
|
7 |
+
num_frames: 16
|
8 |
+
|
9 |
+
data:
|
10 |
+
dataset: ek100_mir
|
11 |
+
#root: /data/EK100/video_ht256px
|
12 |
+
#metadata: /data/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_train.csv
|
13 |
+
#metadata_val: /data/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv
|
14 |
+
#relevancy_path: /data/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl
|
15 |
+
root: data/ek100_mir/video
|
16 |
+
metadata_val: data/ek100_mir/csv/{}.csv
|
17 |
+
relevancy_path: meta/ek100_mir/relevancy_sel.npy
|
18 |
+
narrations: meta/ek100_mir/EPIC_100_retrieval_test_sentence.csv
|
19 |
+
clip_length: 16
|
20 |
+
|
21 |
+
|
demo.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### demo.py
|
2 |
+
# Define model classes for inference.
|
3 |
+
###
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
from collections import OrderedDict
|
7 |
+
import json
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torch.backends.cudnn as cudnn
|
16 |
+
import torchvision.transforms as transforms
|
17 |
+
import torchvision.transforms._transforms_video as transforms_video
|
18 |
+
from sklearn.metrics import confusion_matrix
|
19 |
+
|
20 |
+
from lavila.data import datasets
|
21 |
+
from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop
|
22 |
+
from lavila.models import models
|
23 |
+
from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer)
|
24 |
+
from lavila.models.utils import inflate_positional_embeds
|
25 |
+
from lavila.utils.config import load_cfg
|
26 |
+
from lavila.utils.evaluation_charades import charades_map
|
27 |
+
from lavila.utils.evaluation import get_mean_accuracy
|
28 |
+
from lavila.utils.evaluation_ek100mir import (calculate_k_counts, calculate_IDCG, calculate_mAP, calculate_nDCG)
|
29 |
+
|
30 |
+
|
31 |
+
class VideoModel(nn.Module):
|
32 |
+
""" Base model for video understanding based on LaViLa architecture. """
|
33 |
+
def __init__(self, config):
|
34 |
+
""" Initializes the model.
|
35 |
+
Parameters:
|
36 |
+
config: config file
|
37 |
+
"""
|
38 |
+
super(VideoModel, self).__init__()
|
39 |
+
self.cfg = load_cfg(config)
|
40 |
+
self.model = self.build_model()
|
41 |
+
self.tokenizer = self.get_tokenizer()
|
42 |
+
self.templates = ['{}']
|
43 |
+
self.dataset = self.cfg['data']['dataset']
|
44 |
+
self.eval()
|
45 |
+
|
46 |
+
def build_model(self):
|
47 |
+
cfg = self.cfg
|
48 |
+
if cfg['model'].get('pretrain', False):
|
49 |
+
ckpt_path = cfg['model']['pretrain']
|
50 |
+
else:
|
51 |
+
raise Exception('no checkpoint found')
|
52 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
53 |
+
|
54 |
+
state_dict = OrderedDict()
|
55 |
+
for k, v in ckpt['state_dict'].items():
|
56 |
+
state_dict[k.replace('module.', '')] = v
|
57 |
+
|
58 |
+
old_args = vars(ckpt['args'])
|
59 |
+
arch = old_args.get('model', 'CLIP_OPENAI_TIMESFORMER_BASE')
|
60 |
+
self.arch = arch
|
61 |
+
cfg['model']['arch'] = arch
|
62 |
+
cfg['model']['norm_embed'] = old_args.get('norm_embed', True)
|
63 |
+
print("=> creating model: {}".format(arch))
|
64 |
+
model = getattr(models, arch)(
|
65 |
+
pretrained=old_args.get('load_visual_pretrained', None),
|
66 |
+
pretrained2d=old_args.get('load_visual_pretrained', None) is not None,
|
67 |
+
text_use_cls_token=old_args.get('use_cls_token', False),
|
68 |
+
project_embed_dim=old_args.get('project_embed_dim', 256),
|
69 |
+
timesformer_gated_xattn=False,
|
70 |
+
num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']),
|
71 |
+
model_cfg=cfg['model']
|
72 |
+
)
|
73 |
+
model.logit_scale.requires_grad = False
|
74 |
+
|
75 |
+
if torch.cuda.is_available():
|
76 |
+
model.cuda()
|
77 |
+
|
78 |
+
if ('TIMESFORMER' in arch or 'EGOVLP' in arch) and cfg['model'].get('inflat_posemb', True):
|
79 |
+
# inflate weight
|
80 |
+
print('=> inflating PE in models due to different frame numbers')
|
81 |
+
state_dict = inflate_positional_embeds(
|
82 |
+
model.state_dict(), state_dict,
|
83 |
+
num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']),
|
84 |
+
load_temporal_fix='bilinear',
|
85 |
+
)
|
86 |
+
model.load_state_dict(state_dict, strict=True)
|
87 |
+
print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch']))
|
88 |
+
|
89 |
+
return model
|
90 |
+
|
91 |
+
def eval(self):
|
92 |
+
cudnn.benchmark = True
|
93 |
+
for p in self.model.parameters():
|
94 |
+
p.requires_grad = False
|
95 |
+
self.model.eval()
|
96 |
+
|
97 |
+
def get_tokenizer(self):
|
98 |
+
arch = self.arch
|
99 |
+
if arch.endswith('DISTILBERT_BASE'):
|
100 |
+
tokenizer = MyDistilBertTokenizer('distilbert-base-uncased')
|
101 |
+
elif arch.endswith('BERT_BASE'):
|
102 |
+
tokenizer = MyBertTokenizer('bert-base-uncased')
|
103 |
+
elif arch.endswith('BERT_LARGE'):
|
104 |
+
tokenizer = MyBertTokenizer('bert-large-uncased')
|
105 |
+
elif arch.endswith('GPT2'):
|
106 |
+
tokenizer = MyGPT2Tokenizer('gpt2')
|
107 |
+
elif arch.endswith('GPT2_MEDIUM'):
|
108 |
+
tokenizer = MyGPT2Tokenizer('gpt2-medium')
|
109 |
+
elif arch.endswith('GPT2_LARGE'):
|
110 |
+
tokenizer = MyGPT2Tokenizer('gpt2-large')
|
111 |
+
elif arch.endswith('GPT2_XL'):
|
112 |
+
tokenizer = MyGPT2Tokenizer('gpt2-xl')
|
113 |
+
else:
|
114 |
+
print("Using SimpleTokenizer because of model '{}'. "
|
115 |
+
"Please check if this is what you want".format(arch))
|
116 |
+
tokenizer = SimpleTokenizer()
|
117 |
+
|
118 |
+
return tokenizer
|
119 |
+
|
120 |
+
|
121 |
+
class VideoCLSModel(VideoModel):
|
122 |
+
""" Video model for video classification tasks (Charades-Ego, EGTEA). """
|
123 |
+
def __init__(self, config):
|
124 |
+
super(VideoCLSModel, self).__init__(config)
|
125 |
+
self.labels, self.mapping_vn2act = self.gen_label_map()
|
126 |
+
self.text_features = self.get_text_features()
|
127 |
+
|
128 |
+
def gen_label_map(self):
|
129 |
+
labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json')
|
130 |
+
if os.path.isfile(labelmap):
|
131 |
+
print(f"=> Loading label maps from {labelmap}")
|
132 |
+
meta = json.load(open(labelmap, 'r'))
|
133 |
+
labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act']
|
134 |
+
else:
|
135 |
+
from lavila.utils.preprocess import generate_label_map
|
136 |
+
labels, mapping_vn2act = generate_label_map(self.dataset)
|
137 |
+
meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act}
|
138 |
+
meta_dir = f'meta/{self.dataset}'
|
139 |
+
if not os.path.exists(meta_dir):
|
140 |
+
os.makedirs(meta_dir)
|
141 |
+
json.dump(meta, open(f'{meta_dir}/label_map.json', 'w'))
|
142 |
+
print(f"=> Label map is generated and saved to {meta_dir}/label_map.json")
|
143 |
+
|
144 |
+
return labels, mapping_vn2act
|
145 |
+
|
146 |
+
def load_data(self, idx=None):
|
147 |
+
print(f"=> Creating dataset")
|
148 |
+
cfg, dataset = self.cfg, self.dataset
|
149 |
+
data_cfg = cfg['data']
|
150 |
+
crop_size = 224 if '336PX' not in self.arch else 336
|
151 |
+
val_transform = transforms.Compose([
|
152 |
+
Permute([3, 0, 1, 2]), # T H W C -> C T H W
|
153 |
+
transforms.Resize(crop_size),
|
154 |
+
transforms.CenterCrop(crop_size),
|
155 |
+
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]),
|
156 |
+
])
|
157 |
+
|
158 |
+
if idx is None:
|
159 |
+
metadata_val = data_cfg['metadata_val']
|
160 |
+
else:
|
161 |
+
metadata_val = data_cfg['metadata_val'].format(idx)
|
162 |
+
if dataset in ['charades_ego', 'egtea']:
|
163 |
+
val_dataset = datasets.VideoClassyDataset(
|
164 |
+
dataset, data_cfg['root'], metadata_val,
|
165 |
+
transform=val_transform, is_training=False,
|
166 |
+
label_mapping=self.mapping_vn2act, is_trimmed=False,
|
167 |
+
num_clips=1, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
|
168 |
+
sparse_sample=data_cfg['sparse_sample']
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
raise NotImplementedError
|
172 |
+
|
173 |
+
val_loader = torch.utils.data.DataLoader(
|
174 |
+
val_dataset, batch_size=8, shuffle=False,
|
175 |
+
num_workers=4, pin_memory=True, sampler=None, drop_last=False
|
176 |
+
)
|
177 |
+
|
178 |
+
return val_loader
|
179 |
+
|
180 |
+
@torch.no_grad()
|
181 |
+
def get_text_features(self):
|
182 |
+
print('=> Extracting text features')
|
183 |
+
text_features = []
|
184 |
+
for label in self.labels:
|
185 |
+
if isinstance(label, list):
|
186 |
+
texts = [tmpl.format(lbl) for tmpl in self.templates for lbl in label]
|
187 |
+
else:
|
188 |
+
texts = [tmpl.format(label) for tmpl in self.templates]
|
189 |
+
texts = self.tokenizer(texts)
|
190 |
+
if isinstance(texts, tuple):
|
191 |
+
# Bert-style tokenizer will output both ids and mask
|
192 |
+
texts, masks = texts
|
193 |
+
texts = texts.cuda(non_blocking=True)
|
194 |
+
masks = masks.cuda(non_blocking=True)
|
195 |
+
else:
|
196 |
+
texts = texts.cuda(non_blocking=True)
|
197 |
+
masks = None
|
198 |
+
texts = texts.view(-1, 77).contiguous()
|
199 |
+
masks = masks.view(-1, 77).contiguous() if masks is not None else None
|
200 |
+
if masks is not None:
|
201 |
+
class_embeddings, _ = self.model.encode_text(texts, attention_mask=masks)
|
202 |
+
else:
|
203 |
+
class_embeddings, _ = self.model.encode_text(texts)
|
204 |
+
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
205 |
+
class_embeddings = class_embeddings.mean(dim=0)
|
206 |
+
class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
|
207 |
+
|
208 |
+
text_features.append(class_embeddings)
|
209 |
+
text_features = torch.stack(text_features, dim=0)
|
210 |
+
|
211 |
+
return text_features
|
212 |
+
|
213 |
+
@torch.no_grad()
|
214 |
+
def forward(self, idx=None):
|
215 |
+
print('=> Start forwarding')
|
216 |
+
val_loader = self.load_data(idx)
|
217 |
+
all_outputs = []
|
218 |
+
all_targets = []
|
219 |
+
for i, values in enumerate(val_loader):
|
220 |
+
images = values[0]
|
221 |
+
target = values[1]
|
222 |
+
|
223 |
+
images = images.cuda(non_blocking=True)
|
224 |
+
target = target.cuda(non_blocking=True)
|
225 |
+
|
226 |
+
# encode images
|
227 |
+
image_features, _ = self.model.encode_image(images)
|
228 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
229 |
+
# cosine similarity as logits
|
230 |
+
logits_per_image = image_features @ self.text_features.t()
|
231 |
+
logits_per_image = torch.softmax(logits_per_image, dim=1)
|
232 |
+
|
233 |
+
all_outputs.append(logits_per_image.cpu())
|
234 |
+
all_targets.append(target.cpu())
|
235 |
+
|
236 |
+
all_outputs = torch.cat(all_outputs)
|
237 |
+
all_targets = torch.cat(all_targets)
|
238 |
+
|
239 |
+
return all_outputs, all_targets
|
240 |
+
|
241 |
+
@torch.no_grad()
|
242 |
+
def predict(self, idx=0):
|
243 |
+
all_outputs, all_targets = self.forward(idx)
|
244 |
+
preds, targets = all_outputs.numpy(), all_targets.numpy()
|
245 |
+
sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.055)[0][0]
|
246 |
+
#sel = 5
|
247 |
+
df = pd.DataFrame(self.labels)
|
248 |
+
pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist()
|
249 |
+
gt_action = df.iloc[np.where(targets[0])[0]].values.tolist()
|
250 |
+
pred_action = sorted([x[0] for x in pred_action])
|
251 |
+
gt_action = sorted([x[0] for x in gt_action])
|
252 |
+
return pred_action, gt_action
|
253 |
+
|
254 |
+
@torch.no_grad()
|
255 |
+
def evaluate(self):
|
256 |
+
all_outputs, all_targets = self.forward()
|
257 |
+
preds, targets = all_outputs.numpy(), all_targets.numpy()
|
258 |
+
if self.dataset == 'charades_ego':
|
259 |
+
m_ap, _, m_aps = charades_map(preds, targets)
|
260 |
+
print('mAP = {:.3f}'.format(m_ap))
|
261 |
+
elif self.dataset == 'egtea':
|
262 |
+
cm = confusion_matrix(targets, preds.argmax(axis=1))
|
263 |
+
mean_class_acc, acc = get_mean_accuracy(cm)
|
264 |
+
print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc))
|
265 |
+
else:
|
266 |
+
raise NotImplementedError
|
267 |
+
|
268 |
+
|
269 |
+
class VideoMIRModel(VideoModel):
|
270 |
+
""" Video model for video multi-instance retrieval tasks (EK100_MIR). """
|
271 |
+
def __init__(self, config):
|
272 |
+
super(VideoMIRModel, self).__init__(config)
|
273 |
+
self.narrations = pd.read_csv(self.cfg['data']['narrations']).values[:, 1]
|
274 |
+
self.text_features = self.get_text_features()
|
275 |
+
self.video_samples = pd.read_csv('meta/ek100_mir/sel_t2v.csv').values[:, 0]
|
276 |
+
|
277 |
+
def load_data(self, idx=None, t2v=False):
|
278 |
+
print(f"=> Creating dataset")
|
279 |
+
cfg, dataset = self.cfg, self.dataset
|
280 |
+
data_cfg = cfg['data']
|
281 |
+
crop_size = 224 if '336PX' not in self.arch else 336
|
282 |
+
val_transform = transforms.Compose([
|
283 |
+
Permute([3, 0, 1, 2]), # T H W C -> C T H W
|
284 |
+
transforms.Resize(crop_size),
|
285 |
+
transforms.CenterCrop(crop_size),
|
286 |
+
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]),
|
287 |
+
])
|
288 |
+
|
289 |
+
if dataset == 'ek100_mir':
|
290 |
+
if t2v:
|
291 |
+
metadata_val = 'meta/ek100_mir/sel_t2v.csv'
|
292 |
+
self.relevancy_mat_v2t = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_v2t'))
|
293 |
+
self.relevancy_mat_t2v = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_t2v'))
|
294 |
+
val_dataset = datasets.VideoCaptionDatasetCLIP(
|
295 |
+
'ek100_mir_demo', data_cfg['root'], metadata_val, val_transform,
|
296 |
+
is_training=False, tokenizer=self.tokenizer,
|
297 |
+
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride']
|
298 |
+
)
|
299 |
+
elif idx is None:
|
300 |
+
metadata_val = data_cfg['metadata_val']
|
301 |
+
val_dataset = datasets.get_dataset(val_transform, self.tokenizer, cfg, is_training=False)
|
302 |
+
else:
|
303 |
+
metadata_val = data_cfg['metadata_val'].format(idx)
|
304 |
+
self.relevancy_mat_v2t = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_v2t'))
|
305 |
+
self.relevancy_mat_t2v = np.load(data_cfg['relevancy_path'].replace('sel', 'sel_t2v'))
|
306 |
+
val_dataset = datasets.VideoCaptionDatasetCLIP(
|
307 |
+
'ek100_mir_demo', data_cfg['root'], metadata_val, val_transform,
|
308 |
+
is_training=False, tokenizer=self.tokenizer,
|
309 |
+
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride']
|
310 |
+
)
|
311 |
+
else:
|
312 |
+
raise NotImplementedError
|
313 |
+
|
314 |
+
val_loader = torch.utils.data.DataLoader(
|
315 |
+
val_dataset, batch_size=8, shuffle=False,
|
316 |
+
num_workers=4, pin_memory=True, sampler=None, drop_last=False
|
317 |
+
)
|
318 |
+
|
319 |
+
return val_loader
|
320 |
+
|
321 |
+
@torch.no_grad()
|
322 |
+
def get_text_features(self):
|
323 |
+
print('=> Extracting text features')
|
324 |
+
text_features = []
|
325 |
+
for text in self.narrations:
|
326 |
+
text = self.tokenizer(text)
|
327 |
+
text = text.cuda(non_blocking=True)
|
328 |
+
text = text.view(-1, 77).contiguous()
|
329 |
+
text_embed, _ = self.model.encode_text(text)
|
330 |
+
text_embed = F.normalize(text_embed, dim=-1).squeeze()
|
331 |
+
text_features.append(text_embed)
|
332 |
+
|
333 |
+
text_features = torch.stack(text_features, dim=0)
|
334 |
+
|
335 |
+
return text_features
|
336 |
+
|
337 |
+
@torch.no_grad()
|
338 |
+
def forward_video(self, text_features=None, idx=None, t2v=False):
|
339 |
+
print('=> Start forwarding')
|
340 |
+
if t2v:
|
341 |
+
val_loader = self.load_data(t2v=t2v)
|
342 |
+
else:
|
343 |
+
val_loader = self.load_data(idx=idx)
|
344 |
+
all_outputs = []
|
345 |
+
for i, values in enumerate(val_loader):
|
346 |
+
images = values[0].cuda(non_blocking=True)
|
347 |
+
|
348 |
+
# encode images
|
349 |
+
image_features, _ = self.model.encode_image(images)
|
350 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
351 |
+
if t2v:
|
352 |
+
all_outputs.append(image_features)
|
353 |
+
else:
|
354 |
+
# cosine similarity as logits
|
355 |
+
logits_per_image = image_features @ text_features.t()
|
356 |
+
logits_per_image = torch.softmax(logits_per_image, dim=1)
|
357 |
+
all_outputs.append(logits_per_image.cpu())
|
358 |
+
|
359 |
+
all_outputs = torch.cat(all_outputs)
|
360 |
+
if t2v:
|
361 |
+
all_outputs = torch.softmax(text_features @ all_outputs.t(), dim=1).cpu()
|
362 |
+
|
363 |
+
return all_outputs
|
364 |
+
|
365 |
+
@torch.no_grad()
|
366 |
+
def predict_v2t(self, idx=0, sid=0):
|
367 |
+
all_outputs = self.forward_video(self.text_features, sid)
|
368 |
+
preds = all_outputs.numpy()
|
369 |
+
relevancy = self.relevancy_mat_v2t[idx]
|
370 |
+
sel = 3
|
371 |
+
pred_action = self.narrations[(-preds[0]).argsort()[:sel]]
|
372 |
+
gt_action = self.narrations[np.where(relevancy == 1)[0]]
|
373 |
+
return pred_action, gt_action
|
374 |
+
|
375 |
+
@torch.no_grad()
|
376 |
+
def predict_t2v(self, idx=0, sid=0):
|
377 |
+
text_features = self.text_features[sid].unsqueeze(0)
|
378 |
+
all_outputs = self.forward_video(text_features, t2v=True)
|
379 |
+
preds = all_outputs.numpy()
|
380 |
+
relevancy = self.relevancy_mat_t2v[idx]
|
381 |
+
sel = 3
|
382 |
+
pred_video = self.video_samples[(-preds[0]).argsort()[:sel]]
|
383 |
+
gt_video = np.where(relevancy == 1)[0]
|
384 |
+
return pred_video, gt_video
|
385 |
+
|
386 |
+
@torch.no_grad()
|
387 |
+
def evaluate(self):
|
388 |
+
val_loader = self.load_data()
|
389 |
+
cfg, dataset = self.cfg, self.dataset
|
390 |
+
if self.dataset == 'ek100_mir':
|
391 |
+
all_video_embed = []
|
392 |
+
all_text_embed = []
|
393 |
+
for i, inputs in enumerate(val_loader):
|
394 |
+
inputs = [tensor.cuda(non_blocking=True) for tensor in inputs]
|
395 |
+
relevancies = inputs.pop()
|
396 |
+
|
397 |
+
# compute output
|
398 |
+
outputs = self.model(
|
399 |
+
*inputs,
|
400 |
+
use_checkpoint=True,
|
401 |
+
norm_embed=cfg['model']['norm_embed']
|
402 |
+
)
|
403 |
+
|
404 |
+
image_features = outputs['image_embed']
|
405 |
+
text_features = outputs['text_embed']
|
406 |
+
all_video_embed.append(image_features.cpu().numpy())
|
407 |
+
all_text_embed.append(text_features.cpu().numpy())
|
408 |
+
|
409 |
+
all_text_embed = np.vstack(all_text_embed)
|
410 |
+
all_video_embed = np.vstack(all_video_embed)
|
411 |
+
similarity_matrix = np.matmul(all_video_embed, all_text_embed.T)
|
412 |
+
similarity_matrix = (similarity_matrix + 1) / 2
|
413 |
+
video_id = pd.read_csv(cfg['data']['metadata'].replace('train', 'test')).values[:, 0]
|
414 |
+
text_id = pd.read_csv(cfg['data']['metadata'].replace('train', 'test_sentence')).values[:, 0]
|
415 |
+
indexes = [video_id.tolist().index(elem) for elem in text_id]
|
416 |
+
similarity_matrix = similarity_matrix[:, indexes]
|
417 |
+
print(similarity_matrix.shape)
|
418 |
+
rel_matrix = pd.read_pickle(
|
419 |
+
cfg['data']['relevancy_path']
|
420 |
+
)
|
421 |
+
vis_map = calculate_mAP(similarity_matrix, rel_matrix)
|
422 |
+
txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T)
|
423 |
+
avg_map = (vis_map + txt_map) / 2
|
424 |
+
print('mAP: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_map, txt_map, avg_map))
|
425 |
+
vis_k_counts = calculate_k_counts(rel_matrix)
|
426 |
+
txt_k_counts = calculate_k_counts(rel_matrix.T)
|
427 |
+
vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts)
|
428 |
+
txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts)
|
429 |
+
vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG)
|
430 |
+
txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG)
|
431 |
+
avg_nDCG = (vis_nDCG + txt_nDCG) / 2
|
432 |
+
print('nDCG: V->T: {:.3f} T->V: {:.3f} AVG: {:.3f}'.format(vis_nDCG, txt_nDCG, avg_nDCG))
|
433 |
+
|
434 |
+
else:
|
435 |
+
raise NotImplementedError
|
436 |
+
|
437 |
+
|
438 |
+
def main():
|
439 |
+
parser = argparse.ArgumentParser(description='Ego-VPA inference', add_help=False)
|
440 |
+
parser.add_argument('--dataset',
|
441 |
+
default='charades_ego',
|
442 |
+
type=str, help='charades_ego/ek100_mir')
|
443 |
+
args = parser.parse_args()
|
444 |
+
|
445 |
+
if args.dataset in ['charades_ego']:
|
446 |
+
lavila = VideoCLSModel(f"configs/{args.dataset}/zeroshot.yml")
|
447 |
+
egovpa = VideoCLSModel(f"configs/{args.dataset}/egovpa.yml")
|
448 |
+
elif args.dataset == 'ek100_mir':
|
449 |
+
#lavila = VideoMIRModel(f"configs/{args.dataset}/zeroshot.yml")
|
450 |
+
egovpa = VideoMIRModel(f"configs/{args.dataset}/egovpa.yml")
|
451 |
+
else:
|
452 |
+
raise NotImplementedError
|
453 |
+
|
454 |
+
#lavila.evaluate()
|
455 |
+
#egovpa.evaluate()
|
456 |
+
egovpa.predict_t2v(idx=0, sid=2119)
|
457 |
+
|
458 |
+
|
459 |
+
if __name__ == '__main__':
|
460 |
+
main()
|
461 |
+
|
lavila/data/datasets.py
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import csv
|
8 |
+
import glob
|
9 |
+
import json
|
10 |
+
import numpy as np
|
11 |
+
import os.path as osp
|
12 |
+
import pickle
|
13 |
+
import random
|
14 |
+
|
15 |
+
import decord
|
16 |
+
import pandas as pd
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
def datetime2sec(str):
|
21 |
+
hh, mm, ss = str.split(':')
|
22 |
+
return int(hh) * 3600 + int(mm) * 60 + float(ss)
|
23 |
+
|
24 |
+
|
25 |
+
def video_loader(root, vid, second, end_second=None, chunk_len=300, fps=30, clip_length=32, jitter=False):
|
26 |
+
if chunk_len == -1:
|
27 |
+
vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid)))
|
28 |
+
second_offset = second
|
29 |
+
if end_second is not None:
|
30 |
+
end_second = min(end_second, len(vr) / vr.get_avg_fps())
|
31 |
+
else:
|
32 |
+
end_second = len(vr) / vr.get_avg_fps()
|
33 |
+
else:
|
34 |
+
chunk_start = int(second) // chunk_len * chunk_len
|
35 |
+
second_offset = second - chunk_start
|
36 |
+
vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start)))
|
37 |
+
if fps == -1:
|
38 |
+
fps = vr.get_avg_fps()
|
39 |
+
|
40 |
+
# calculate frame_ids
|
41 |
+
frame_offset = int(np.round(second_offset * fps))
|
42 |
+
total_duration = max(int((end_second - second) * fps), clip_length)
|
43 |
+
if chunk_len == -1:
|
44 |
+
if end_second <= second:
|
45 |
+
raise ValueError("end_second should be greater than second")
|
46 |
+
else:
|
47 |
+
frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter)
|
48 |
+
else:
|
49 |
+
frame_ids = get_frame_ids(frame_offset, frame_offset + total_duration, num_segments=clip_length, jitter=jitter)
|
50 |
+
|
51 |
+
# load frames
|
52 |
+
if max(frame_ids) < len(vr):
|
53 |
+
try:
|
54 |
+
frames = vr.get_batch(frame_ids).asnumpy()
|
55 |
+
except decord.DECORDError as error:
|
56 |
+
print(error)
|
57 |
+
frames = vr.get_batch([0] * len(frame_ids)).asnumpy()
|
58 |
+
else:
|
59 |
+
# find the remaining frames in the next chunk
|
60 |
+
try:
|
61 |
+
frame_ids_part1 = list(filter(lambda frame_id: frame_id < len(vr), frame_ids))
|
62 |
+
frames_part1 = vr.get_batch(frame_ids_part1).asnumpy()
|
63 |
+
vr2 = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start + chunk_len)))
|
64 |
+
frame_ids_part2 = list(filter(lambda frame_id: frame_id >= len(vr), frame_ids))
|
65 |
+
frame_ids_part2 = [min(frame_id % len(vr), len(vr2) - 1) for frame_id in frame_ids_part2]
|
66 |
+
frames_part2 = vr2.get_batch(frame_ids_part2).asnumpy()
|
67 |
+
frames = np.concatenate([frames_part1, frames_part2], axis=0)
|
68 |
+
# the next chunk does not exist; the current chunk is the last one
|
69 |
+
except (RuntimeError, decord.DECORDError) as error:
|
70 |
+
print(error)
|
71 |
+
frame_ids = get_frame_ids(min(frame_offset, len(vr) - 1), len(vr), num_segments=clip_length, jitter=jitter)
|
72 |
+
frames = vr.get_batch(frame_ids).asnumpy()
|
73 |
+
|
74 |
+
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
|
75 |
+
return torch.stack(frames, dim=0)
|
76 |
+
|
77 |
+
|
78 |
+
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
|
79 |
+
seg_size = float(end_frame - start_frame - 1) / num_segments
|
80 |
+
seq = []
|
81 |
+
for i in range(num_segments):
|
82 |
+
start = int(np.round(seg_size * i) + start_frame)
|
83 |
+
end = int(np.round(seg_size * (i + 1)) + start_frame)
|
84 |
+
end = min(end, end_frame)
|
85 |
+
if jitter:
|
86 |
+
frame_id = np.random.randint(low=start, high=(end + 1))
|
87 |
+
else:
|
88 |
+
frame_id = (start + end) // 2
|
89 |
+
seq.append(frame_id)
|
90 |
+
return seq
|
91 |
+
|
92 |
+
|
93 |
+
def video_loader_by_frames(root, vid, frame_ids):
|
94 |
+
vr = decord.VideoReader(osp.join(root, vid))
|
95 |
+
try:
|
96 |
+
frames = vr.get_batch(frame_ids).asnumpy()
|
97 |
+
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
|
98 |
+
except (IndexError, decord.DECORDError) as error:
|
99 |
+
print(error)
|
100 |
+
print("Erroneous video: ", vid)
|
101 |
+
frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
|
102 |
+
return torch.stack(frames, dim=0)
|
103 |
+
|
104 |
+
|
105 |
+
class VideoCaptionDatasetBase(torch.utils.data.Dataset):
|
106 |
+
def __init__(self, dataset, root, metadata, is_trimmed=True):
|
107 |
+
self.dataset = dataset
|
108 |
+
self.root = root
|
109 |
+
self.is_trimmed = is_trimmed
|
110 |
+
|
111 |
+
if self.dataset == 'ego4d':
|
112 |
+
with open(metadata, 'rb') as f:
|
113 |
+
self.samples = pickle.load(f)
|
114 |
+
elif self.dataset == 'ego4d_mcq':
|
115 |
+
with open(metadata, 'r') as f:
|
116 |
+
self.samples = json.load(f)
|
117 |
+
elif self.dataset in ['ek100_cls', 'ek100_mir']:
|
118 |
+
video_list = glob.glob(osp.join(self.root, '*/*.MP4'))
|
119 |
+
fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
|
120 |
+
self.samples = []
|
121 |
+
with open(metadata) as f:
|
122 |
+
csv_reader = csv.reader(f)
|
123 |
+
_ = next(csv_reader) # skip the header
|
124 |
+
for row in csv_reader:
|
125 |
+
pid, vid = row[1:3]
|
126 |
+
# start_frame, end_frame = int(row[6]), int(row[7])
|
127 |
+
# Deprecated: some videos might have fps mismatch issue
|
128 |
+
start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5])
|
129 |
+
narration = row[8]
|
130 |
+
verb, noun = int(row[10]), int(row[12])
|
131 |
+
vid_path = '{}/{}.MP4'.format(pid, vid)
|
132 |
+
fps = fps_dict[osp.join(self.root, vid_path)]
|
133 |
+
start_frame = int(np.round(fps * start_timestamp))
|
134 |
+
end_frame = int(np.ceil(fps * end_timestamp))
|
135 |
+
self.samples.append((vid_path, start_frame, end_frame, narration, verb, noun))
|
136 |
+
if self.dataset == 'ek100_mir':
|
137 |
+
self.metadata_sentence = pd.read_csv(metadata[:metadata.index('.csv')] + '_sentence.csv')
|
138 |
+
if 'train' in metadata:
|
139 |
+
self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_train.pkl'), 'rb'))
|
140 |
+
elif 'test' in metadata:
|
141 |
+
self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_test.pkl'), 'rb'))
|
142 |
+
else:
|
143 |
+
raise ValueError('{} should contain either "train" or "test"!'.format(metadata))
|
144 |
+
self.relevancy = .1
|
145 |
+
elif self.dataset == 'ek100_mir_demo':
|
146 |
+
df = pd.read_csv(metadata, header=None)
|
147 |
+
fps = 59.94
|
148 |
+
self.samples = []
|
149 |
+
for i in range(len(df)):
|
150 |
+
vid_path, start_timestamp, end_timestamp, narration, verb, noun = df.iloc[i:i+1].values[0].tolist()[1:]
|
151 |
+
start_frame = int(np.round(fps * start_timestamp))
|
152 |
+
end_frame = int(np.ceil(fps * end_timestamp))
|
153 |
+
self.samples.append((vid_path, start_frame, end_frame, narration, verb, noun))
|
154 |
+
|
155 |
+
elif self.dataset == 'egtea':
|
156 |
+
video_list = glob.glob(osp.join(self.root, '*/*'))
|
157 |
+
len_dict = {video: len(decord.VideoReader(video)) for video in video_list}
|
158 |
+
|
159 |
+
vn_list, labels = [], []
|
160 |
+
for row in open(osp.join(osp.dirname(metadata), 'action_idx.txt')):
|
161 |
+
row = row.strip()
|
162 |
+
vn = int(row.split(' ')[-1])
|
163 |
+
vn_list.append(vn)
|
164 |
+
narration = ' '.join(row.split(' ')[:-1])
|
165 |
+
labels.append(narration.replace('_', ' ').lower())
|
166 |
+
# labels.append(narration)
|
167 |
+
mapping_act2narration = {vn: narration for vn, narration in zip(vn_list, labels)}
|
168 |
+
|
169 |
+
self.samples = []
|
170 |
+
with open(metadata) as f:
|
171 |
+
for row in f:
|
172 |
+
clip_id, action_idx = row.strip().split(' ')[:2]
|
173 |
+
video_id = '-'.join(clip_id.split('-')[:3])
|
174 |
+
vid_relpath = osp.join(video_id, '{}.mp4'.format(clip_id))
|
175 |
+
vid_fullpath = osp.join(self.root, video_id, '{}.mp4'.format(clip_id))
|
176 |
+
self.samples.append((vid_relpath, 0, len_dict[vid_fullpath], mapping_act2narration[int(action_idx)]))
|
177 |
+
elif self.dataset == 'charades_ego':
|
178 |
+
video_list = glob.glob(osp.join(self.root, '*.mp4'))
|
179 |
+
fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
|
180 |
+
self.samples = []
|
181 |
+
with open(metadata) as f:
|
182 |
+
csv_reader = csv.reader(f)
|
183 |
+
_ = next(csv_reader) # skip the header
|
184 |
+
for row in csv_reader:
|
185 |
+
video_id = row[0]
|
186 |
+
if self.is_trimmed:
|
187 |
+
for action_tuple in row[9].split(';'):
|
188 |
+
if not action_tuple:
|
189 |
+
continue
|
190 |
+
action, start_timestamp, end_timestamp = action_tuple.split(' ')
|
191 |
+
start_timestamp, end_timestamp = float(start_timestamp), float(end_timestamp)
|
192 |
+
vid_path = '{}.mp4'.format(video_id)
|
193 |
+
fps = fps_dict[osp.join(self.root, vid_path)]
|
194 |
+
start_frame = int(np.round(fps * start_timestamp))
|
195 |
+
end_frame = int(np.ceil(fps * end_timestamp))
|
196 |
+
self.samples.append((vid_path, start_frame, end_frame, action))
|
197 |
+
else:
|
198 |
+
if not row[9]:
|
199 |
+
action_list = []
|
200 |
+
else:
|
201 |
+
action_list = [action_tuple.split(' ')[0] for action_tuple in row[9].split(';')]
|
202 |
+
vid_path = '{}.mp4'.format(video_id)
|
203 |
+
fps = fps_dict[osp.join(self.root, vid_path)]
|
204 |
+
duration = fps * float(row[10])
|
205 |
+
self.samples.append((vid_path, 0, duration, action_list))
|
206 |
+
elif self.dataset == 'charades_ego_trimmed':
|
207 |
+
with open(metadata, 'rb') as f:
|
208 |
+
self.samples = pickle.load(f)
|
209 |
+
else:
|
210 |
+
raise NotImplementedError
|
211 |
+
|
212 |
+
def get_raw_item(self, i, is_training=True, num_clips=1, clip_length=32, clip_stride=2, sparse_sample=False,
|
213 |
+
narration_selection='random'):
|
214 |
+
if self.dataset == 'ego4d':
|
215 |
+
if len(self.samples[i]) == 4:
|
216 |
+
vid, start_second, end_second, narration = self.samples[i]
|
217 |
+
frames = video_loader(self.root, vid, start_second,
|
218 |
+
end_second=end_second,
|
219 |
+
clip_length=clip_length,
|
220 |
+
jitter=is_training)
|
221 |
+
if isinstance(narration, list):
|
222 |
+
if narration_selection == 'random':
|
223 |
+
narration = random.choice(narration)
|
224 |
+
elif narration_selection == 'concat':
|
225 |
+
narration = '. '.join(narration)
|
226 |
+
elif narration_selection == 'list':
|
227 |
+
narration = narration
|
228 |
+
else:
|
229 |
+
raise ValueError
|
230 |
+
return frames, narration
|
231 |
+
elif len(self.samples[i]) == 5:
|
232 |
+
# TODO: need better filtering strategy based on nll
|
233 |
+
vid, start_second, end_second, narration, _ = self.samples[i]
|
234 |
+
frames = video_loader(self.root, vid, start_second,
|
235 |
+
end_second=end_second,
|
236 |
+
clip_length=clip_length,
|
237 |
+
jitter=is_training)
|
238 |
+
if isinstance(narration, list):
|
239 |
+
if narration_selection == 'random':
|
240 |
+
narration = random.choice(narration)
|
241 |
+
elif narration_selection == 'concat':
|
242 |
+
narration = '. '.join(narration)
|
243 |
+
elif narration_selection == 'list':
|
244 |
+
narration = narration
|
245 |
+
else:
|
246 |
+
raise ValueError
|
247 |
+
return frames, narration
|
248 |
+
elif self.dataset == 'ego4d_mcq':
|
249 |
+
itemMCQ = self.samples[str(i)]
|
250 |
+
answerIndex = itemMCQ['answer']
|
251 |
+
textQuery = itemMCQ['query']['clip_text']
|
252 |
+
sampleOptions = itemMCQ['choices']
|
253 |
+
frames_options = []
|
254 |
+
narration_options = []
|
255 |
+
for option_id in range(len(sampleOptions)):
|
256 |
+
option = sampleOptions[str(option_id)]
|
257 |
+
frames = video_loader(self.root, option['video_uid'],
|
258 |
+
float(option['clip_start']), end_second=float(option['clip_end']),
|
259 |
+
clip_length=clip_length,
|
260 |
+
jitter=is_training)
|
261 |
+
frames_options.append(frames)
|
262 |
+
narration_options.append(option['clip_text'])
|
263 |
+
return textQuery, frames_options, narration_options, answerIndex, itemMCQ['types']
|
264 |
+
elif self.dataset == 'ek100_mir':
|
265 |
+
vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
|
266 |
+
# from third_party.EgoVLP.base.base_dataset import sample_frames_start_end
|
267 |
+
# frame_ids = sample_frames_start_end(clip_length, start_frame, end_frame, sample='uniform', fix_start=None)
|
268 |
+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
|
269 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
270 |
+
if is_training:
|
271 |
+
positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist()
|
272 |
+
if positive_list != []:
|
273 |
+
pos = random.sample(positive_list, min(len(positive_list), 1))[0]
|
274 |
+
if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]:
|
275 |
+
return frames, (self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos])
|
276 |
+
else:
|
277 |
+
return frames, (narration, 1)
|
278 |
+
elif self.dataset == 'ek100_mir_demo':
|
279 |
+
vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
|
280 |
+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
|
281 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
282 |
+
return frames, (narration, 1)
|
283 |
+
|
284 |
+
elif self.dataset == 'ek100_cls':
|
285 |
+
vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
|
286 |
+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
|
287 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
288 |
+
return frames, '{}:{}'.format(verb, noun)
|
289 |
+
elif self.dataset == 'egtea':
|
290 |
+
vid_path, start_frame, end_frame, sentence = self.samples[i]
|
291 |
+
if is_training:
|
292 |
+
assert num_clips == 1
|
293 |
+
if end_frame < clip_length * clip_stride:
|
294 |
+
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
|
295 |
+
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
|
296 |
+
frames = torch.cat((frames, zeros), dim=0)
|
297 |
+
frames = frames[::clip_stride]
|
298 |
+
else:
|
299 |
+
start_id = np.random.randint(0, end_frame - clip_length * clip_stride + 1)
|
300 |
+
frame_ids = np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)
|
301 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
302 |
+
else:
|
303 |
+
if end_frame < clip_length * clip_stride:
|
304 |
+
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
|
305 |
+
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
|
306 |
+
frames = torch.cat((frames, zeros), dim=0)
|
307 |
+
frames = frames[::clip_stride]
|
308 |
+
frames = frames.repeat(num_clips, 1, 1, 1)
|
309 |
+
else:
|
310 |
+
frame_ids = []
|
311 |
+
for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
|
312 |
+
frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
|
313 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
314 |
+
return frames, sentence
|
315 |
+
elif self.dataset == 'charades_ego':
|
316 |
+
vid_path, start_frame, end_frame, action_list = self.samples[i]
|
317 |
+
if sparse_sample:
|
318 |
+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=num_clips * clip_length, jitter=is_training)
|
319 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
320 |
+
else:
|
321 |
+
if end_frame < clip_length * clip_stride:
|
322 |
+
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
|
323 |
+
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
|
324 |
+
frames = torch.cat((frames, zeros), dim=0)
|
325 |
+
frames = frames[::clip_stride]
|
326 |
+
frames = frames.repeat(num_clips, 1, 1, 1)
|
327 |
+
else:
|
328 |
+
frame_ids = []
|
329 |
+
for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
|
330 |
+
frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
|
331 |
+
#print('frame_ids:', frame_ids)
|
332 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
333 |
+
return frames, action_list, vid_path
|
334 |
+
elif self.dataset == 'charades_ego_trimmed':
|
335 |
+
vid, start_second, end_second, narration = self.samples[i]
|
336 |
+
frames = video_loader(self.root, vid, start_second,
|
337 |
+
end_second=end_second,
|
338 |
+
chunk_len=-1, # no chunk for CharadesEgo
|
339 |
+
fps=-1, # could be variable fps
|
340 |
+
clip_length=clip_length,
|
341 |
+
jitter=is_training)
|
342 |
+
return frames, narration
|
343 |
+
else:
|
344 |
+
raise NotImplementedError
|
345 |
+
|
346 |
+
def __getitem__(self, i):
|
347 |
+
raise NotImplementedError
|
348 |
+
|
349 |
+
def __len__(self):
|
350 |
+
return len(self.samples)
|
351 |
+
|
352 |
+
|
353 |
+
class VideoCaptionDatasetCLIP(VideoCaptionDatasetBase):
|
354 |
+
def __init__(self, dataset, root, metadata, transform=None,
|
355 |
+
is_training=True, tokenizer=None,
|
356 |
+
clip_length=32, clip_stride=2, sparse_sample=False,
|
357 |
+
narration_selection='random',
|
358 |
+
num_hard_negatives=0,
|
359 |
+
subsample_stride=None):
|
360 |
+
super().__init__(dataset, root, metadata)
|
361 |
+
|
362 |
+
self.full_samples = self.samples.copy()
|
363 |
+
if isinstance(subsample_stride, int):
|
364 |
+
self.samples = self.samples[::subsample_stride]
|
365 |
+
self.transform = transform
|
366 |
+
self.is_training = is_training
|
367 |
+
self.tokenizer = tokenizer
|
368 |
+
self.clip_length = clip_length
|
369 |
+
self.clip_stride = clip_stride
|
370 |
+
self.sparse_sample = sparse_sample
|
371 |
+
self.narration_selection = narration_selection
|
372 |
+
self.num_hard_negatives = num_hard_negatives
|
373 |
+
if num_hard_negatives > 0:
|
374 |
+
assert self.dataset == 'htm_aa'
|
375 |
+
|
376 |
+
def __getitem__(self, i):
|
377 |
+
frames, caption = self.get_raw_item(
|
378 |
+
i, is_training=self.is_training,
|
379 |
+
clip_length=self.clip_length,
|
380 |
+
clip_stride=self.clip_stride,
|
381 |
+
sparse_sample=self.sparse_sample,
|
382 |
+
narration_selection=self.narration_selection,
|
383 |
+
)
|
384 |
+
|
385 |
+
# ek100_mir will also output relevancy value
|
386 |
+
if isinstance(caption, tuple):
|
387 |
+
caption, relevancy = caption
|
388 |
+
else:
|
389 |
+
relevancy = 0.
|
390 |
+
|
391 |
+
# apply transformation
|
392 |
+
if self.transform is not None:
|
393 |
+
frames = self.transform(frames)
|
394 |
+
|
395 |
+
# tokenize caption
|
396 |
+
if self.tokenizer is not None:
|
397 |
+
caption = self.tokenizer(caption)
|
398 |
+
|
399 |
+
if isinstance(caption, tuple):
|
400 |
+
caption, mask = caption
|
401 |
+
return frames, caption, mask, relevancy
|
402 |
+
else:
|
403 |
+
return frames, caption, relevancy
|
404 |
+
|
405 |
+
|
406 |
+
class VideoCaptionDatasetMCQ(VideoCaptionDatasetBase):
|
407 |
+
def __init__(self, dataset, root, metadata, transform=None,
|
408 |
+
is_training=True, tokenizer=None,
|
409 |
+
clip_length=32, clip_stride=2, sparse_sample=False,
|
410 |
+
narration_selection='random'):
|
411 |
+
super().__init__(dataset, root, metadata)
|
412 |
+
|
413 |
+
self.full_samples = self.samples.copy()
|
414 |
+
self.transform = transform
|
415 |
+
self.is_training = is_training
|
416 |
+
self.tokenizer = tokenizer
|
417 |
+
self.clip_length = clip_length
|
418 |
+
self.clip_stride = clip_stride
|
419 |
+
self.sparse_sample = sparse_sample
|
420 |
+
self.narration_selection = narration_selection
|
421 |
+
|
422 |
+
def __getitem__(self, i):
|
423 |
+
|
424 |
+
textQuery, frames_options, narration_options, answerIndex, q_type = self.get_raw_item(
|
425 |
+
i, is_training=self.is_training,
|
426 |
+
clip_length=self.clip_length,
|
427 |
+
clip_stride=self.clip_stride,
|
428 |
+
sparse_sample=self.sparse_sample,
|
429 |
+
narration_selection=self.narration_selection,
|
430 |
+
)
|
431 |
+
|
432 |
+
# apply transformation
|
433 |
+
if self.transform is not None:
|
434 |
+
frames_options = [self.transform(frames) for frames in frames_options]
|
435 |
+
|
436 |
+
# tokenize caption
|
437 |
+
if self.tokenizer is not None:
|
438 |
+
textQuery = self.tokenizer(textQuery)
|
439 |
+
narration_options = self.tokenizer(narration_options)
|
440 |
+
if isinstance(textQuery, tuple):
|
441 |
+
textQuery, mask_query = textQuery
|
442 |
+
narration_options, mask_options = narration_options
|
443 |
+
return (
|
444 |
+
textQuery, torch.stack(frames_options, dim=0),
|
445 |
+
narration_options, answerIndex, q_type,
|
446 |
+
mask_query, mask_options
|
447 |
+
)
|
448 |
+
else:
|
449 |
+
return textQuery, torch.stack(frames_options, dim=0), narration_options, answerIndex, q_type
|
450 |
+
|
451 |
+
|
452 |
+
class VideoClassyDataset(VideoCaptionDatasetBase):
|
453 |
+
def __init__(
|
454 |
+
self, dataset, root, metadata, transform=None,
|
455 |
+
is_training=True, label_mapping=None,
|
456 |
+
num_clips=1,
|
457 |
+
clip_length=32, clip_stride=2,
|
458 |
+
sparse_sample=False,
|
459 |
+
is_trimmed=True,
|
460 |
+
):
|
461 |
+
super().__init__(dataset, root, metadata, is_trimmed=is_trimmed)
|
462 |
+
|
463 |
+
self.transform = transform
|
464 |
+
self.is_training = is_training
|
465 |
+
self.label_mapping = label_mapping
|
466 |
+
self.num_clips = num_clips
|
467 |
+
self.clip_length = clip_length
|
468 |
+
self.clip_stride = clip_stride
|
469 |
+
self.sparse_sample = sparse_sample
|
470 |
+
|
471 |
+
def __getitem__(self, i):
|
472 |
+
frames, label, vid_path = self.get_raw_item(
|
473 |
+
i, is_training=self.is_training,
|
474 |
+
num_clips=self.num_clips,
|
475 |
+
clip_length=self.clip_length,
|
476 |
+
clip_stride=self.clip_stride,
|
477 |
+
sparse_sample=self.sparse_sample,
|
478 |
+
)
|
479 |
+
|
480 |
+
# apply transformation
|
481 |
+
if self.transform is not None:
|
482 |
+
frames = self.transform(frames)
|
483 |
+
|
484 |
+
if self.label_mapping is not None:
|
485 |
+
if isinstance(label, list):
|
486 |
+
# multi-label case
|
487 |
+
res_array = np.zeros(len(self.label_mapping))
|
488 |
+
for lbl in label:
|
489 |
+
res_array[self.label_mapping[lbl]] = 1.
|
490 |
+
label = res_array
|
491 |
+
else:
|
492 |
+
label = self.label_mapping[label]
|
493 |
+
|
494 |
+
return frames, label, vid_path
|
495 |
+
|
496 |
+
|
497 |
+
def get_dataset(train_transform, tokenizer, cfg, is_training=True):
|
498 |
+
narration_selection = cfg.get('narration_selection', 'random')
|
499 |
+
num_hard_neg = cfg.get('num_hard_neg', 0)
|
500 |
+
data_cfg = cfg['data']
|
501 |
+
if cfg['model']['arch'].startswith('CLIP') or cfg['model']['arch'].startswith('VCLM'):
|
502 |
+
if is_training:
|
503 |
+
metadata = data_cfg['metadata']
|
504 |
+
else:
|
505 |
+
metadata = data_cfg['metadata_val']
|
506 |
+
|
507 |
+
return VideoCaptionDatasetCLIP(
|
508 |
+
data_cfg['dataset'], data_cfg['root'], metadata, train_transform,
|
509 |
+
is_training=is_training,
|
510 |
+
tokenizer=tokenizer,
|
511 |
+
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
|
512 |
+
sparse_sample=data_cfg['sparse_sample'],
|
513 |
+
narration_selection=narration_selection,
|
514 |
+
num_hard_negatives=num_hard_neg
|
515 |
+
)
|
516 |
+
else:
|
517 |
+
raise NotImplementedError
|
518 |
+
|
519 |
+
|
520 |
+
def get_downstream_dataset(transform, tokenizer, cfg, is_training=True, num_clips=0, label_mapping=None):
|
521 |
+
data_cfg = cfg['data']
|
522 |
+
n_clips = num_clips if num_clips > 0 else data_cfg['num_clips']
|
523 |
+
if is_training:
|
524 |
+
metadata = data_cfg['metadata']
|
525 |
+
return VideoClassyDataset(
|
526 |
+
data_cfg['dataset'], data_cfg['root'], metadata, transform,
|
527 |
+
is_training=True, label_mapping=label_mapping,
|
528 |
+
num_clips=n_clips,
|
529 |
+
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
|
530 |
+
sparse_sample=data_cfg['sparse_sample'],
|
531 |
+
)
|
532 |
+
else:
|
533 |
+
metadata = data_cfg['metadata_val']
|
534 |
+
return VideoClassyDataset(
|
535 |
+
data_cfg['dataset'], data_cfg['root'], metadata, transform,
|
536 |
+
is_training=False, label_mapping=label_mapping,
|
537 |
+
num_clips=n_clips,
|
538 |
+
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
|
539 |
+
sparse_sample=data_cfg['sparse_sample'],
|
540 |
+
is_trimmed=not data_cfg['dataset'] == 'charades_ego'
|
541 |
+
)
|
542 |
+
|
lavila/data/video_transforms.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Sequence
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
|
14 |
+
class Permute(nn.Module):
|
15 |
+
"""
|
16 |
+
Permutation as an op
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, ordering):
|
20 |
+
super().__init__()
|
21 |
+
self.ordering = ordering
|
22 |
+
|
23 |
+
def forward(self, frames):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
frames in some ordering, by default (C, T, H, W)
|
27 |
+
Returns:
|
28 |
+
frames in the ordering that was specified
|
29 |
+
"""
|
30 |
+
return frames.permute(self.ordering)
|
31 |
+
|
32 |
+
|
33 |
+
class TemporalCrop(nn.Module):
|
34 |
+
"""
|
35 |
+
Convert the video into smaller clips temporally.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.frames = frames_per_clip
|
43 |
+
self.stride = stride
|
44 |
+
self.frame_stride = frame_stride
|
45 |
+
|
46 |
+
def forward(self, video):
|
47 |
+
assert video.ndim == 4, "Must be (C, T, H, W)"
|
48 |
+
res = []
|
49 |
+
for start in range(
|
50 |
+
0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride
|
51 |
+
):
|
52 |
+
end = start + (self.frames) * self.frame_stride
|
53 |
+
res.append(video[:, start: end: self.frame_stride, ...])
|
54 |
+
return res
|
55 |
+
|
56 |
+
|
57 |
+
def crop_boxes(boxes, x_offset, y_offset):
|
58 |
+
"""
|
59 |
+
Peform crop on the bounding boxes given the offsets.
|
60 |
+
Args:
|
61 |
+
boxes (ndarray or None): bounding boxes to peform crop. The dimension
|
62 |
+
is `num boxes` x 4.
|
63 |
+
x_offset (int): cropping offset in the x axis.
|
64 |
+
y_offset (int): cropping offset in the y axis.
|
65 |
+
Returns:
|
66 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
67 |
+
`num boxes` x 4.
|
68 |
+
"""
|
69 |
+
cropped_boxes = boxes.copy()
|
70 |
+
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
71 |
+
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
72 |
+
|
73 |
+
return cropped_boxes
|
74 |
+
|
75 |
+
|
76 |
+
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
77 |
+
"""
|
78 |
+
Perform uniform spatial sampling on the images and corresponding boxes.
|
79 |
+
Args:
|
80 |
+
images (tensor): images to perform uniform crop. The dimension is
|
81 |
+
`num frames` x `channel` x `height` x `width`.
|
82 |
+
size (int): size of height and weight to crop the images.
|
83 |
+
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
84 |
+
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
85 |
+
crop if height is larger than width.
|
86 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
87 |
+
Dimension is `num boxes` x 4.
|
88 |
+
scale_size (int): optinal. If not None, resize the images to scale_size before
|
89 |
+
performing any crop.
|
90 |
+
Returns:
|
91 |
+
cropped (tensor): images with dimension of
|
92 |
+
`num frames` x `channel` x `size` x `size`.
|
93 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
94 |
+
`num boxes` x 4.
|
95 |
+
"""
|
96 |
+
assert spatial_idx in [0, 1, 2]
|
97 |
+
ndim = len(images.shape)
|
98 |
+
if ndim == 3:
|
99 |
+
images = images.unsqueeze(0)
|
100 |
+
height = images.shape[2]
|
101 |
+
width = images.shape[3]
|
102 |
+
|
103 |
+
if scale_size is not None:
|
104 |
+
if width <= height:
|
105 |
+
width, height = scale_size, int(height / width * scale_size)
|
106 |
+
else:
|
107 |
+
width, height = int(width / height * scale_size), scale_size
|
108 |
+
images = torch.nn.functional.interpolate(
|
109 |
+
images,
|
110 |
+
size=(height, width),
|
111 |
+
mode="bilinear",
|
112 |
+
align_corners=False,
|
113 |
+
)
|
114 |
+
|
115 |
+
y_offset = int(math.ceil((height - size) / 2))
|
116 |
+
x_offset = int(math.ceil((width - size) / 2))
|
117 |
+
|
118 |
+
if height > width:
|
119 |
+
if spatial_idx == 0:
|
120 |
+
y_offset = 0
|
121 |
+
elif spatial_idx == 2:
|
122 |
+
y_offset = height - size
|
123 |
+
else:
|
124 |
+
if spatial_idx == 0:
|
125 |
+
x_offset = 0
|
126 |
+
elif spatial_idx == 2:
|
127 |
+
x_offset = width - size
|
128 |
+
cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
|
129 |
+
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
130 |
+
if ndim == 3:
|
131 |
+
cropped = cropped.squeeze(0)
|
132 |
+
return cropped, cropped_boxes
|
133 |
+
|
134 |
+
|
135 |
+
class SpatialCrop(nn.Module):
|
136 |
+
"""
|
137 |
+
Convert the video into 3 smaller clips spatially. Must be used after the
|
138 |
+
temporal crops to get spatial crops, and should be used with
|
139 |
+
-2 in the spatial crop at the slowfast augmentation stage (so full
|
140 |
+
frames are passed in here). Will return a larger list with the
|
141 |
+
3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT)
|
142 |
+
or 3x10 testing in SlowFast etc.
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(self, crop_size: int = 224, num_crops: int = 3):
|
146 |
+
super().__init__()
|
147 |
+
self.crop_size = crop_size
|
148 |
+
if num_crops == 6:
|
149 |
+
self.crops_to_ext = [0, 1, 2]
|
150 |
+
# I guess Swin uses 5 crops without flipping, but that doesn't
|
151 |
+
# make sense given they first resize to 224 and take 224 crops.
|
152 |
+
# (pg 6 of https://arxiv.org/pdf/2106.13230.pdf)
|
153 |
+
# So I'm assuming we can use flipped crops and that will add sth..
|
154 |
+
self.flipped_crops_to_ext = [0, 1, 2]
|
155 |
+
elif num_crops == 3:
|
156 |
+
self.crops_to_ext = [0, 1, 2]
|
157 |
+
self.flipped_crops_to_ext = []
|
158 |
+
elif num_crops == 1:
|
159 |
+
self.crops_to_ext = [1]
|
160 |
+
self.flipped_crops_to_ext = []
|
161 |
+
else:
|
162 |
+
raise NotImplementedError(
|
163 |
+
"Nothing else supported yet, "
|
164 |
+
"slowfast only takes 0, 1, 2 as arguments"
|
165 |
+
)
|
166 |
+
|
167 |
+
def forward(self, videos: Sequence[torch.Tensor]):
|
168 |
+
"""
|
169 |
+
Args:
|
170 |
+
videos: A list of C, T, H, W videos.
|
171 |
+
Returns:
|
172 |
+
videos: A list with 3x the number of elements. Each video converted
|
173 |
+
to C, T, H', W' by spatial cropping.
|
174 |
+
"""
|
175 |
+
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
|
176 |
+
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
|
177 |
+
res = []
|
178 |
+
for video in videos:
|
179 |
+
for spatial_idx in self.crops_to_ext:
|
180 |
+
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
|
181 |
+
if not self.flipped_crops_to_ext:
|
182 |
+
continue
|
183 |
+
flipped_video = transforms.functional.hflip(video)
|
184 |
+
for spatial_idx in self.flipped_crops_to_ext:
|
185 |
+
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
|
186 |
+
return res
|
lavila/models/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
lavila/models/distributed_utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Part of the code is from
|
7 |
+
# `https://github.com/facebookresearch/vissl/blob/main/vissl/utils/distributed_utils.py` and
|
8 |
+
# `https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/generic/distributed_util.py`
|
9 |
+
# Modified by Yue Zhao
|
10 |
+
# The original code is under MIT License
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
from typing import Tuple
|
15 |
+
|
16 |
+
|
17 |
+
def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, str]:
|
18 |
+
"""
|
19 |
+
For some backends, such as NCCL, communication only works if the
|
20 |
+
tensor is on the GPU. This helper function converts to the correct
|
21 |
+
device and returns the tensor + original device.
|
22 |
+
"""
|
23 |
+
orig_device = "cpu" if not tensor.is_cuda else "gpu"
|
24 |
+
if (
|
25 |
+
torch.distributed.is_available()
|
26 |
+
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
|
27 |
+
and not tensor.is_cuda
|
28 |
+
):
|
29 |
+
tensor = tensor.cuda()
|
30 |
+
return (tensor, orig_device)
|
31 |
+
|
32 |
+
|
33 |
+
def convert_to_normal_tensor(tensor: torch.Tensor, orig_device: str) -> torch.Tensor:
|
34 |
+
"""
|
35 |
+
For some backends, such as NCCL, communication only works if the
|
36 |
+
tensor is on the GPU. This converts the tensor back to original device.
|
37 |
+
"""
|
38 |
+
if tensor.is_cuda and orig_device == "cpu":
|
39 |
+
tensor = tensor.cpu()
|
40 |
+
return tensor
|
41 |
+
|
42 |
+
|
43 |
+
def is_distributed_training_run() -> bool:
|
44 |
+
return (
|
45 |
+
torch.distributed.is_available()
|
46 |
+
and torch.distributed.is_initialized()
|
47 |
+
and (torch.distributed.get_world_size() > 1)
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
class GatherLayer(torch.autograd.Function):
|
52 |
+
"""
|
53 |
+
Gather tensors from all workers with support for backward propagation:
|
54 |
+
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
55 |
+
"""
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def forward(ctx, x):
|
59 |
+
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
60 |
+
dist.all_gather(output, x)
|
61 |
+
return tuple(output)
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def backward(ctx, *grads):
|
65 |
+
all_gradients = torch.stack(grads)
|
66 |
+
dist.all_reduce(all_gradients)
|
67 |
+
return all_gradients[dist.get_rank()]
|
68 |
+
|
69 |
+
|
70 |
+
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
|
71 |
+
"""
|
72 |
+
Similar to classy_vision.generic.distributed_util.gather_from_all
|
73 |
+
except that it does not cut the gradients
|
74 |
+
"""
|
75 |
+
if tensor.ndim == 0:
|
76 |
+
# 0 dim tensors cannot be gathered. so unsqueeze
|
77 |
+
tensor = tensor.unsqueeze(0)
|
78 |
+
|
79 |
+
if is_distributed_training_run():
|
80 |
+
tensor, orig_device = convert_to_distributed_tensor(tensor)
|
81 |
+
gathered_tensors = GatherLayer.apply(tensor)
|
82 |
+
gathered_tensors = [
|
83 |
+
convert_to_normal_tensor(_tensor, orig_device)
|
84 |
+
for _tensor in gathered_tensors
|
85 |
+
]
|
86 |
+
else:
|
87 |
+
gathered_tensors = [tensor]
|
88 |
+
gathered_tensor = torch.cat(gathered_tensors, 0)
|
89 |
+
return gathered_tensor
|
lavila/models/models.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import timm
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from lavila.models.openai_clip import load as load_openai_clip
|
14 |
+
from lavila.models.openai_model import QuickGELU, Transformer
|
15 |
+
from lavila.models.timesformer import SpaceTimeTransformer
|
16 |
+
from lavila.models.utils import remap_keys, rsetattr
|
17 |
+
from lavila.models.prompt_tuning import PromptLearner
|
18 |
+
|
19 |
+
|
20 |
+
class CLIP(nn.Module):
|
21 |
+
def __init__(self,
|
22 |
+
cfg,
|
23 |
+
embed_dim: int,
|
24 |
+
# vision
|
25 |
+
vision_width: int,
|
26 |
+
vision_model: nn.Module,
|
27 |
+
# text
|
28 |
+
context_length: int,
|
29 |
+
vocab_size: int,
|
30 |
+
transformer_width: int,
|
31 |
+
transformer_heads: int,
|
32 |
+
transformer_layers: int,
|
33 |
+
tempearture_init=0.07,
|
34 |
+
**kwargs,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.context_length = context_length
|
39 |
+
self.vision_width = vision_width
|
40 |
+
self.tune_bias = cfg.get('tune_bias', False)
|
41 |
+
self.freeze_vis_backbone = cfg.get('freeze_vis_backbone', False)
|
42 |
+
self.freeze_txt_backbone = cfg.get('freeze_txt_backbone', False)
|
43 |
+
|
44 |
+
self.visual = vision_model
|
45 |
+
self.t_step = cfg.get('t_step', self.visual.num_frames)
|
46 |
+
txt_prompt_cfg = cfg.get('text_prompt', {})
|
47 |
+
self.n_ctx = txt_prompt_cfg.get('n_ctx', 0)
|
48 |
+
self.txt_use_bank = txt_prompt_cfg.get('use_bank', False)
|
49 |
+
if self.txt_use_bank:
|
50 |
+
self.transformer = Transformer(
|
51 |
+
width=transformer_width,
|
52 |
+
layers=transformer_layers,
|
53 |
+
heads=transformer_heads,
|
54 |
+
attn_mask=self.build_attention_mask(),
|
55 |
+
prompt_cfg=txt_prompt_cfg,
|
56 |
+
prompt_learner=PromptLearner(transformer_width, self.n_ctx),
|
57 |
+
prompt_generator=self.visual.prompt_generator
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
self.transformer = Transformer(
|
61 |
+
width=transformer_width,
|
62 |
+
layers=transformer_layers,
|
63 |
+
heads=transformer_heads,
|
64 |
+
attn_mask=self.build_attention_mask(),
|
65 |
+
prompt_cfg=txt_prompt_cfg,
|
66 |
+
prompt_learner=PromptLearner(transformer_width, self.n_ctx)
|
67 |
+
)
|
68 |
+
|
69 |
+
self.vocab_size = vocab_size
|
70 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
71 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
72 |
+
self.ln_final = nn.LayerNorm(transformer_width) # used to be `models.transformer.LayerNorm``
|
73 |
+
|
74 |
+
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
|
75 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
76 |
+
print("=> initialize initial temperature with {}".format(tempearture_init))
|
77 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init))
|
78 |
+
|
79 |
+
self.initialize_parameters()
|
80 |
+
|
81 |
+
freeze_list = []
|
82 |
+
if self.freeze_vis_backbone:
|
83 |
+
print("=> Freeze visual backbone")
|
84 |
+
freeze_list += self.visual.param_list + [self.image_projection]
|
85 |
+
|
86 |
+
if self.freeze_txt_backbone:
|
87 |
+
print("=> Freeze text backbone")
|
88 |
+
if self.tune_bias:
|
89 |
+
freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n and 'bias' not in n]
|
90 |
+
freeze_list += [m for n, m in self.ln_final.named_parameters() if 'bias' not in n]
|
91 |
+
else:
|
92 |
+
freeze_list += [m for n, m in self.transformer.named_parameters() if 'prompt' not in n]
|
93 |
+
freeze_list += list(self.ln_final.parameters())
|
94 |
+
freeze_list += list(self.token_embedding.parameters())
|
95 |
+
freeze_list += [self.positional_embedding] + [self.text_projection]
|
96 |
+
|
97 |
+
for p in freeze_list:
|
98 |
+
p.requires_grad = False
|
99 |
+
|
100 |
+
# text prompts
|
101 |
+
if self.n_ctx > 0:
|
102 |
+
if self.txt_use_bank:
|
103 |
+
prompt_dim = self.visual.prompt_dim
|
104 |
+
if prompt_dim != transformer_width:
|
105 |
+
self.transformer.prompt_inproj = nn.Linear(transformer_width, prompt_dim, bias=False)
|
106 |
+
else:
|
107 |
+
self.transformer.prompt_inproj = nn.Identity()
|
108 |
+
self.transformer.prompt_outproj = nn.Linear(prompt_dim, transformer_width, bias=False)
|
109 |
+
nn.init.kaiming_normal_(
|
110 |
+
self.transformer.prompt_outproj.weight, a=0, mode='fan_out')
|
111 |
+
|
112 |
+
params_to_update = [n for n, m in self.named_parameters() if m.requires_grad]
|
113 |
+
num_opt_params = sum([m.numel() for m in self.parameters() if m.requires_grad])
|
114 |
+
num_fz_params = sum([m.numel() for m in self.parameters() if not m.requires_grad])
|
115 |
+
print("=> Params to update: {}".format(params_to_update))
|
116 |
+
print("=> Update/Frozen: {}/{}".format(num_opt_params, num_fz_params))
|
117 |
+
|
118 |
+
def initialize_parameters(self):
|
119 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
120 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
121 |
+
|
122 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
123 |
+
attn_std = self.transformer.width ** -0.5
|
124 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
125 |
+
for block in self.transformer.resblocks:
|
126 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
127 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
128 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
129 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
130 |
+
|
131 |
+
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
|
132 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
133 |
+
|
134 |
+
def build_attention_mask(self):
|
135 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
136 |
+
# pytorch uses additive attention mask; fill with -inf
|
137 |
+
mask = torch.empty(self.context_length, self.context_length)
|
138 |
+
mask.fill_(float("-inf"))
|
139 |
+
mask.triu_(1) # zero out the lower diagonal
|
140 |
+
return mask
|
141 |
+
|
142 |
+
def encode_image(self, image, use_checkpoint=False, apply_project=True, istrain=False, gamma=1.0):
|
143 |
+
x, ps_loss = self.visual(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
|
144 |
+
|
145 |
+
if isinstance(x, list):
|
146 |
+
assert len(x) == 1
|
147 |
+
x = x[0]
|
148 |
+
if apply_project:
|
149 |
+
x = x @ self.image_projection
|
150 |
+
|
151 |
+
return x, ps_loss
|
152 |
+
|
153 |
+
def encode_text(self, text, use_checkpoint=False, istrain=False, gamma=1.0):
|
154 |
+
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
|
155 |
+
B = x.shape[0]
|
156 |
+
eot = text.argmax(dim=-1)
|
157 |
+
|
158 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
159 |
+
x, ps_loss = self.transformer(x, self.positional_embedding, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma, eot=eot)
|
160 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
161 |
+
x = self.ln_final(x)
|
162 |
+
|
163 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
164 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
165 |
+
x = x[torch.arange(x.shape[0]), self.n_ctx + eot] @ self.text_projection
|
166 |
+
|
167 |
+
return x, ps_loss
|
168 |
+
|
169 |
+
def forward(self, image, text, use_checkpoint=False, norm_embed=False, istrain=False, gamma=1.0):
|
170 |
+
image_embed, ps_loss_img = self.encode_image(image, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
|
171 |
+
text_embed, ps_loss_txt = self.encode_text(text, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
|
172 |
+
|
173 |
+
if norm_embed:
|
174 |
+
image_embed = F.normalize(image_embed, dim=-1)
|
175 |
+
text_embed = F.normalize(text_embed, dim=-1)
|
176 |
+
return {'image_embed': image_embed,
|
177 |
+
'text_embed': text_embed,
|
178 |
+
'logit_scale': self.logit_scale.exp(),
|
179 |
+
'ps_loss': ps_loss_img + ps_loss_txt}
|
180 |
+
|
181 |
+
def train(self, mode=True):
|
182 |
+
if not isinstance(mode, bool):
|
183 |
+
raise ValueError("training mode is expected to be boolean")
|
184 |
+
self.training = mode
|
185 |
+
for m in self.modules():
|
186 |
+
m.training = mode
|
187 |
+
|
188 |
+
if mode:
|
189 |
+
if self.freeze_vis_backbone and not self.tune_bias:
|
190 |
+
for n, m in self.visual.named_modules():
|
191 |
+
if 'prompt' not in n:
|
192 |
+
m.training = False
|
193 |
+
|
194 |
+
if self.freeze_txt_backbone and not self.tune_bias:
|
195 |
+
for n, m in self.transformer.named_modules():
|
196 |
+
if 'prompt' not in n:
|
197 |
+
m.training = False
|
198 |
+
|
199 |
+
self.token_embedding.training = False
|
200 |
+
self.ln_final.training = False
|
201 |
+
|
202 |
+
|
203 |
+
def CLIP_OPENAI_TIMESFORMER_BASE(
|
204 |
+
num_frames=4, timesformer_gated_xattn=False, temperature_init=0.07,
|
205 |
+
project_embed_dim=256, **kwargs
|
206 |
+
):
|
207 |
+
cfg = kwargs.pop('model_cfg', {})
|
208 |
+
vision_model = SpaceTimeTransformer(
|
209 |
+
num_frames=num_frames,
|
210 |
+
time_init='zeros',
|
211 |
+
attention_style='frozen-in-time',
|
212 |
+
ln_pre=True,
|
213 |
+
act_layer=QuickGELU,
|
214 |
+
is_tanh_gating=timesformer_gated_xattn,
|
215 |
+
drop_path_rate=cfg.get('drop_path_rate', 0),
|
216 |
+
tune_bias=cfg.get('tune_bias', False),
|
217 |
+
prompt_cfg=cfg.get('visual_prompt', {})
|
218 |
+
)
|
219 |
+
clip_model, _ = load_openai_clip('ViT-B/16', 'cpu')
|
220 |
+
print("=> Loading CLIP (ViT-B/16) weights")
|
221 |
+
remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=12)
|
222 |
+
res = vision_model.load_state_dict(remapped_state_dict, strict=False)
|
223 |
+
print(res)
|
224 |
+
|
225 |
+
vision_model.head = nn.Identity()
|
226 |
+
vision_model.pre_logits = nn.Identity()
|
227 |
+
vision_model.fc = nn.Identity()
|
228 |
+
model = CLIP(
|
229 |
+
cfg,
|
230 |
+
embed_dim=project_embed_dim,
|
231 |
+
vision_width=768,
|
232 |
+
vision_model=vision_model,
|
233 |
+
context_length=77,
|
234 |
+
vocab_size=49408,
|
235 |
+
transformer_width=512,
|
236 |
+
transformer_heads=8,
|
237 |
+
transformer_layers=12,
|
238 |
+
tempearture_init=temperature_init,
|
239 |
+
**kwargs
|
240 |
+
)
|
241 |
+
model.transformer.load_state_dict(clip_model.transformer.state_dict(), strict=False)
|
242 |
+
model.token_embedding.load_state_dict(clip_model.token_embedding.state_dict())
|
243 |
+
model.positional_embedding.data.copy_(clip_model.positional_embedding.data)
|
244 |
+
model.ln_final.load_state_dict(clip_model.ln_final.state_dict())
|
245 |
+
if project_embed_dim == clip_model.text_projection.shape[1]:
|
246 |
+
print("=> Loading CLIP's text_projection, image_projection and logit_scale directly")
|
247 |
+
model.image_projection.data.copy_(clip_model.visual.proj.data)
|
248 |
+
model.text_projection.data.copy_(clip_model.text_projection.data)
|
249 |
+
model.logit_scale.data.copy_(clip_model.logit_scale.data)
|
250 |
+
return model
|
251 |
+
|
252 |
+
|
lavila/models/openai_clip.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/openai/CLIP/blob/main/clip/clip.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
# The original code is under MIT License
|
10 |
+
|
11 |
+
import hashlib
|
12 |
+
import os
|
13 |
+
import urllib
|
14 |
+
import warnings
|
15 |
+
from typing import Union, List
|
16 |
+
from pkg_resources import packaging
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from PIL import Image
|
20 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from .openai_model import build_model
|
24 |
+
from .tokenizer import SimpleTokenizer as _Tokenizer
|
25 |
+
|
26 |
+
try:
|
27 |
+
from torchvision.transforms import InterpolationMode
|
28 |
+
BICUBIC = InterpolationMode.BICUBIC
|
29 |
+
except ImportError:
|
30 |
+
BICUBIC = Image.BICUBIC
|
31 |
+
|
32 |
+
|
33 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
34 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
35 |
+
|
36 |
+
|
37 |
+
__all__ = ["available_models", "load", "tokenize"]
|
38 |
+
_tokenizer = _Tokenizer()
|
39 |
+
|
40 |
+
_MODELS = {
|
41 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
42 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
43 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
44 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
45 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
46 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
47 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
48 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
49 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
50 |
+
}
|
51 |
+
|
52 |
+
|
53 |
+
def _download(url: str, root: str):
|
54 |
+
os.makedirs(root, exist_ok=True)
|
55 |
+
filename = os.path.basename(url)
|
56 |
+
|
57 |
+
expected_sha256 = url.split("/")[-2]
|
58 |
+
download_target = os.path.join(root, filename)
|
59 |
+
|
60 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
61 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
62 |
+
|
63 |
+
if os.path.isfile(download_target):
|
64 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
65 |
+
return download_target
|
66 |
+
else:
|
67 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
68 |
+
|
69 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
70 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
71 |
+
while True:
|
72 |
+
buffer = source.read(8192)
|
73 |
+
if not buffer:
|
74 |
+
break
|
75 |
+
|
76 |
+
output.write(buffer)
|
77 |
+
loop.update(len(buffer))
|
78 |
+
|
79 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
80 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
81 |
+
|
82 |
+
return download_target
|
83 |
+
|
84 |
+
|
85 |
+
def _convert_image_to_rgb(image):
|
86 |
+
return image.convert("RGB")
|
87 |
+
|
88 |
+
|
89 |
+
def _transform(n_px):
|
90 |
+
return Compose([
|
91 |
+
Resize(n_px, interpolation=BICUBIC),
|
92 |
+
CenterCrop(n_px),
|
93 |
+
_convert_image_to_rgb,
|
94 |
+
ToTensor(),
|
95 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
96 |
+
])
|
97 |
+
|
98 |
+
|
99 |
+
def available_models() -> List[str]:
|
100 |
+
"""Returns the names of available CLIP models"""
|
101 |
+
return list(_MODELS.keys())
|
102 |
+
|
103 |
+
|
104 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
105 |
+
"""Load a CLIP model
|
106 |
+
Parameters
|
107 |
+
----------
|
108 |
+
name : str
|
109 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
110 |
+
device : Union[str, torch.device]
|
111 |
+
The device to put the loaded model
|
112 |
+
jit : bool
|
113 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
114 |
+
download_root: str
|
115 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
116 |
+
Returns
|
117 |
+
-------
|
118 |
+
model : torch.nn.Module
|
119 |
+
The CLIP model
|
120 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
121 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
122 |
+
"""
|
123 |
+
if name in _MODELS:
|
124 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("cache/clip"))
|
125 |
+
elif os.path.isfile(name):
|
126 |
+
model_path = name
|
127 |
+
else:
|
128 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
129 |
+
|
130 |
+
with open(model_path, 'rb') as opened_file:
|
131 |
+
try:
|
132 |
+
# loading JIT archive
|
133 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
134 |
+
state_dict = None
|
135 |
+
except RuntimeError:
|
136 |
+
# loading saved state dict
|
137 |
+
if jit:
|
138 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
139 |
+
jit = False
|
140 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
141 |
+
|
142 |
+
if not jit:
|
143 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
144 |
+
if str(device) == "cpu":
|
145 |
+
model.float()
|
146 |
+
return model, _transform(model.visual.input_resolution)
|
147 |
+
|
148 |
+
# patch the device names
|
149 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
150 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
151 |
+
|
152 |
+
def patch_device(module):
|
153 |
+
try:
|
154 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
155 |
+
except RuntimeError:
|
156 |
+
graphs = []
|
157 |
+
|
158 |
+
if hasattr(module, "forward1"):
|
159 |
+
graphs.append(module.forward1.graph)
|
160 |
+
|
161 |
+
for graph in graphs:
|
162 |
+
for node in graph.findAllNodes("prim::Constant"):
|
163 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
164 |
+
node.copyAttributes(device_node)
|
165 |
+
|
166 |
+
model.apply(patch_device)
|
167 |
+
patch_device(model.encode_image)
|
168 |
+
patch_device(model.encode_text)
|
169 |
+
|
170 |
+
# patch dtype to float32 on CPU
|
171 |
+
if str(device) == "cpu":
|
172 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
173 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
174 |
+
float_node = float_input.node()
|
175 |
+
|
176 |
+
def patch_float(module):
|
177 |
+
try:
|
178 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
179 |
+
except RuntimeError:
|
180 |
+
graphs = []
|
181 |
+
|
182 |
+
if hasattr(module, "forward1"):
|
183 |
+
graphs.append(module.forward1.graph)
|
184 |
+
|
185 |
+
for graph in graphs:
|
186 |
+
for node in graph.findAllNodes("aten::to"):
|
187 |
+
inputs = list(node.inputs())
|
188 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
189 |
+
if inputs[i].node()["value"] == 5:
|
190 |
+
inputs[i].node().copyAttributes(float_node)
|
191 |
+
|
192 |
+
model.apply(patch_float)
|
193 |
+
patch_float(model.encode_image)
|
194 |
+
patch_float(model.encode_text)
|
195 |
+
|
196 |
+
model.float()
|
197 |
+
|
198 |
+
return model, _transform(model.input_resolution.item())
|
199 |
+
|
200 |
+
|
201 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
202 |
+
"""
|
203 |
+
Returns the tokenized representation of given input string(s)
|
204 |
+
Parameters
|
205 |
+
----------
|
206 |
+
texts : Union[str, List[str]]
|
207 |
+
An input string or a list of input strings to tokenize
|
208 |
+
context_length : int
|
209 |
+
The context length to use; all CLIP models use 77 as the context length
|
210 |
+
truncate: bool
|
211 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
212 |
+
Returns
|
213 |
+
-------
|
214 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
215 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
216 |
+
"""
|
217 |
+
if isinstance(texts, str):
|
218 |
+
texts = [texts]
|
219 |
+
|
220 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
221 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
222 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
223 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
224 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
225 |
+
else:
|
226 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
227 |
+
|
228 |
+
for i, tokens in enumerate(all_tokens):
|
229 |
+
if len(tokens) > context_length:
|
230 |
+
if truncate:
|
231 |
+
tokens = tokens[:context_length]
|
232 |
+
tokens[-1] = eot_token
|
233 |
+
else:
|
234 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
235 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
236 |
+
|
237 |
+
return result
|
lavila/models/openai_model.py
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/openai/CLIP/blob/main/clip/model.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
# The original code is under MIT License
|
10 |
+
|
11 |
+
from collections import OrderedDict
|
12 |
+
from typing import Tuple, Union
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.utils.checkpoint as checkpoint
|
19 |
+
from torch import nn
|
20 |
+
import pdb
|
21 |
+
|
22 |
+
|
23 |
+
class Bottleneck(nn.Module):
|
24 |
+
expansion = 4
|
25 |
+
|
26 |
+
def __init__(self, inplanes, planes, stride=1):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
30 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
31 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
32 |
+
self.relu1 = nn.ReLU(inplace=True)
|
33 |
+
|
34 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
35 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
36 |
+
self.relu2 = nn.ReLU(inplace=True)
|
37 |
+
|
38 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
39 |
+
|
40 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
41 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
42 |
+
self.relu3 = nn.ReLU(inplace=True)
|
43 |
+
|
44 |
+
self.downsample = None
|
45 |
+
self.stride = stride
|
46 |
+
|
47 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
48 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
49 |
+
self.downsample = nn.Sequential(OrderedDict([
|
50 |
+
("-1", nn.AvgPool2d(stride)),
|
51 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
52 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
53 |
+
]))
|
54 |
+
|
55 |
+
def forward(self, x: torch.Tensor):
|
56 |
+
identity = x
|
57 |
+
|
58 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
59 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
60 |
+
out = self.avgpool(out)
|
61 |
+
out = self.bn3(self.conv3(out))
|
62 |
+
|
63 |
+
if self.downsample is not None:
|
64 |
+
identity = self.downsample(x)
|
65 |
+
|
66 |
+
out += identity
|
67 |
+
out = self.relu3(out)
|
68 |
+
return out
|
69 |
+
|
70 |
+
|
71 |
+
class AttentionPool2d(nn.Module):
|
72 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
73 |
+
super().__init__()
|
74 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
75 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
76 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
77 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
78 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
79 |
+
self.num_heads = num_heads
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
83 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
84 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
85 |
+
x, _ = F.multi_head_attention_forward(
|
86 |
+
query=x[:1], key=x, value=x,
|
87 |
+
embed_dim_to_check=x.shape[-1],
|
88 |
+
num_heads=self.num_heads,
|
89 |
+
q_proj_weight=self.q_proj.weight,
|
90 |
+
k_proj_weight=self.k_proj.weight,
|
91 |
+
v_proj_weight=self.v_proj.weight,
|
92 |
+
in_proj_weight=None,
|
93 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
94 |
+
bias_k=None,
|
95 |
+
bias_v=None,
|
96 |
+
add_zero_attn=False,
|
97 |
+
dropout_p=0,
|
98 |
+
out_proj_weight=self.c_proj.weight,
|
99 |
+
out_proj_bias=self.c_proj.bias,
|
100 |
+
use_separate_proj_weight=True,
|
101 |
+
training=self.training,
|
102 |
+
need_weights=False
|
103 |
+
)
|
104 |
+
return x.squeeze(0)
|
105 |
+
|
106 |
+
|
107 |
+
class ModifiedResNet(nn.Module):
|
108 |
+
"""
|
109 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
110 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
111 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
112 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
116 |
+
super().__init__()
|
117 |
+
self.output_dim = output_dim
|
118 |
+
self.input_resolution = input_resolution
|
119 |
+
|
120 |
+
# the 3-layer stem
|
121 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
122 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
123 |
+
self.relu1 = nn.ReLU(inplace=True)
|
124 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
125 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
126 |
+
self.relu2 = nn.ReLU(inplace=True)
|
127 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
128 |
+
self.bn3 = nn.BatchNorm2d(width)
|
129 |
+
self.relu3 = nn.ReLU(inplace=True)
|
130 |
+
self.avgpool = nn.AvgPool2d(2)
|
131 |
+
|
132 |
+
# residual layers
|
133 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
134 |
+
self.layer1 = self._make_layer(width, layers[0])
|
135 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
136 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
137 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
138 |
+
|
139 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
140 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
141 |
+
|
142 |
+
def _make_layer(self, planes, blocks, stride=1):
|
143 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
144 |
+
|
145 |
+
self._inplanes = planes * Bottleneck.expansion
|
146 |
+
for _ in range(1, blocks):
|
147 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
148 |
+
|
149 |
+
return nn.Sequential(*layers)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
def stem(x):
|
153 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
154 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
155 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
156 |
+
x = self.avgpool(x)
|
157 |
+
return x
|
158 |
+
|
159 |
+
x = x.type(self.conv1.weight.dtype)
|
160 |
+
x = stem(x)
|
161 |
+
x = self.layer1(x)
|
162 |
+
x = self.layer2(x)
|
163 |
+
x = self.layer3(x)
|
164 |
+
x = self.layer4(x)
|
165 |
+
x = self.attnpool(x)
|
166 |
+
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class LayerNorm(nn.LayerNorm):
|
171 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
172 |
+
|
173 |
+
def forward(self, x: torch.Tensor):
|
174 |
+
orig_type = x.dtype
|
175 |
+
ret = super().forward(x.type(torch.float32))
|
176 |
+
return ret.type(orig_type)
|
177 |
+
|
178 |
+
|
179 |
+
class QuickGELU(nn.Module):
|
180 |
+
def forward(self, x: torch.Tensor):
|
181 |
+
return x * torch.sigmoid(1.702 * x)
|
182 |
+
|
183 |
+
|
184 |
+
class ResidualAttentionBlock(nn.Module):
|
185 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
189 |
+
self.ln_1 = nn.LayerNorm(d_model) # used to be `models.transformer.LayerNorm`
|
190 |
+
self.mlp = nn.Sequential(OrderedDict([
|
191 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
192 |
+
("gelu", QuickGELU()),
|
193 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
194 |
+
]))
|
195 |
+
self.ln_2 = nn.LayerNorm(d_model) # used to be `models.transformer.LayerNorm`
|
196 |
+
self.attn_mask = attn_mask
|
197 |
+
|
198 |
+
def attention(self, x: torch.Tensor):
|
199 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
200 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
201 |
+
|
202 |
+
def forward_part1(self, x):
|
203 |
+
return self.attention(self.ln_1(x))
|
204 |
+
|
205 |
+
def forward_part2(self, x):
|
206 |
+
return self.mlp(self.ln_2(x))
|
207 |
+
|
208 |
+
def forward(self, x: torch.Tensor, use_checkpoint=False):
|
209 |
+
if use_checkpoint:
|
210 |
+
x = x + checkpoint.checkpoint(self.forward_part1, x)
|
211 |
+
else:
|
212 |
+
x = x + self.forward_part1(x)
|
213 |
+
|
214 |
+
if use_checkpoint:
|
215 |
+
x = x + checkpoint.checkpoint(self.forward_part2, x)
|
216 |
+
else:
|
217 |
+
x = x + self.forward_part2(x)
|
218 |
+
return x
|
219 |
+
|
220 |
+
|
221 |
+
class Transformer(nn.Module):
|
222 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, prompt_cfg={}, prompt_learner=None, prompt_generator=None):
|
223 |
+
super().__init__()
|
224 |
+
self.width = width
|
225 |
+
self.layers = layers
|
226 |
+
self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
227 |
+
self.num_tokens = prompt_cfg.pop('n_ctx', 0)
|
228 |
+
self.use_bank = prompt_cfg.pop('use_bank', False)
|
229 |
+
if self.num_tokens > 0:
|
230 |
+
self.prompt_learner = prompt_learner
|
231 |
+
self.prompt_generator = prompt_generator
|
232 |
+
self.k_s = 0
|
233 |
+
if self.prompt_generator is not None:
|
234 |
+
if self.prompt_generator.use_bank:
|
235 |
+
self.k_s = len(self.prompt_generator.prompt_pool)
|
236 |
+
self.prompt_inproj = None
|
237 |
+
self.prompt_outproj = None
|
238 |
+
|
239 |
+
def forward(self, x: torch.Tensor, pos_emb, use_checkpoint=False, istrain=False, gamma=1.0, eot=None):
|
240 |
+
ps_loss = x.new_zeros([1])
|
241 |
+
BZ = x.size(1)
|
242 |
+
if not self.use_bank:
|
243 |
+
if self.num_tokens > 0:
|
244 |
+
ctx = self.prompt_learner()
|
245 |
+
ctx = ctx.unsqueeze(1).expand(-1, BZ, -1)
|
246 |
+
x = torch.cat((
|
247 |
+
x[:1, :, :], # SOT
|
248 |
+
ctx,
|
249 |
+
x[1:, :, :]
|
250 |
+
), dim=0)
|
251 |
+
x = x[:pos_emb.size(0)] + pos_emb.unsqueeze(1)
|
252 |
+
|
253 |
+
for i, blk in enumerate(self.resblocks):
|
254 |
+
if self.num_tokens > 0 and self.use_bank:
|
255 |
+
k = self.num_tokens
|
256 |
+
num_tokens = 0 if i == 0 else self.num_tokens
|
257 |
+
x = torch.cat((x[:1, :, :], x[num_tokens+1:, :, :]), dim=0)
|
258 |
+
query = self.prompt_inproj(x[eot, torch.arange(BZ), :].detach())
|
259 |
+
if i < self.k_s:
|
260 |
+
out = self.prompt_generator.prompt_pool[i](query, k, istrain=istrain, gamma=gamma)
|
261 |
+
ctx = self.prompt_outproj(out['prompts'])
|
262 |
+
ctx = ctx.transpose(1, 0) + pos_emb.unsqueeze(1)[1:self.num_tokens+1, :]
|
263 |
+
ps_loss += out.get('ps_loss', 0)
|
264 |
+
else:
|
265 |
+
ctx = self.prompt_learner()
|
266 |
+
ctx = ctx.unsqueeze(1).expand(-1, BZ, -1)
|
267 |
+
ctx = ctx + pos_emb.unsqueeze(1)[1:self.num_tokens+1, :]
|
268 |
+
|
269 |
+
x = torch.cat((
|
270 |
+
x[:1, :, :], # SOT
|
271 |
+
ctx,
|
272 |
+
x[1:, :, :]
|
273 |
+
), dim=0)
|
274 |
+
x = x[:pos_emb.size(0)]
|
275 |
+
|
276 |
+
if use_checkpoint:
|
277 |
+
x = checkpoint.checkpoint(blk, x)
|
278 |
+
else:
|
279 |
+
x = blk(x)
|
280 |
+
|
281 |
+
return x, ps_loss
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
class VisionTransformer(nn.Module):
|
286 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
287 |
+
super().__init__()
|
288 |
+
self.input_resolution = input_resolution
|
289 |
+
self.output_dim = output_dim
|
290 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
291 |
+
|
292 |
+
scale = width ** -0.5
|
293 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
294 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
295 |
+
self.ln_pre = LayerNorm(width)
|
296 |
+
|
297 |
+
self.transformer = Transformer(width, layers, heads)
|
298 |
+
|
299 |
+
self.ln_post = LayerNorm(width)
|
300 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
301 |
+
|
302 |
+
def forward(self, x: torch.Tensor, apply_project=True, use_checkpoint=False, cls_at_last=True):
|
303 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
304 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
305 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
306 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
307 |
+
x = x + self.positional_embedding.to(x.dtype)
|
308 |
+
x = self.ln_pre(x)
|
309 |
+
|
310 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
311 |
+
x = self.transformer(x, use_checkpoint=use_checkpoint)
|
312 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
313 |
+
|
314 |
+
if cls_at_last:
|
315 |
+
x = self.ln_post(x[:, 0, :])
|
316 |
+
|
317 |
+
if self.proj is not None and apply_project:
|
318 |
+
x = x @ self.proj
|
319 |
+
|
320 |
+
return x
|
321 |
+
else:
|
322 |
+
return x[:, 1:, :]
|
323 |
+
|
324 |
+
|
325 |
+
class CLIP(nn.Module):
|
326 |
+
def __init__(self,
|
327 |
+
embed_dim: int,
|
328 |
+
# vision
|
329 |
+
image_resolution: int,
|
330 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
331 |
+
vision_width: int,
|
332 |
+
vision_patch_size: int,
|
333 |
+
# text
|
334 |
+
context_length: int,
|
335 |
+
vocab_size: int,
|
336 |
+
transformer_width: int,
|
337 |
+
transformer_heads: int,
|
338 |
+
transformer_layers: int
|
339 |
+
):
|
340 |
+
super().__init__()
|
341 |
+
|
342 |
+
self.context_length = context_length
|
343 |
+
|
344 |
+
if isinstance(vision_layers, (tuple, list)):
|
345 |
+
vision_heads = vision_width * 32 // 64
|
346 |
+
self.visual = ModifiedResNet(
|
347 |
+
layers=vision_layers,
|
348 |
+
output_dim=embed_dim,
|
349 |
+
heads=vision_heads,
|
350 |
+
input_resolution=image_resolution,
|
351 |
+
width=vision_width
|
352 |
+
)
|
353 |
+
else:
|
354 |
+
vision_heads = vision_width // 64
|
355 |
+
self.visual = VisionTransformer(
|
356 |
+
input_resolution=image_resolution,
|
357 |
+
patch_size=vision_patch_size,
|
358 |
+
width=vision_width,
|
359 |
+
layers=vision_layers,
|
360 |
+
heads=vision_heads,
|
361 |
+
output_dim=embed_dim
|
362 |
+
)
|
363 |
+
|
364 |
+
self.transformer = Transformer(
|
365 |
+
width=transformer_width,
|
366 |
+
layers=transformer_layers,
|
367 |
+
heads=transformer_heads,
|
368 |
+
attn_mask=self.build_attention_mask()
|
369 |
+
)
|
370 |
+
|
371 |
+
self.vocab_size = vocab_size
|
372 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
373 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
374 |
+
self.ln_final = LayerNorm(transformer_width)
|
375 |
+
|
376 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
377 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
378 |
+
|
379 |
+
self.initialize_parameters()
|
380 |
+
|
381 |
+
def initialize_parameters(self):
|
382 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
383 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
384 |
+
|
385 |
+
if isinstance(self.visual, ModifiedResNet):
|
386 |
+
if self.visual.attnpool is not None:
|
387 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
388 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
389 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
390 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
391 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
392 |
+
|
393 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
394 |
+
for name, param in resnet_block.named_parameters():
|
395 |
+
if name.endswith("bn3.weight"):
|
396 |
+
nn.init.zeros_(param)
|
397 |
+
|
398 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
399 |
+
attn_std = self.transformer.width ** -0.5
|
400 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
401 |
+
for block in self.transformer.resblocks:
|
402 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
403 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
404 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
405 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
406 |
+
|
407 |
+
if self.text_projection is not None:
|
408 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
409 |
+
|
410 |
+
def build_attention_mask(self):
|
411 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
412 |
+
# pytorch uses additive attention mask; fill with -inf
|
413 |
+
mask = torch.empty(self.context_length, self.context_length)
|
414 |
+
mask.fill_(float("-inf"))
|
415 |
+
mask.triu_(1) # zero out the lower diagonal
|
416 |
+
return mask
|
417 |
+
|
418 |
+
@property
|
419 |
+
def dtype(self):
|
420 |
+
return self.visual.conv1.weight.dtype
|
421 |
+
|
422 |
+
def encode_image(self, image, apply_project=True, use_checkpoint=False):
|
423 |
+
if image.ndim == 4:
|
424 |
+
return self.visual(image.type(self.dtype))
|
425 |
+
else:
|
426 |
+
image = image.permute(0, 2, 1, 3, 4) # BCTHW -> BTCHW
|
427 |
+
bb, tt, _, _, _ = image.shape
|
428 |
+
x = self.visual(image.reshape(-1, *image.shape[2:]), apply_project=apply_project, use_checkpoint=use_checkpoint) # ND
|
429 |
+
x = x.view(bb, tt, -1)
|
430 |
+
image_features = x.mean(1)
|
431 |
+
# image_features = x.max(1).values
|
432 |
+
return image_features
|
433 |
+
|
434 |
+
def encode_text(self, text, use_checkpoint=False):
|
435 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
436 |
+
|
437 |
+
x = x + self.positional_embedding.type(self.dtype)
|
438 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
439 |
+
x = self.transformer(x, use_checkpoint=use_checkpoint)
|
440 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
441 |
+
x = self.ln_final(x).type(self.dtype)
|
442 |
+
|
443 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
444 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
445 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
446 |
+
|
447 |
+
return x
|
448 |
+
|
449 |
+
def forward(self, image, text, use_checkpoint=False, norm_embed=True):
|
450 |
+
image_features = self.encode_image(image, use_checkpoint=use_checkpoint)
|
451 |
+
text_features = self.encode_text(text, use_checkpoint=use_checkpoint)
|
452 |
+
|
453 |
+
# normalized features
|
454 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
455 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
456 |
+
|
457 |
+
# # cosine similarity as logits
|
458 |
+
# logit_scale = self.logit_scale.exp()
|
459 |
+
# logits_per_image = logit_scale * image_features @ text_features.t()
|
460 |
+
# logits_per_text = logits_per_image.t()
|
461 |
+
|
462 |
+
# # shape = [global_batch_size, global_batch_size]
|
463 |
+
# return logits_per_image, logits_per_text
|
464 |
+
|
465 |
+
return {'image_embed': image_features,
|
466 |
+
'text_embed': text_features,
|
467 |
+
'logit_scale': self.logit_scale.exp()}
|
468 |
+
|
469 |
+
|
470 |
+
def convert_weights(model: nn.Module):
|
471 |
+
"""Convert applicable model parameters to fp16"""
|
472 |
+
|
473 |
+
def _convert_weights_to_fp16(l):
|
474 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
475 |
+
l.weight.data = l.weight.data.half()
|
476 |
+
if l.bias is not None:
|
477 |
+
l.bias.data = l.bias.data.half()
|
478 |
+
|
479 |
+
if isinstance(l, nn.MultiheadAttention):
|
480 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
481 |
+
tensor = getattr(l, attr)
|
482 |
+
if tensor is not None:
|
483 |
+
tensor.data = tensor.data.half()
|
484 |
+
|
485 |
+
for name in ["text_projection", "proj"]:
|
486 |
+
if hasattr(l, name):
|
487 |
+
attr = getattr(l, name)
|
488 |
+
if attr is not None:
|
489 |
+
attr.data = attr.data.half()
|
490 |
+
|
491 |
+
model.apply(_convert_weights_to_fp16)
|
492 |
+
|
493 |
+
|
494 |
+
def build_model(state_dict: dict):
|
495 |
+
vit = "visual.proj" in state_dict
|
496 |
+
|
497 |
+
if vit:
|
498 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
499 |
+
vision_layers = len(
|
500 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]
|
501 |
+
)
|
502 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
503 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
504 |
+
image_resolution = vision_patch_size * grid_size
|
505 |
+
else:
|
506 |
+
counts: list = [
|
507 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]
|
508 |
+
]
|
509 |
+
vision_layers = tuple(counts)
|
510 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
511 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
512 |
+
vision_patch_size = None
|
513 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
514 |
+
image_resolution = output_width * 32
|
515 |
+
|
516 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
517 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
518 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
519 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
520 |
+
transformer_heads = transformer_width // 64
|
521 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
522 |
+
|
523 |
+
model = CLIP(
|
524 |
+
embed_dim,
|
525 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
526 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
527 |
+
)
|
528 |
+
|
529 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
530 |
+
if key in state_dict:
|
531 |
+
del state_dict[key]
|
532 |
+
|
533 |
+
convert_weights(model)
|
534 |
+
model.load_state_dict(state_dict)
|
535 |
+
return model.eval()
|
lavila/models/prompt_tuning.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
from functools import reduce
|
4 |
+
from operator import mul
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
import pdb
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
class PromptLearner(nn.Module):
|
13 |
+
def __init__(self, ctx_dim=512, n_ctx=16):
|
14 |
+
super(PromptLearner, self).__init__()
|
15 |
+
self.n_ctx = n_ctx
|
16 |
+
self.ctx_dim = ctx_dim
|
17 |
+
|
18 |
+
# initialize prompts
|
19 |
+
ctx_vectors = torch.empty(n_ctx, ctx_dim)
|
20 |
+
nn.init.normal_(ctx_vectors, std=0.02)
|
21 |
+
prompt_prefix = " ".join(["X"] * n_ctx)
|
22 |
+
self.ctx = nn.Parameter(ctx_vectors) # to be optimized
|
23 |
+
print(f'Initial context: "{prompt_prefix}"')
|
24 |
+
print(f"Number of context words (tokens): {n_ctx}")
|
25 |
+
|
26 |
+
def forward(self):
|
27 |
+
return self.ctx
|
28 |
+
|
29 |
+
class PromptPoolLearner(nn.Module):
|
30 |
+
def __init__(self, prompt_dim=256, size=128, length=1):
|
31 |
+
super(PromptPoolLearner, self).__init__()
|
32 |
+
self.prompt_dim = prompt_dim
|
33 |
+
self.length = length
|
34 |
+
self.size = size
|
35 |
+
|
36 |
+
# initiate prompt
|
37 |
+
self.prompt_values = nn.Parameter(torch.zeros(size, length, prompt_dim))
|
38 |
+
self.id_table = torch.ones([size]).cuda()
|
39 |
+
|
40 |
+
# xavier_uniform initialization
|
41 |
+
nn.init.uniform_(self.prompt_values.data, -1, 1)
|
42 |
+
|
43 |
+
def l2_normalize(self, x, dim=None, epsilon=1e-12):
|
44 |
+
"""Normalizes a given vector or matrix."""
|
45 |
+
square_sum = torch.sum(x ** 2, dim=dim, keepdim=True)
|
46 |
+
x_inv_norm = torch.rsqrt(torch.maximum(square_sum, torch.tensor(epsilon, device=x.device)))
|
47 |
+
return x * x_inv_norm
|
48 |
+
|
49 |
+
def forward(self, query, k=0, istrain=False, gamma=1.0):
|
50 |
+
BZ = query.shape[0]
|
51 |
+
out = dict()
|
52 |
+
query = self.l2_normalize(query.squeeze(1), dim=1)
|
53 |
+
keys = self.prompt_values.mean(dim=1)
|
54 |
+
keys = self.l2_normalize(keys, dim=1)
|
55 |
+
similarity = torch.matmul(query, keys.t())
|
56 |
+
|
57 |
+
if k > 0 and k < self.size:
|
58 |
+
|
59 |
+
if istrain:
|
60 |
+
inv_freq = self.id_table.sum() / self.id_table.float()
|
61 |
+
weights = (similarity + 1) / 2 * gamma + (1 - gamma) * torch.softmax(inv_freq, dim=-1)
|
62 |
+
idx = torch.multinomial(weights, k, replacement=False)
|
63 |
+
else:
|
64 |
+
idx = torch.argsort(similarity, dim=-1, descending=True)[:, :k]
|
65 |
+
|
66 |
+
prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True)
|
67 |
+
self.id_table[prompt_id] += id_counts
|
68 |
+
prompts = self.prompt_values[idx.flatten(), ...].view(BZ, k * self.length, self.prompt_dim)
|
69 |
+
else:
|
70 |
+
idx = torch.arange(self.size).unsqueeze(0).expand(BZ, -1)
|
71 |
+
prompts = self.prompt_values.flatten(0, 1).unsqueeze(0).expand(BZ, -1, -1)
|
72 |
+
|
73 |
+
prompts = self.l2_normalize(prompts, dim=-1)
|
74 |
+
out['prompts'] = prompts
|
75 |
+
sel_sim = similarity[torch.arange(BZ).view(-1, 1), idx]
|
76 |
+
sel_key = keys[idx.flatten(), ...].view(BZ, k, self.prompt_dim)
|
77 |
+
diff = F.mse_loss((sel_sim.unsqueeze(1) @ sel_key).squeeze(1), query.detach(), reduction='sum') / BZ
|
78 |
+
ksim = torch.sum(torch.abs(torch.matmul(keys, keys.t()) - torch.eye(self.size).to(keys.device))) / BZ
|
79 |
+
out['ps_loss'] = diff + ksim
|
80 |
+
|
81 |
+
return out
|
82 |
+
|
83 |
+
|
84 |
+
class VisualPromptLearner(nn.Module):
|
85 |
+
def __init__(self, patch_size=16, embed_dim=768, num_layers=12, prompt_dim=256, num_tokens=5, deep=False,
|
86 |
+
deep_shared=False, split_st=False, dropout=0.1, pool={}):
|
87 |
+
super(VisualPromptLearner, self).__init__()
|
88 |
+
self.num_layers = num_layers
|
89 |
+
self.embed_dim = embed_dim
|
90 |
+
self.prompt_dim = prompt_dim
|
91 |
+
self.num_tokens = num_tokens # number of prompted tokens
|
92 |
+
self.prompt_dropout = nn.Dropout(dropout)
|
93 |
+
pool_size = pool.get('size', 0)
|
94 |
+
self.pool_length = pool.get('length', 1)
|
95 |
+
self.use_bank = True if pool_size > 0 and num_tokens <= (pool_size * self.pool_length) else False
|
96 |
+
if self.use_bank:
|
97 |
+
print(f'Using feature bank with size {pool_size} (dimension: {prompt_dim})')
|
98 |
+
|
99 |
+
if prompt_dim != embed_dim:
|
100 |
+
self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False)
|
101 |
+
else:
|
102 |
+
self.prompt_inproj = nn.Identity()
|
103 |
+
|
104 |
+
if self.use_bank:
|
105 |
+
self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False)
|
106 |
+
nn.init.kaiming_normal_(
|
107 |
+
self.prompt_outproj.weight, a=0, mode='fan_out')
|
108 |
+
else:
|
109 |
+
self.prompt_outproj = nn.Identity()
|
110 |
+
|
111 |
+
self.split_st = split_st # split spatial and temporal prompts
|
112 |
+
|
113 |
+
# initiate prompt:
|
114 |
+
val = math.sqrt(6. / float(3 * reduce(mul, (patch_size, patch_size), 1) + prompt_dim))
|
115 |
+
if split_st:
|
116 |
+
if self.use_bank:
|
117 |
+
pool['size'] //= 2
|
118 |
+
self.spatial_prompt_pool = PromptPoolLearner(prompt_dim, **pool)
|
119 |
+
self.temporal_prompt_pool = PromptPoolLearner(prompt_dim, **pool)
|
120 |
+
else:
|
121 |
+
self.spatial_prompt_embeddings = nn.Parameter(torch.zeros(
|
122 |
+
1, num_tokens // 2, prompt_dim))
|
123 |
+
self.temporal_prompt_embeddings = nn.Parameter(torch.zeros(
|
124 |
+
1, num_tokens // 2, prompt_dim))
|
125 |
+
# xavier_uniform initialization
|
126 |
+
nn.init.uniform_(self.spatial_prompt_embeddings.data, -val, val)
|
127 |
+
nn.init.uniform_(self.temporal_prompt_embeddings.data, -val, val)
|
128 |
+
else:
|
129 |
+
if self.use_bank:
|
130 |
+
self.prompt_pool = PromptPoolLearner(prompt_dim, **pool)
|
131 |
+
else:
|
132 |
+
self.prompt_embeddings = nn.Parameter(torch.zeros(
|
133 |
+
1, num_tokens, prompt_dim))
|
134 |
+
# xavier_uniform initialization
|
135 |
+
nn.init.uniform_(self.prompt_embeddings.data, -val, val)
|
136 |
+
|
137 |
+
self.deep = deep or deep_shared
|
138 |
+
self.deep_shared = deep_shared
|
139 |
+
if deep and (not deep_shared):
|
140 |
+
total_d_layer = num_layers - 1
|
141 |
+
if split_st:
|
142 |
+
if self.use_bank:
|
143 |
+
self.spatial_deep_prompt_pool = nn.ModuleList([
|
144 |
+
PromptPoolLearner(prompt_dim, **pool)
|
145 |
+
for i in range(total_d_layer)])
|
146 |
+
self.temporal_deep_prompt_pool = nn.ModuleList([
|
147 |
+
PromptPoolLearner(prompt_dim, **pool)
|
148 |
+
for i in range(total_d_layer)])
|
149 |
+
else:
|
150 |
+
self.spatial_deep_prompt_embeddings = nn.Parameter(torch.zeros(
|
151 |
+
total_d_layer, num_tokens // 2, prompt_dim))
|
152 |
+
self.temporal_deep_prompt_embeddings = nn.Parameter(torch.zeros(
|
153 |
+
total_d_layer, num_tokens // 2, prompt_dim))
|
154 |
+
# xavier_uniform initialization
|
155 |
+
nn.init.uniform_(self.spatial_deep_prompt_embeddings.data, -val, val)
|
156 |
+
nn.init.uniform_(self.temporal_deep_prompt_embeddings.data, -val, val)
|
157 |
+
else:
|
158 |
+
if self.use_bank:
|
159 |
+
self.deep_prompt_pool = nn.ModuleList([
|
160 |
+
PromptPoolLearner(prompt_dim, **pool)
|
161 |
+
for i in range(total_d_layer)])
|
162 |
+
else:
|
163 |
+
self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
|
164 |
+
total_d_layer, num_tokens, prompt_dim))
|
165 |
+
# xavier_uniform initialization
|
166 |
+
nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
|
167 |
+
|
168 |
+
def forward(self, query=None, layer=0, istrain=False, gamma=1.0):
|
169 |
+
query = query.detach()
|
170 |
+
query = self.prompt_inproj(query)
|
171 |
+
ps_loss = query.new_zeros([1])
|
172 |
+
if self.split_st:
|
173 |
+
if self.deep and (not self.deep_shared) and layer > 0:
|
174 |
+
if self.use_bank:
|
175 |
+
k = (self.num_tokens // 2) // self.pool_length
|
176 |
+
spatial_out = self.spatial_deep_prompt_pool[layer-1](query, k, istrain, gamma)
|
177 |
+
spatial_prompts = spatial_out['prompts']
|
178 |
+
temporal_out = self.temporal_deep_prompt_pool[layer-1](query, k, istrain, gamma)
|
179 |
+
temporal_prompts = temporal_out['prompts']
|
180 |
+
ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0)
|
181 |
+
else:
|
182 |
+
spatial_prompts = self.spatial_deep_prompt_embeddings[layer-1]
|
183 |
+
temporal_prompts = self.temporal_deep_prompt_embeddings[layer-1]
|
184 |
+
else:
|
185 |
+
if self.use_bank:
|
186 |
+
k = (self.num_tokens // 2) // self.pool_length
|
187 |
+
spatial_out = self.spatial_prompt_pool(query, k, istrain, gamma)
|
188 |
+
spatial_prompts = spatial_out['prompts']
|
189 |
+
temporal_out = self.temporal_prompt_pool(query, k, istrain, gamma)
|
190 |
+
temporal_prompts = temporal_out['prompts']
|
191 |
+
ps_loss += spatial_out.get('ps_loss', 0) + temporal_out.get('ps_loss', 0)
|
192 |
+
else:
|
193 |
+
spatial_prompts = self.spatial_prompt_embeddings
|
194 |
+
temporal_prompts = self.temporal_prompt_embeddings
|
195 |
+
|
196 |
+
prompts = torch.cat((spatial_prompts, temporal_prompts), dim=1)
|
197 |
+
|
198 |
+
else:
|
199 |
+
if self.deep and (not self.deep_shared) and layer > 0:
|
200 |
+
if self.use_bank:
|
201 |
+
k = self.num_tokens // self.pool_length
|
202 |
+
out = self.deep_prompt_pool[layer-1](query, k, istrain, gamma)
|
203 |
+
prompts = out['prompts']
|
204 |
+
ps_loss += out.get('ps_loss', 0)
|
205 |
+
else:
|
206 |
+
prompts = self.deep_prompt_embeddings[layer-1]
|
207 |
+
else:
|
208 |
+
if self.use_bank:
|
209 |
+
k = self.num_tokens // self.pool_length
|
210 |
+
out = self.prompt_pool(query, k, istrain, gamma)
|
211 |
+
prompts = out['prompts']
|
212 |
+
ps_loss += out.get('ps_loss', 0)
|
213 |
+
else:
|
214 |
+
prompts = self.prompt_embeddings
|
215 |
+
|
216 |
+
prompts = self.prompt_dropout(self.prompt_outproj(prompts))
|
217 |
+
return prompts, ps_loss
|
218 |
+
|
219 |
+
|
220 |
+
class CMM(nn.Module):
|
221 |
+
'''Context modeling module'''
|
222 |
+
def __init__(self, num_tokens=8, num_frames=16, embed_dim=768, prompt_dim=256, dropout=0., num_layer=1, shared=False, pool={}):
|
223 |
+
super(CMM, self).__init__()
|
224 |
+
self.num_tokens = num_tokens
|
225 |
+
self.num_frames = num_frames
|
226 |
+
self.embed_dim = embed_dim
|
227 |
+
self.prompt_dim = prompt_dim
|
228 |
+
self.pool_size = pool.get('size', 0)
|
229 |
+
self.pool_length = pool.get('length', 1)
|
230 |
+
self.use_bank = True if self.pool_size > 0 else False
|
231 |
+
self.use_rnn = not self.use_bank
|
232 |
+
if self.use_rnn:
|
233 |
+
self.rnn = nn.LSTM(input_size=embed_dim, hidden_size=embed_dim,
|
234 |
+
num_layers=1, batch_first=True, dropout=dropout, bidirectional=True)
|
235 |
+
self.shared = shared
|
236 |
+
self.prompt_dropout = nn.Dropout(dropout)
|
237 |
+
|
238 |
+
if self.use_bank:
|
239 |
+
print(f'Using feature bank with size {self.pool_size} (dimension: {prompt_dim})')
|
240 |
+
if self.use_rnn:
|
241 |
+
self.prompt_inproj = nn.Linear(embed_dim * 2, prompt_dim)
|
242 |
+
nn.init.kaiming_normal_(
|
243 |
+
self.prompt_inproj.weight, a=0, mode='fan_out')
|
244 |
+
else:
|
245 |
+
if embed_dim != prompt_dim:
|
246 |
+
self.prompt_inproj = nn.Linear(embed_dim, prompt_dim, bias=False)
|
247 |
+
else:
|
248 |
+
self.prompt_inproj = nn.Identity()
|
249 |
+
|
250 |
+
self.prompt_outproj = nn.Linear(prompt_dim, embed_dim, bias=False)
|
251 |
+
nn.init.kaiming_normal_(
|
252 |
+
self.prompt_outproj.weight, a=0, mode='fan_out')
|
253 |
+
|
254 |
+
if shared:
|
255 |
+
self.prompt_pool = PromptPoolLearner(prompt_dim, **pool)
|
256 |
+
else:
|
257 |
+
self.prompt_pool = nn.ModuleList([
|
258 |
+
PromptPoolLearner(prompt_dim, **pool)
|
259 |
+
for i in range(num_layer)])
|
260 |
+
else:
|
261 |
+
self.fc = nn.Linear(embed_dim * 2, embed_dim * num_tokens)
|
262 |
+
|
263 |
+
def forward(self, x, layer=0, istrain=False, gamma=1.0):
|
264 |
+
BZ = x.size(0)
|
265 |
+
x = x.detach()
|
266 |
+
x = rearrange(x, 'b (f n) d -> b f n d', f=self.num_frames)
|
267 |
+
x = torch.mean(x, dim=2)
|
268 |
+
|
269 |
+
if self.use_rnn:
|
270 |
+
x, _ = self.rnn(x)
|
271 |
+
|
272 |
+
ps_loss = x.new_zeros([1])
|
273 |
+
if self.use_bank:
|
274 |
+
query = self.prompt_inproj(x).flatten(0, 1)
|
275 |
+
k = self.num_tokens // self.pool_length
|
276 |
+
if self.shared:
|
277 |
+
out = self.prompt_pool(query, k, istrain, gamma)
|
278 |
+
else:
|
279 |
+
out = self.prompt_pool[layer](query, k, istrain, gamma)
|
280 |
+
|
281 |
+
prompts = rearrange(out['prompts'], '(b f) p d -> b (f p) d', f=self.num_frames)
|
282 |
+
prompts = self.prompt_outproj(prompts)
|
283 |
+
ps_loss += out.get('ps_loss', 0) * self.num_frames
|
284 |
+
|
285 |
+
else:
|
286 |
+
prompts = self.fc(x)
|
287 |
+
prompts = rearrange(prompts, 'b f (p d) -> b (f p) d', p=self.num_tokens)
|
288 |
+
|
289 |
+
return prompts, ps_loss
|
290 |
+
|
291 |
+
|
lavila/models/timesformer.py
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/m-bain/frozen-in-time/blob/main/model/video_transformer.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
# The original code is under MIT License
|
10 |
+
|
11 |
+
"""
|
12 |
+
Implementations of Video Transformers in PyTorch
|
13 |
+
A PyTorch implementation of space-time transformer as described in
|
14 |
+
'Frozen in Time: A Joint Image and Video Encoder for End-to-End Retrieval' - https://arxiv.org/abs/2104.00650
|
15 |
+
A PyTorch implementation of timesformer as described in
|
16 |
+
'Is Space-Time Attention All You Need for Video Understanding?' - https://arxiv.org/abs/2102.05095
|
17 |
+
Acknowledgments:
|
18 |
+
- This code builds on Ross Wightman's vision_transformer code in pytorch-image-models:
|
19 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
20 |
+
- It is also inspired by lucidrains timesformer implementation:
|
21 |
+
https://github.com/lucidrains/TimeSformer-pytorch
|
22 |
+
Hacked together by Max Bain
|
23 |
+
"""
|
24 |
+
|
25 |
+
from collections import OrderedDict, defaultdict
|
26 |
+
from functools import partial, reduce
|
27 |
+
import operator
|
28 |
+
import copy
|
29 |
+
|
30 |
+
import torch
|
31 |
+
import torch.utils.checkpoint as checkpoint
|
32 |
+
from einops import rearrange, repeat
|
33 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
34 |
+
from torch import einsum, nn
|
35 |
+
import torch.nn.functional as F
|
36 |
+
import pdb
|
37 |
+
|
38 |
+
from lavila.models.prompt_tuning import VisualPromptLearner, CMM
|
39 |
+
|
40 |
+
|
41 |
+
def attn(q, k, v):
|
42 |
+
sim = einsum('b i d, b j d -> b i j', q, k)
|
43 |
+
attn = sim.softmax(dim=-1)
|
44 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
45 |
+
return out
|
46 |
+
|
47 |
+
|
48 |
+
class Mlp(nn.Module):
|
49 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
50 |
+
super().__init__()
|
51 |
+
out_features = out_features or in_features
|
52 |
+
hidden_features = hidden_features or in_features
|
53 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
54 |
+
self.act = act_layer()
|
55 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
56 |
+
self.drop = nn.Dropout(drop)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
x = self.fc1(x)
|
60 |
+
x = self.act(x)
|
61 |
+
x = self.drop(x)
|
62 |
+
x = self.fc2(x)
|
63 |
+
x = self.drop(x)
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
class VideoPatchEmbed(nn.Module):
|
68 |
+
""" Video to Patch Embedding
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
|
72 |
+
num_frames=8, ln_pre=False):
|
73 |
+
super().__init__()
|
74 |
+
img_size = to_2tuple(img_size)
|
75 |
+
patch_size = to_2tuple(patch_size)
|
76 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * num_frames
|
77 |
+
self.img_size = img_size
|
78 |
+
self.patch_size = patch_size
|
79 |
+
self.num_patches = num_patches
|
80 |
+
self.num_frames = num_frames
|
81 |
+
self.embed_dim = embed_dim
|
82 |
+
# ln_pre is inserted to be compatible with CLIP-style model
|
83 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=not ln_pre)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
B, F, C, H, W = x.shape
|
87 |
+
assert F <= self.num_frames
|
88 |
+
x = x.view(-1, C, H, W)
|
89 |
+
x = self.proj(x)
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
class VarAttention(nn.Module):
|
94 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
|
95 |
+
initialize='random', num_tokens=0):
|
96 |
+
super().__init__()
|
97 |
+
self.num_heads = num_heads
|
98 |
+
head_dim = dim // num_heads
|
99 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
100 |
+
self.scale = qk_scale or head_dim ** -0.5
|
101 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
102 |
+
self.proj = nn.Linear(dim, dim)
|
103 |
+
if initialize == 'zeros':
|
104 |
+
self.qkv.weight.data.fill_(0)
|
105 |
+
self.qkv.bias.data.fill_(0)
|
106 |
+
# fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs
|
107 |
+
# are multiplied by 0*0, which is hard for the model to move out of.
|
108 |
+
self.proj.weight.data.fill_(1)
|
109 |
+
self.proj.bias.data.fill_(0)
|
110 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
111 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
112 |
+
self.num_tokens = num_tokens
|
113 |
+
|
114 |
+
def forward(self, x, einops_from, einops_to, einops_dims, cfg):
|
115 |
+
style = cfg.get('style', 'default')
|
116 |
+
pt_att = cfg.get('pt_att', True)
|
117 |
+
n_seg = cfg.get('n_seg', 4)
|
118 |
+
if 'VoP' in style:
|
119 |
+
return self.forward_VoP(x, einops_from, einops_to, einops_dims, n_seg)
|
120 |
+
elif style == 'attall':
|
121 |
+
return self.forward_attall(x, pt_att)
|
122 |
+
else:
|
123 |
+
return self.forward_features(x, einops_from, einops_to, einops_dims, pt_att)
|
124 |
+
|
125 |
+
def forward_features(self, x, einops_from, einops_to, einops_dims, pt_att=True):
|
126 |
+
h = self.num_heads
|
127 |
+
num_tokens = self.num_tokens
|
128 |
+
if self.num_tokens > 0 and not pt_att:
|
129 |
+
prompts = x[:, 1:self.num_tokens+1, :]
|
130 |
+
x = torch.cat((
|
131 |
+
x[:, :1, :], # cls_token
|
132 |
+
x[:, self.num_tokens+1:, :] # patch embeddings
|
133 |
+
), dim=1)
|
134 |
+
num_tokens = 0
|
135 |
+
|
136 |
+
# project x to q, k, v values
|
137 |
+
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
138 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
139 |
+
|
140 |
+
q *= self.scale
|
141 |
+
|
142 |
+
# splice out CLS token at index 1 (and prompts)
|
143 |
+
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:num_tokens+1], t[:, num_tokens+1:]), (q, k, v)) # Bh x () x d
|
144 |
+
|
145 |
+
# let CLS token attend to key / values of all patches across time and space
|
146 |
+
cls_out = attn(cls_q, k, v) # Bh x (1 + p) x d
|
147 |
+
# rearrange across time or space
|
148 |
+
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_)) # Bh x NT x d -> Bhr x s x d
|
149 |
+
|
150 |
+
# expand cls token keys and values across time or space and concat
|
151 |
+
r = q_.shape[0] // cls_k.shape[0]
|
152 |
+
cls_k, cls_v = map(lambda t: repeat(t, 'b p d -> (b r) p d', r=r), (cls_k, cls_v)) # Bhr x (1 + p) x d
|
153 |
+
k_ = torch.cat((cls_k, k_), dim=1)
|
154 |
+
v_ = torch.cat((cls_v, v_), dim=1)
|
155 |
+
|
156 |
+
# attention
|
157 |
+
out = attn(q_, k_, v_)
|
158 |
+
|
159 |
+
# merge back time or space
|
160 |
+
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) # Bh x NT x d
|
161 |
+
|
162 |
+
# concat back the cls token
|
163 |
+
out = torch.cat((cls_out, out), dim=1) # Bh x (1 + p + NT) x d
|
164 |
+
|
165 |
+
# merge back the heads
|
166 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # B x (1 + p + NT) x hd
|
167 |
+
if self.num_tokens > 0 and not pt_att:
|
168 |
+
out = torch.cat((
|
169 |
+
out[:, :1, :], # cls_tokens
|
170 |
+
prompts,
|
171 |
+
out[:, 1:, :] # patch embeddings
|
172 |
+
), dim=1)
|
173 |
+
|
174 |
+
# to out
|
175 |
+
x = self.proj(out)
|
176 |
+
x = self.proj_drop(x)
|
177 |
+
return x
|
178 |
+
|
179 |
+
def forward_VoP(self, x, einops_from, einops_to, einops_dims, n_seg=4):
|
180 |
+
# position-specific prompts for spatial attention
|
181 |
+
h = self.num_heads
|
182 |
+
num_tokens = self.num_tokens
|
183 |
+
|
184 |
+
# project x to q, k, v values
|
185 |
+
q, k, v = self.qkv(x).chunk(3, dim=-1) # B x (1+p+NT) x hd
|
186 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) # Bh x (1+p+NT) x d
|
187 |
+
|
188 |
+
q *= self.scale
|
189 |
+
|
190 |
+
# splice out CLS token at index 1 and prompts
|
191 |
+
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:num_tokens+1], t[:, num_tokens+1:]), (q, k, v)) # Bh x () x d
|
192 |
+
# let CLS token attend to key / values of all patches across time and space
|
193 |
+
cls_out = attn(cls_q[:, :1, :], k, v) # cls token: Bh x 1 x d
|
194 |
+
|
195 |
+
# segment prompts into s segments in time
|
196 |
+
pstep = num_tokens // n_seg
|
197 |
+
pseg = [range(st, en) for st, en in zip(range(1, num_tokens+1, pstep), range(pstep+1, num_tokens+2, pstep))]
|
198 |
+
p_q, p_k, p_v = map(lambda t: rearrange(t[:, pseg, :], 'b s p d -> (b s) p d'), (cls_q, cls_k, cls_v)) # prompt query: (Bh x n_seg) x p_per_seg x d
|
199 |
+
|
200 |
+
# segment patch embeddings into s segments in time
|
201 |
+
q_, k_, v_ = map(lambda t: rearrange(t, 'b (f n) d -> b f n d', **einops_dims), (q_, k_, v_)) # Bh x T x N x d
|
202 |
+
num_frames = k_.size(1)
|
203 |
+
tstep = num_frames // n_seg
|
204 |
+
tseg = [range(st, en) for st, en in zip(range(0, num_frames, tstep), range(tstep, num_frames+1, tstep))]
|
205 |
+
q_, k_, v_ = map(lambda t: t[:, tseg, ...], (q_, k_, v_)) # Bh x n_seg x f_per_seg x n x d
|
206 |
+
q_, k_, v_ = map(lambda t: rearrange(t, 'b s f n d -> (b s) (f n) d'), (q_, k_, v_)) # (Bh x n_seg) x (f_per_seg x n) x d
|
207 |
+
|
208 |
+
# concatenate prompts and patch embeddings
|
209 |
+
k_, v_ = map(lambda t: torch.cat((t[0], t[1]), dim=1), ((p_k, k_), (p_v, v_)))
|
210 |
+
p_out = attn(p_q, k_, v_) # (Bh x n_seg) x p_per_seg x d
|
211 |
+
out = attn(q_, k_, v_) # (Bh x n_seg) x (f_per_seg x n) x d
|
212 |
+
p_out = rearrange(p_out, '(b s) p d -> b (s p) d', s=n_seg) # Bh x p x d
|
213 |
+
out = rearrange(out, '(b s) (f n) d -> b (s f n) d', s=n_seg, f=tstep) # Bh x NT x d
|
214 |
+
|
215 |
+
# merge tokens
|
216 |
+
out = torch.cat((cls_out, p_out, out), dim=1) # Bh x (1+p+NT) x d
|
217 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # B x (NT+1) x hd
|
218 |
+
|
219 |
+
# to out
|
220 |
+
x = self.proj(out)
|
221 |
+
x = self.proj_drop(x)
|
222 |
+
return x
|
223 |
+
|
224 |
+
def forward_attall(self, x, pt_att=True):
|
225 |
+
h = self.num_heads
|
226 |
+
if self.num_tokens > 0 and not pt_att:
|
227 |
+
prompts = x[:, 1:self.num_tokens+1, :]
|
228 |
+
x = torch.cat((
|
229 |
+
x[:, :1, :], # cls_token
|
230 |
+
x[:, self.num_tokens+1:, :] # patch embeddings
|
231 |
+
), dim=1)
|
232 |
+
|
233 |
+
# project x to q, k, v values
|
234 |
+
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
235 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
236 |
+
|
237 |
+
q *= self.scale
|
238 |
+
|
239 |
+
# all tokens attend to all tokens
|
240 |
+
out = attn(q, k, v)
|
241 |
+
|
242 |
+
# merge back the heads
|
243 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # B x (1 + p + NT) x hd
|
244 |
+
if self.num_tokens > 0 and not pt_att:
|
245 |
+
out = torch.cat((
|
246 |
+
out[:, :1, :], # cls_tokens
|
247 |
+
prompts,
|
248 |
+
out[:, 1:, :] # patch embeddings
|
249 |
+
), dim=1)
|
250 |
+
|
251 |
+
# to out
|
252 |
+
x = self.proj(out)
|
253 |
+
x = self.proj_drop(x)
|
254 |
+
return x
|
255 |
+
|
256 |
+
|
257 |
+
class SpaceTimeBlock(nn.Module):
|
258 |
+
|
259 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
260 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, time_init='zeros',
|
261 |
+
attention_style='frozen-in-time', is_tanh_gating=False, num_tokens=0, split_st=False):
|
262 |
+
super().__init__()
|
263 |
+
|
264 |
+
self.split_st = split_st # split spatial and temporal prompts
|
265 |
+
if split_st:
|
266 |
+
num_tokens = num_tokens // 2
|
267 |
+
self.num_tokens = num_tokens # learnable prompts
|
268 |
+
|
269 |
+
self.norm1 = norm_layer(dim)
|
270 |
+
self.attn = VarAttention(
|
271 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, num_tokens=num_tokens)
|
272 |
+
|
273 |
+
self.timeattn = VarAttention(
|
274 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, num_tokens=num_tokens,
|
275 |
+
initialize=time_init)
|
276 |
+
|
277 |
+
if is_tanh_gating:
|
278 |
+
self.alpha_timeattn = nn.Parameter(torch.zeros([]))
|
279 |
+
|
280 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
281 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
282 |
+
self.norm2 = norm_layer(dim)
|
283 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
284 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
285 |
+
self.norm3 = norm_layer(dim)
|
286 |
+
|
287 |
+
self.attention_style = attention_style
|
288 |
+
|
289 |
+
def forward(self, x, einops_from_space, einops_to_space, einops_from_time, einops_to_time,
|
290 |
+
time_n, space_f, use_checkpoint=False, pt_spt=True, pt_tmp=True, style='default', n_seg=4):
|
291 |
+
if self.split_st:
|
292 |
+
spatial_prompts = x[:, 1:self.num_tokens+1, :]
|
293 |
+
x = torch.cat((
|
294 |
+
x[:, :1, :], # cls_token
|
295 |
+
x[:, self.num_tokens+1:, :] # temporal prompts and patch embeddings
|
296 |
+
), dim=1)
|
297 |
+
|
298 |
+
if use_checkpoint:
|
299 |
+
time_output = checkpoint.checkpoint(
|
300 |
+
self.timeattn, self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}, {'pt_att': pt_tmp}
|
301 |
+
)
|
302 |
+
else:
|
303 |
+
time_output = self.timeattn(self.norm3(x), einops_from_time, einops_to_time, {"n": time_n}, {'pt_att': pt_tmp})
|
304 |
+
if hasattr(self, "alpha_timeattn"):
|
305 |
+
time_output = torch.tanh(self.alpha_timeattn) * time_output
|
306 |
+
time_residual = x + time_output
|
307 |
+
|
308 |
+
if self.split_st:
|
309 |
+
temporal_prompts = time_residual[:, 1:self.num_tokens+1, :]
|
310 |
+
time_residual = torch.cat((
|
311 |
+
time_residual[:, :1, :], # cls_token
|
312 |
+
spatial_prompts,
|
313 |
+
time_residual[:, self.num_tokens+1:, :] # patch embeddings
|
314 |
+
), dim=1)
|
315 |
+
|
316 |
+
cfg = {'style': style, 'pt_att': pt_spt, 'n_seg': n_seg}
|
317 |
+
if use_checkpoint:
|
318 |
+
space_output = checkpoint.checkpoint(
|
319 |
+
self.attn, self.norm1(time_residual), einops_from_space, einops_to_space, {"f": space_f}, cfg
|
320 |
+
)
|
321 |
+
else:
|
322 |
+
space_output = self.attn(self.norm1(time_residual), einops_from_space,
|
323 |
+
einops_to_space, {"f": space_f}, cfg)
|
324 |
+
if self.attention_style == 'frozen-in-time':
|
325 |
+
space_residual = x + self.drop_path(space_output)
|
326 |
+
else:
|
327 |
+
raise NotImplementedError
|
328 |
+
|
329 |
+
if self.split_st:
|
330 |
+
space_residual = torch.cat((
|
331 |
+
space_residual[:, :self.num_tokens+1, :], # cls_token and spacial prompts
|
332 |
+
temporal_prompts,
|
333 |
+
space_residual[:, self.num_tokens+1:, :] # patch embeddings
|
334 |
+
), dim=1)
|
335 |
+
|
336 |
+
x = space_residual + self.drop_path(self.mlp(self.norm2(space_residual)))
|
337 |
+
|
338 |
+
return x
|
339 |
+
|
340 |
+
|
341 |
+
class SpaceTimeTransformer(nn.Module):
|
342 |
+
""" Vision Transformer
|
343 |
+
A PyTorch impl of : `Space-Time Transformer` from Frozen-in-time - by Max Bain.
|
344 |
+
https://arxiv.org/abs/2104.00650
|
345 |
+
Based off:
|
346 |
+
- ViT implementation from the timm library [https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py]
|
347 |
+
lucidrains timesformer implementation [https://github.com/lucidrains/TimeSformer-pytorch].
|
348 |
+
Notable differences:
|
349 |
+
- allows for variable length input frames (<= num_frames)
|
350 |
+
- allows for variable length input resolution (<= (img_size, img_size)) [UNTESTED]
|
351 |
+
- different attention block mechanism
|
352 |
+
"""
|
353 |
+
|
354 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
355 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
356 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
|
357 |
+
num_frames=8, time_init='rand', attention_style='frozen-in-time', ln_pre=False,
|
358 |
+
act_layer=nn.GELU, is_tanh_gating=False, tune_bias=False, prompt_cfg={}):
|
359 |
+
"""
|
360 |
+
Args:
|
361 |
+
img_size (int, tuple): input image size
|
362 |
+
patch_size (int, tuple): patch size
|
363 |
+
in_chans (int): number of input channels
|
364 |
+
num_classes (int): number of classes for classification head
|
365 |
+
embed_dim (int): embedding dimension
|
366 |
+
depth (int): depth of transformer
|
367 |
+
num_heads (int): number of attention heads
|
368 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
369 |
+
qkv_bias (bool): enable bias for qkv if True
|
370 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
371 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
372 |
+
drop_rate (float): dropout rate
|
373 |
+
attn_drop_rate (float): attention dropout rate
|
374 |
+
drop_path_rate (float): stochastic depth rate
|
375 |
+
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
|
376 |
+
norm_layer: (nn.Module): normalization layer
|
377 |
+
num_frames: (int) maximum number of frames expected as input
|
378 |
+
time_init: (str) how to initialise the time attention layer, 'zeros' allows for the timesformer to start off
|
379 |
+
as ViT.
|
380 |
+
attention_style: (str) how to attend to space and time.
|
381 |
+
"""
|
382 |
+
super().__init__()
|
383 |
+
self.num_classes = num_classes
|
384 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
385 |
+
self.num_frames = num_frames
|
386 |
+
self.embed_dim = embed_dim
|
387 |
+
self.tune_bias = tune_bias
|
388 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
389 |
+
print("######USING ATTENTION STYLE: ", attention_style)
|
390 |
+
self.param_list = []
|
391 |
+
if hybrid_backbone is not None:
|
392 |
+
raise NotImplementedError('hybrid backbone not implemented')
|
393 |
+
else:
|
394 |
+
self.patch_embed = VideoPatchEmbed(
|
395 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, num_frames=num_frames, ln_pre=ln_pre)
|
396 |
+
self.param_list += list(self.patch_embed.parameters())
|
397 |
+
num_patches = self.patch_embed.num_patches
|
398 |
+
self.patches_per_frame = num_patches // num_frames
|
399 |
+
|
400 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
401 |
+
self.pos_embed = nn.Parameter(
|
402 |
+
torch.zeros(1, self.patches_per_frame + 1,
|
403 |
+
embed_dim)) # remember to take pos_embed[1:] for tiling over time
|
404 |
+
self.temporal_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
|
405 |
+
self.param_list += [self.cls_token, self.pos_embed, self.temporal_embed]
|
406 |
+
|
407 |
+
if ln_pre:
|
408 |
+
self.ln_pre = nn.LayerNorm(embed_dim)
|
409 |
+
if self.tune_bias:
|
410 |
+
self.param_list += [m for n, m in self.ln_pre.named_parameters() if 'bias' not in n]
|
411 |
+
else:
|
412 |
+
self.param_list += list(self.ln_pre.parameters())
|
413 |
+
else:
|
414 |
+
self.ln_pre = None
|
415 |
+
|
416 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
417 |
+
|
418 |
+
# config for prompts
|
419 |
+
self.num_tokens = prompt_cfg.get('num_tokens', 0)
|
420 |
+
self.prompt_dim = prompt_cfg.get('prompt_dim', 768)
|
421 |
+
self.pt_spt = prompt_cfg.pop('pt_spt', True)
|
422 |
+
self.pt_tmp = prompt_cfg.pop('pt_tmp', True)
|
423 |
+
self.style = prompt_cfg.pop('style', 'default')
|
424 |
+
self.query = prompt_cfg.pop('query', 'cls')
|
425 |
+
self.n_seg = prompt_cfg.pop('n_seg', 4)
|
426 |
+
self.k_s = prompt_cfg.pop('K_s', depth)
|
427 |
+
self.st = prompt_cfg.pop('st', 0)
|
428 |
+
self.end = prompt_cfg.pop('end', depth)
|
429 |
+
assert self.st <= self.end
|
430 |
+
if self.style == 'default':
|
431 |
+
print(f'Prompting {self.st}-{self.end} layer of the visual backbone')
|
432 |
+
elif self.style == 'VoP_c' and self.k_s < depth:
|
433 |
+
self.prompt_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
|
434 |
+
elif self.style == 'VoP_c_pool':
|
435 |
+
self.prompt_temp_embed = nn.Parameter(torch.zeros(1, self.n_seg, embed_dim))
|
436 |
+
trunc_normal_(self.prompt_temp_embed, std=.02)
|
437 |
+
|
438 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
439 |
+
|
440 |
+
blocks = []
|
441 |
+
for i in range(depth):
|
442 |
+
stblk_cfg = {}
|
443 |
+
if self.num_tokens > 0:
|
444 |
+
stblk_cfg = {'num_tokens': prompt_cfg['num_tokens'], 'split_st': prompt_cfg.get('split_st', False)}
|
445 |
+
blocks.append(
|
446 |
+
SpaceTimeBlock(
|
447 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
448 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, time_init=time_init,
|
449 |
+
attention_style=attention_style, act_layer=act_layer, is_tanh_gating=is_tanh_gating, **stblk_cfg)
|
450 |
+
)
|
451 |
+
|
452 |
+
self.blocks = nn.ModuleList(blocks)
|
453 |
+
self.norm = norm_layer(embed_dim)
|
454 |
+
if self.tune_bias:
|
455 |
+
self.param_list += reduce(operator.add, [[m for n, m in x.named_parameters() if 'bias' not in n] for x in self.blocks])
|
456 |
+
self.param_list += [m for n, m in self.norm.named_parameters() if 'bias' not in n]
|
457 |
+
else:
|
458 |
+
self.param_list += reduce(operator.add, [list(x.parameters()) for x in self.blocks])
|
459 |
+
self.param_list += list(self.norm.parameters())
|
460 |
+
|
461 |
+
# Representation layer
|
462 |
+
if representation_size:
|
463 |
+
self.num_features = representation_size
|
464 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
465 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
466 |
+
('act', nn.Tanh())
|
467 |
+
]))
|
468 |
+
if self.tune_bias:
|
469 |
+
self.param_list += [m for n, m in self.pre_logits.named_parameters() if 'bias' not in n]
|
470 |
+
else:
|
471 |
+
self.param_list += list(self.pre_logits.parameters())
|
472 |
+
else:
|
473 |
+
self.pre_logits = nn.Identity()
|
474 |
+
|
475 |
+
# Classifier head
|
476 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
477 |
+
|
478 |
+
trunc_normal_(self.pos_embed, std=.02)
|
479 |
+
trunc_normal_(self.cls_token, std=.02)
|
480 |
+
|
481 |
+
# if num_frames > 1, then we perform ViT inflation and initialise time attention to zero so not necessary.
|
482 |
+
if num_frames == 1:
|
483 |
+
self.apply(self._init_weights)
|
484 |
+
|
485 |
+
# einops transformations
|
486 |
+
self.einops_from_space = 'b (f n) d'
|
487 |
+
self.einops_to_space = '(b f) n d'
|
488 |
+
self.einops_from_time = 'b (f n) d'
|
489 |
+
self.einops_to_time = '(b n) f d'
|
490 |
+
|
491 |
+
# freeze the backbone and only learn the prompts
|
492 |
+
self.prompt_learner = None
|
493 |
+
if self.num_tokens > 0:
|
494 |
+
if 'VoP_c' in self.style:
|
495 |
+
pool = prompt_cfg.pop('pool', {}) if 'pool' in self.style else {}
|
496 |
+
if self.k_s > 0:
|
497 |
+
self.prompt_generator = CMM(self.num_tokens // self.n_seg, self.n_seg, embed_dim, self.prompt_dim, num_layer=self.k_s, \
|
498 |
+
shared=prompt_cfg.get('deep_shared', False), pool=pool)
|
499 |
+
n_prompt_layer = depth - self.k_s
|
500 |
+
|
501 |
+
else:
|
502 |
+
n_prompt_layer = self.end - self.st
|
503 |
+
|
504 |
+
if n_prompt_layer > 0:
|
505 |
+
prompt_cfg['num_layers'] = n_prompt_layer
|
506 |
+
prompt_cfg['prompt_dim'] = embed_dim
|
507 |
+
self.prompt_learner = VisualPromptLearner(patch_size, embed_dim, **prompt_cfg)
|
508 |
+
|
509 |
+
for p in self.param_list:
|
510 |
+
p.requies_grad = False
|
511 |
+
|
512 |
+
def _init_weights(self, m):
|
513 |
+
if isinstance(m, nn.Linear):
|
514 |
+
trunc_normal_(m.weight, std=.02)
|
515 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
516 |
+
nn.init.constant_(m.bias, 0)
|
517 |
+
elif isinstance(m, nn.LayerNorm):
|
518 |
+
nn.init.constant_(m.bias, 0)
|
519 |
+
nn.init.constant_(m.weight, 1.0)
|
520 |
+
|
521 |
+
@torch.jit.ignore
|
522 |
+
def no_weight_decay(self):
|
523 |
+
return {'pos_embed', 'cls_token'}
|
524 |
+
|
525 |
+
def get_classifier(self):
|
526 |
+
return self.head
|
527 |
+
|
528 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
529 |
+
self.num_classes = num_classes
|
530 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
531 |
+
|
532 |
+
def forward_features(self, x, use_checkpoint=False, cls_at_last=True, istrain=False, gamma=1.0):
|
533 |
+
# print(x.shape)
|
534 |
+
b, curr_frames, channels, _, _ = x.shape
|
535 |
+
x = self.patch_embed(x)
|
536 |
+
x = x.flatten(2).transpose(2, 1)
|
537 |
+
x = x.reshape(b, -1, self.patch_embed.embed_dim)
|
538 |
+
|
539 |
+
BF = x.shape[0]
|
540 |
+
cls_tokens = self.cls_token.expand(BF, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
541 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
542 |
+
# positional embed needs to be tiled for each frame (this does [1,2,3] --> [1,2,3,1,2,3]...)
|
543 |
+
cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
|
544 |
+
tile_pos_embed = self.pos_embed[:, 1:, :].repeat(1, self.num_frames, 1)
|
545 |
+
# temporal embed needs to be repeated within each frame (this does [1,2,3] --> [1,1,1,2,2,2,3,3,3]...)
|
546 |
+
tile_temporal_embed = self.temporal_embed.repeat_interleave(self.patches_per_frame, 1)
|
547 |
+
total_pos_embed = tile_pos_embed + tile_temporal_embed
|
548 |
+
total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) # 1 x (NT + 1) x D
|
549 |
+
|
550 |
+
curr_patches = x.shape[1]
|
551 |
+
x = x + total_pos_embed[:, :curr_patches] # B x (NT + 1) x D
|
552 |
+
ps_loss = x.new_zeros([1])
|
553 |
+
# incorporate prompts
|
554 |
+
if self.num_tokens > 0:
|
555 |
+
if 'VoP_c' in self.style and self.k_s > 0:
|
556 |
+
ctx, ps = self.prompt_generator(x[:, 1:, :], 0, istrain=istrain, gamma=gamma)
|
557 |
+
ps_loss += ps
|
558 |
+
if self.prompt_generator.use_bank:
|
559 |
+
prompt_temp_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1)
|
560 |
+
ctx = ctx + prompt_temp_embed
|
561 |
+
|
562 |
+
elif self.prompt_learner is not None:
|
563 |
+
ctx, ps = self.prompt_learner(x[:, :1, :], 0, istrain=istrain, gamma=gamma)
|
564 |
+
ps_loss += ps
|
565 |
+
if ctx.size(0) != BF:
|
566 |
+
ctx = ctx.expand(BF, -1, -1)
|
567 |
+
|
568 |
+
x = torch.cat((
|
569 |
+
x[:, :1, :], # cls_token
|
570 |
+
ctx,
|
571 |
+
x[:, 1:, :]
|
572 |
+
), dim=1)
|
573 |
+
|
574 |
+
if self.ln_pre is not None:
|
575 |
+
x = self.ln_pre(x)
|
576 |
+
x = self.pos_drop(x)
|
577 |
+
n = self.patches_per_frame
|
578 |
+
f = curr_frames
|
579 |
+
|
580 |
+
for i, blk in enumerate(self.blocks):
|
581 |
+
if self.num_tokens > 0 and i > 0 and i >= self.st and i < self.end:
|
582 |
+
if 'VoP_c' in self.style:
|
583 |
+
if i < self.k_s:
|
584 |
+
ctx, ps = self.prompt_generator(x[:, self.num_tokens+1:, :], i, istrain=istrain, gamma=gamma)
|
585 |
+
ps_loss += ps
|
586 |
+
if self.prompt_generator.use_bank:
|
587 |
+
prompt_temp_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1)
|
588 |
+
ctx = ctx + prompt_temp_embed
|
589 |
+
else:
|
590 |
+
ctx, ps = self.prompt_learner(x[:, :1, :], i-self.k_s, istrain=istrain, gamma=gamma)
|
591 |
+
ps_loss += ps
|
592 |
+
|
593 |
+
if 'pool' in self.style:
|
594 |
+
prompt_embed = self.prompt_temp_embed.repeat_interleave(self.num_tokens // self.n_seg, 1)
|
595 |
+
else:
|
596 |
+
prompt_embed = self.prompt_embed.repeat_interleave(self.num_tokens // self.num_frames, 1)
|
597 |
+
ctx = ctx + prompt_embed
|
598 |
+
if ctx.size(0) != BF:
|
599 |
+
ctx = ctx.expand(BF, -1, -1)
|
600 |
+
|
601 |
+
elif (i - self.st) < self.prompt_learner.num_layers:
|
602 |
+
ctx, ps = self.prompt_learner(x[:, :1, :], i-self.st, istrain=istrain, gamma=gamma)
|
603 |
+
ps_loss += ps
|
604 |
+
if ctx.size(0) != BF:
|
605 |
+
ctx = ctx.expand(BF, -1, -1)
|
606 |
+
|
607 |
+
x = torch.cat((
|
608 |
+
x[:, :1, :], # cls_token
|
609 |
+
ctx,
|
610 |
+
x[:, self.num_tokens+1:, :]
|
611 |
+
), dim=1)
|
612 |
+
|
613 |
+
style = 'default' if i >= self.k_s else self.style
|
614 |
+
pt_tmp = self.pt_tmp if i >= self.st and i < self.end else False
|
615 |
+
pt_spt = self.pt_spt if i >= self.st and i < self.end else False
|
616 |
+
x = blk(x, self.einops_from_space, self.einops_to_space, self.einops_from_time,
|
617 |
+
self.einops_to_time,
|
618 |
+
time_n=n, space_f=f, use_checkpoint=use_checkpoint, pt_spt=pt_spt,
|
619 |
+
pt_tmp=pt_tmp, style=style, n_seg=self.n_seg)
|
620 |
+
|
621 |
+
if cls_at_last:
|
622 |
+
x = self.norm(x)
|
623 |
+
x = x[:, 0]
|
624 |
+
x = self.pre_logits(x)
|
625 |
+
|
626 |
+
return x, ps_loss
|
627 |
+
else:
|
628 |
+
return self.norm(x), ps_loss
|
629 |
+
|
630 |
+
def forward(self, x, use_checkpoint=False, istrain=False, gamma=1.0):
|
631 |
+
# Note: B C T H W => B T C H W
|
632 |
+
# The default input order is different from the one in Frozen-in-Time
|
633 |
+
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
634 |
+
x, ps_loss = self.forward_features(x, use_checkpoint=use_checkpoint, istrain=istrain, gamma=gamma)
|
635 |
+
x = self.head(x)
|
636 |
+
|
637 |
+
return x, ps_loss
|
638 |
+
|
639 |
+
def train(self, mode=True):
|
640 |
+
if not isinstance(mode, bool):
|
641 |
+
raise ValueError("training mode is expected to be boolean")
|
642 |
+
self.training = mode
|
643 |
+
for m in self.modules():
|
644 |
+
m.training = mode
|
645 |
+
|
646 |
+
if mode and self.num_tokens > 0:
|
647 |
+
for n, m in self.named_modules():
|
648 |
+
if 'prompt' not in n:
|
649 |
+
m.training = False
|
650 |
+
|
lavila/models/tokenizer.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
|
8 |
+
# Modified by Yue Zhao
|
9 |
+
# The original code is under MIT License
|
10 |
+
|
11 |
+
import gzip
|
12 |
+
import html
|
13 |
+
import os
|
14 |
+
from functools import lru_cache
|
15 |
+
|
16 |
+
import ftfy
|
17 |
+
import regex as re
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from transformers import (BertTokenizer, DistilBertTokenizer, GPT2Tokenizer)
|
21 |
+
|
22 |
+
|
23 |
+
@lru_cache()
|
24 |
+
def default_bpe():
|
25 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
26 |
+
|
27 |
+
|
28 |
+
@lru_cache()
|
29 |
+
def bytes_to_unicode():
|
30 |
+
"""
|
31 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
32 |
+
The reversible bpe codes work on unicode strings.
|
33 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
34 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
35 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
36 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
37 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
38 |
+
"""
|
39 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
40 |
+
cs = bs[:]
|
41 |
+
n = 0
|
42 |
+
for b in range(2**8):
|
43 |
+
if b not in bs:
|
44 |
+
bs.append(b)
|
45 |
+
cs.append(2**8+n)
|
46 |
+
n += 1
|
47 |
+
cs = [chr(n) for n in cs]
|
48 |
+
return dict(zip(bs, cs))
|
49 |
+
|
50 |
+
|
51 |
+
def get_pairs(word):
|
52 |
+
"""Return set of symbol pairs in a word.
|
53 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
54 |
+
"""
|
55 |
+
pairs = set()
|
56 |
+
prev_char = word[0]
|
57 |
+
for char in word[1:]:
|
58 |
+
pairs.add((prev_char, char))
|
59 |
+
prev_char = char
|
60 |
+
return pairs
|
61 |
+
|
62 |
+
|
63 |
+
def basic_clean(text):
|
64 |
+
text = ftfy.fix_text(text)
|
65 |
+
text = html.unescape(html.unescape(text))
|
66 |
+
return text.strip()
|
67 |
+
|
68 |
+
|
69 |
+
def whitespace_clean(text):
|
70 |
+
text = re.sub(r'\s+', ' ', text)
|
71 |
+
text = text.strip()
|
72 |
+
return text
|
73 |
+
|
74 |
+
|
75 |
+
class SimpleTokenizer(object):
|
76 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
77 |
+
self.byte_encoder = bytes_to_unicode()
|
78 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
79 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
80 |
+
merges = merges[1:49152-256-2+1]
|
81 |
+
merges = [tuple(merge.split()) for merge in merges]
|
82 |
+
vocab = list(bytes_to_unicode().values())
|
83 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
84 |
+
for merge in merges:
|
85 |
+
vocab.append(''.join(merge))
|
86 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
87 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
88 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
89 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
90 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
91 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
92 |
+
|
93 |
+
def bpe(self, token):
|
94 |
+
if token in self.cache:
|
95 |
+
return self.cache[token]
|
96 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
97 |
+
pairs = get_pairs(word)
|
98 |
+
|
99 |
+
if not pairs:
|
100 |
+
return token+'</w>'
|
101 |
+
|
102 |
+
while True:
|
103 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
104 |
+
if bigram not in self.bpe_ranks:
|
105 |
+
break
|
106 |
+
first, second = bigram
|
107 |
+
new_word = []
|
108 |
+
i = 0
|
109 |
+
while i < len(word):
|
110 |
+
try:
|
111 |
+
j = word.index(first, i)
|
112 |
+
new_word.extend(word[i:j])
|
113 |
+
i = j
|
114 |
+
except:
|
115 |
+
new_word.extend(word[i:])
|
116 |
+
break
|
117 |
+
|
118 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
119 |
+
new_word.append(first+second)
|
120 |
+
i += 2
|
121 |
+
else:
|
122 |
+
new_word.append(word[i])
|
123 |
+
i += 1
|
124 |
+
new_word = tuple(new_word)
|
125 |
+
word = new_word
|
126 |
+
if len(word) == 1:
|
127 |
+
break
|
128 |
+
else:
|
129 |
+
pairs = get_pairs(word)
|
130 |
+
word = ' '.join(word)
|
131 |
+
self.cache[token] = word
|
132 |
+
return word
|
133 |
+
|
134 |
+
def encode(self, text):
|
135 |
+
bpe_tokens = []
|
136 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
137 |
+
for token in re.findall(self.pat, text):
|
138 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
139 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
140 |
+
return bpe_tokens
|
141 |
+
|
142 |
+
def decode(self, tokens):
|
143 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
144 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
145 |
+
return text
|
146 |
+
|
147 |
+
def __call__(self, texts, context_length=77):
|
148 |
+
if isinstance(texts, str):
|
149 |
+
texts = [texts]
|
150 |
+
|
151 |
+
sot_token = self.encoder["<|startoftext|>"]
|
152 |
+
eot_token = self.encoder["<|endoftext|>"]
|
153 |
+
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
154 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
155 |
+
|
156 |
+
for i, tokens in enumerate(all_tokens):
|
157 |
+
tokens = tokens[:context_length]
|
158 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
159 |
+
|
160 |
+
if len(result) == 1:
|
161 |
+
return result[0]
|
162 |
+
return result
|
163 |
+
|
164 |
+
|
165 |
+
class MyBertTokenizer(object):
|
166 |
+
def __init__(self, name=''):
|
167 |
+
print('=> Initialize MyBertTokenizer ({})'.format(name))
|
168 |
+
self.tokenizer = BertTokenizer.from_pretrained(name)
|
169 |
+
self.bos_token_id, self.eos_token_id = self.tokenizer('').input_ids
|
170 |
+
self.pad_token_id = 0
|
171 |
+
|
172 |
+
def __call__(self, texts, context_length=77):
|
173 |
+
if isinstance(texts, str):
|
174 |
+
texts = [texts]
|
175 |
+
result = torch.zeros(len(texts), context_length, dtype=torch.long)
|
176 |
+
mask = torch.zeros(len(texts), context_length, dtype=torch.float32)
|
177 |
+
for i, text in enumerate(texts):
|
178 |
+
tokens = self.tokenizer(text)
|
179 |
+
input_ids = tokens.input_ids[:context_length]
|
180 |
+
attention_mask = tokens.attention_mask[:context_length]
|
181 |
+
result[i, :len(input_ids)] = torch.tensor(input_ids)
|
182 |
+
mask[i, :len(attention_mask)] = torch.tensor(attention_mask)
|
183 |
+
|
184 |
+
if len(result) == 1:
|
185 |
+
return result[0], mask[0]
|
186 |
+
return result, mask
|
187 |
+
|
188 |
+
|
189 |
+
class MyDistilBertTokenizer(object):
|
190 |
+
def __init__(self, name=''):
|
191 |
+
print('=> Initialize MyDistilBertTokenizer ({})'.format(name))
|
192 |
+
self.tokenizer = DistilBertTokenizer.from_pretrained(name)
|
193 |
+
|
194 |
+
def __call__(self, texts, context_length=77):
|
195 |
+
if isinstance(texts, str):
|
196 |
+
texts = [texts]
|
197 |
+
result = torch.zeros(len(texts), context_length, dtype=torch.long)
|
198 |
+
mask = torch.zeros(len(texts), context_length, dtype=torch.float32)
|
199 |
+
for i, text in enumerate(texts):
|
200 |
+
tokens = self.tokenizer(text)
|
201 |
+
input_ids = tokens.input_ids[:context_length]
|
202 |
+
attention_mask = tokens.attention_mask[:context_length]
|
203 |
+
result[i, :len(input_ids)] = torch.tensor(input_ids)
|
204 |
+
mask[i, :len(attention_mask)] = torch.tensor(attention_mask)
|
205 |
+
|
206 |
+
if len(result) == 1:
|
207 |
+
return result[0], mask[0]
|
208 |
+
return result, mask
|
209 |
+
|
210 |
+
|
211 |
+
class MyGPT2Tokenizer(object):
|
212 |
+
def __init__(self, name='', add_bos=False):
|
213 |
+
print('=> Initialize MyGPT2Tokenizer ({})'.format(name))
|
214 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(name)
|
215 |
+
self.bos_token_id, self.eos_token_id = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
|
216 |
+
self.pad_token_id = 0
|
217 |
+
self.add_bos = add_bos
|
218 |
+
# num_added_tokens = self.tokenizer.add_special_tokens({'pad_token': "[PAD]"})
|
219 |
+
# print('num_added_tokens={}'.format(len(num_added_tokens)))
|
220 |
+
|
221 |
+
def __call__(self, texts, context_length=77):
|
222 |
+
if isinstance(texts, str):
|
223 |
+
texts = [texts]
|
224 |
+
result = torch.zeros(len(texts), context_length, dtype=torch.long)
|
225 |
+
for i, text in enumerate(texts):
|
226 |
+
tokens = self.tokenizer(text)
|
227 |
+
if not self.add_bos:
|
228 |
+
input_ids = tokens.input_ids[:context_length - 1]
|
229 |
+
input_ids = input_ids + [self.tokenizer.eos_token_id] # add [EOS]
|
230 |
+
else:
|
231 |
+
input_ids = tokens.input_ids[:context_length - 2]
|
232 |
+
input_ids = [self.tokenizer.bos_token_id] + input_ids + [self.tokenizer.eos_token_id] # add [EOS]
|
233 |
+
# attention_mask = tokens.attention_mask[:context_length]
|
234 |
+
# attention_mask = attention_mask + [0.] * pad_length
|
235 |
+
result[i, :len(input_ids)] = torch.tensor(input_ids)
|
236 |
+
|
237 |
+
if len(result) == 1:
|
238 |
+
return result[0]
|
239 |
+
return result
|
lavila/models/utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from collections import OrderedDict
|
8 |
+
import functools
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
def inflate_positional_embeds(
|
14 |
+
current_model_state_dict, new_state_dict,
|
15 |
+
num_frames=4,
|
16 |
+
load_temporal_fix='bilinear',
|
17 |
+
):
|
18 |
+
# allow loading of timesformer with fewer num_frames
|
19 |
+
curr_keys = list(current_model_state_dict.keys())
|
20 |
+
temporal_embed = ['visual.temporal_embed', 'visual.prompt_embed']
|
21 |
+
for x in temporal_embed:
|
22 |
+
if x in new_state_dict and x in curr_keys:
|
23 |
+
load_temporal_embed = new_state_dict[x]
|
24 |
+
load_num_frames = load_temporal_embed.shape[1]
|
25 |
+
curr_num_frames = num_frames
|
26 |
+
embed_dim = load_temporal_embed.shape[2]
|
27 |
+
|
28 |
+
if load_num_frames != curr_num_frames:
|
29 |
+
if load_num_frames > curr_num_frames:
|
30 |
+
print(f'### loaded SpaceTimeTransformer model has MORE frames than current...'
|
31 |
+
f'### loading {x} weights, filling in the extras via {load_temporal_fix}')
|
32 |
+
new_temporal_embed = load_temporal_embed[:, :curr_num_frames, :]
|
33 |
+
else:
|
34 |
+
print(f'### loaded SpaceTimeTransformer model has FEWER frames than current...'
|
35 |
+
f'### loading {x} weights, filling in the extras via {load_temporal_fix}')
|
36 |
+
if load_temporal_fix == 'zeros':
|
37 |
+
new_temporal_embed = torch.zeros([load_temporal_embed.shape[0], curr_num_frames, embed_dim])
|
38 |
+
new_temporal_embed[:, :load_num_frames] = load_temporal_embed
|
39 |
+
elif load_temporal_fix in ['interp', 'bilinear']:
|
40 |
+
# interpolate
|
41 |
+
# unsqueeze so pytorch thinks its an image
|
42 |
+
mode = 'nearest'
|
43 |
+
if load_temporal_fix == 'bilinear':
|
44 |
+
mode = 'bilinear'
|
45 |
+
load_temporal_embed = load_temporal_embed.unsqueeze(0)
|
46 |
+
new_temporal_embed = F.interpolate(load_temporal_embed,
|
47 |
+
(curr_num_frames, embed_dim), mode=mode).squeeze(0)
|
48 |
+
else:
|
49 |
+
raise NotImplementedError
|
50 |
+
new_state_dict[x] = new_temporal_embed
|
51 |
+
# allow loading with smaller spatial patches. assumes custom border crop, to append the
|
52 |
+
# border patches to the input sequence
|
53 |
+
if 'visual.pos_embed' in new_state_dict and 'visual.pos_embed' in curr_keys:
|
54 |
+
load_pos_embed = new_state_dict['visual.pos_embed']
|
55 |
+
load_num_patches = load_pos_embed.shape[1]
|
56 |
+
curr_pos_embed = current_model_state_dict['visual.pos_embed']
|
57 |
+
if load_num_patches != curr_pos_embed.shape[1]:
|
58 |
+
raise NotImplementedError(
|
59 |
+
'Loading models with different spatial resolution / patch number not yet implemented, sorry.')
|
60 |
+
|
61 |
+
return new_state_dict
|
62 |
+
|
63 |
+
|
64 |
+
def rsetattr(obj, attr, val):
|
65 |
+
pre, _, post = attr.rpartition('.')
|
66 |
+
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
|
67 |
+
|
68 |
+
|
69 |
+
def rgetattr(obj, attr, *args):
|
70 |
+
def _getattr(obj, attr):
|
71 |
+
return getattr(obj, attr, *args)
|
72 |
+
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
73 |
+
|
74 |
+
|
75 |
+
# util functions to convert CLIP-style model keys to TimeSformer-style
|
76 |
+
def remap_keys(clip_state_dict, transformer_layers=12):
|
77 |
+
remapped_state_dict = OrderedDict()
|
78 |
+
key_mapping = {
|
79 |
+
"class_embedding": "cls_token",
|
80 |
+
"positional_embedding": "pos_embed",
|
81 |
+
"conv1.weight": "patch_embed.proj.weight",
|
82 |
+
"ln_pre.weight": "ln_pre.weight",
|
83 |
+
"ln_pre.bias": "ln_pre.bias",
|
84 |
+
"ln_post.weight": "norm.weight",
|
85 |
+
"ln_post.bias": "norm.bias",
|
86 |
+
}
|
87 |
+
for layer in range(transformer_layers):
|
88 |
+
key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_weight"] = f"blocks.{layer}.attn.qkv.weight"
|
89 |
+
key_mapping[f"transformer.resblocks.{layer}.attn.in_proj_bias"] = f"blocks.{layer}.attn.qkv.bias"
|
90 |
+
key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.weight"] = f"blocks.{layer}.attn.proj.weight"
|
91 |
+
key_mapping[f"transformer.resblocks.{layer}.attn.out_proj.bias"] = f"blocks.{layer}.attn.proj.bias"
|
92 |
+
key_mapping[f"transformer.resblocks.{layer}.ln_1.weight"] = f"blocks.{layer}.norm1.weight"
|
93 |
+
key_mapping[f"transformer.resblocks.{layer}.ln_1.bias"] = f"blocks.{layer}.norm1.bias"
|
94 |
+
key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.weight"] = f"blocks.{layer}.mlp.fc1.weight"
|
95 |
+
key_mapping[f"transformer.resblocks.{layer}.mlp.c_fc.bias"] = f"blocks.{layer}.mlp.fc1.bias"
|
96 |
+
key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.weight"] = f"blocks.{layer}.mlp.fc2.weight"
|
97 |
+
key_mapping[f"transformer.resblocks.{layer}.mlp.c_proj.bias"] = f"blocks.{layer}.mlp.fc2.bias"
|
98 |
+
key_mapping[f"transformer.resblocks.{layer}.ln_2.weight"] = f"blocks.{layer}.norm2.weight"
|
99 |
+
key_mapping[f"transformer.resblocks.{layer}.ln_2.bias"] = f"blocks.{layer}.norm2.bias"
|
100 |
+
|
101 |
+
for key in clip_state_dict:
|
102 |
+
if key == 'proj':
|
103 |
+
continue # due to possible dim mismatch, we load this later
|
104 |
+
if key == "class_embedding":
|
105 |
+
clip_state_dict[key] = clip_state_dict[key].unsqueeze(0).unsqueeze(0)
|
106 |
+
if key == "positional_embedding":
|
107 |
+
clip_state_dict[key] = clip_state_dict[key].unsqueeze(0)
|
108 |
+
remapped_state_dict[key_mapping[key]] = clip_state_dict[key]
|
109 |
+
|
110 |
+
return remapped_state_dict
|
lavila/utils/config.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import yaml
|
3 |
+
|
4 |
+
def load_base_cfg():
|
5 |
+
with open('configs/base.yml', 'r') as fp:
|
6 |
+
cfg = yaml.load(fp, Loader=yaml.SafeLoader)
|
7 |
+
return cfg
|
8 |
+
|
9 |
+
def load_cfg(cfg_file):
|
10 |
+
cfg = load_base_cfg()
|
11 |
+
with open(cfg_file, 'r') as fp:
|
12 |
+
exp_cfg = yaml.load(fp, Loader=yaml.SafeLoader)
|
13 |
+
|
14 |
+
cfg['model'].update(exp_cfg.get('model', {}))
|
15 |
+
cfg['data'].update(exp_cfg.get('data', {}))
|
16 |
+
dataset = cfg['data'].get('dataset')
|
17 |
+
return cfg
|
18 |
+
|
lavila/utils/evaluation.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def accuracy(output, target, topk=(1,)):
|
12 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
13 |
+
with torch.no_grad():
|
14 |
+
maxk = max(topk)
|
15 |
+
batch_size = target.size(0)
|
16 |
+
|
17 |
+
_, pred = output.topk(maxk, 1, True, True)
|
18 |
+
pred = pred.t()
|
19 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
20 |
+
|
21 |
+
res = []
|
22 |
+
for k in topk:
|
23 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
24 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
25 |
+
return res
|
26 |
+
|
27 |
+
|
28 |
+
def get_mean_accuracy(cm):
|
29 |
+
list_acc = []
|
30 |
+
for i in range(len(cm)):
|
31 |
+
acc = 0
|
32 |
+
if cm[i, :].sum() > 0:
|
33 |
+
acc = cm[i, i] / cm[i, :].sum()
|
34 |
+
list_acc.append(acc)
|
35 |
+
|
36 |
+
return 100 * np.mean(list_acc), 100 * np.trace(cm) / np.sum(cm)
|
lavila/utils/evaluation_charades.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def compute_map(submission_array, gt_array):
|
11 |
+
""" Returns mAP, weighted mAP, and AP array """
|
12 |
+
m_aps = []
|
13 |
+
n_classes = submission_array.shape[1]
|
14 |
+
for oc_i in range(n_classes):
|
15 |
+
sorted_idxs = np.argsort(-submission_array[:, oc_i])
|
16 |
+
tp = gt_array[:, oc_i][sorted_idxs] == 1
|
17 |
+
fp = np.invert(tp)
|
18 |
+
n_pos = tp.sum()
|
19 |
+
if n_pos < 0.1:
|
20 |
+
m_aps.append(float('nan'))
|
21 |
+
continue
|
22 |
+
fp.sum()
|
23 |
+
f_pcs = np.cumsum(fp)
|
24 |
+
t_pcs = np.cumsum(tp)
|
25 |
+
prec = t_pcs / (f_pcs+t_pcs).astype(float)
|
26 |
+
avg_prec = 0
|
27 |
+
for i in range(submission_array.shape[0]):
|
28 |
+
if tp[i]:
|
29 |
+
avg_prec += prec[i]
|
30 |
+
m_aps.append(avg_prec / n_pos.astype(float))
|
31 |
+
m_aps = np.array(m_aps)
|
32 |
+
#m_ap = np.mean(m_aps)
|
33 |
+
m_ap = m_aps[~np.isnan(m_aps)]
|
34 |
+
print(f'num of available classes: {len(m_ap)}')
|
35 |
+
m_ap = m_ap.mean() # compute mean w/o nan
|
36 |
+
w_ap = (m_aps * gt_array.sum(axis=0) / gt_array.sum().sum().astype(float))
|
37 |
+
return m_ap, w_ap, m_aps
|
38 |
+
|
39 |
+
|
40 |
+
def charades_map(submission_array, gt_array):
|
41 |
+
"""
|
42 |
+
Approximate version of the charades evaluation function
|
43 |
+
For precise numbers, use the submission file with the official matlab script
|
44 |
+
"""
|
45 |
+
fix = submission_array.copy()
|
46 |
+
empty = np.sum(gt_array, axis=1) == 0
|
47 |
+
fix[empty, :] = np.NINF
|
48 |
+
return compute_map(fix, gt_array)
|
49 |
+
|
50 |
+
|
51 |
+
def create_submission(video_list, predictions, out_file):
|
52 |
+
assert len(video_list) == predictions.shape[0]
|
53 |
+
with open(out_file, 'w') as f:
|
54 |
+
for i, video_id in enumerate(video_list):
|
55 |
+
pred_str = ' '.join(map(lambda x: str(x), predictions[i].tolist()))
|
56 |
+
f.write('{} {}\n\n'.format(video_id, pred_str))
|
lavila/utils/evaluation_ek100mir.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Part of the code is from
|
8 |
+
# `https://github.com/mwray/Joint-Part-of-Speech-Embeddings/tree/main/src/evaluation/NDCG.py`
|
9 |
+
# and
|
10 |
+
# `https://github.com/mwray/Joint-Part-of-Speech-Embeddings/tree/main/src/evaluation/mAP.py`
|
11 |
+
# Modified by Yue Zhao
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
|
16 |
+
def calculate_DCG(similarity_matrix, relevancy_matrix, k_counts):
|
17 |
+
"""
|
18 |
+
Calculates the Discounted Cumulative Gain (DCG) between two modalities for
|
19 |
+
the first modality.
|
20 |
+
DCG = \sum_{i=1}^k \frac{rel_i}{log_2(i + 1)}
|
21 |
+
i.e. the sum of the k relevant retrievals which is calculated as the scaled
|
22 |
+
relevancy for the ith item. The scale is designed such that early
|
23 |
+
retrievals are more important than later retrievals.
|
24 |
+
Params:
|
25 |
+
- similarity_matrix: matrix of size n1 x n2 where n1 is the number of
|
26 |
+
items in the first modality and n2 is the number of items in the
|
27 |
+
second modality. The [ith,jth] element is the predicted similarity
|
28 |
+
between the ith item from the first modality and the jth item from
|
29 |
+
the second modality.
|
30 |
+
- relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix
|
31 |
+
above). The [ith, jth] element is the semantic relevancy between the
|
32 |
+
ith item from the first modality and the jth item from the second
|
33 |
+
modality.
|
34 |
+
- k_counts: matrix of size n1 x n2 (see similarity_matrix above) which
|
35 |
+
includes information on which items to use to calculate the DCG for
|
36 |
+
(see calculate_k_counts for more info on this matrix).
|
37 |
+
Returns:
|
38 |
+
- The DCG for each item in the first modality, a n1 length vector.
|
39 |
+
"""
|
40 |
+
x_sz, y_sz = similarity_matrix.shape
|
41 |
+
ranks = np.argsort(similarity_matrix)[:, ::-1]
|
42 |
+
# Create vector of size (n,) where n is the length of the last dimension in
|
43 |
+
# similarity matrix
|
44 |
+
# This vector is of the form log(i+1)
|
45 |
+
logs = np.log2(np.arange(y_sz) + 2)
|
46 |
+
# Convert logs into the divisor for the DCG calculation, of size similarity
|
47 |
+
# matrix
|
48 |
+
divisors = np.repeat(np.expand_dims(logs, axis=0), x_sz, axis=0)
|
49 |
+
|
50 |
+
# mask out the sorted relevancy matrix to only use the first k relevant
|
51 |
+
# retrievals for each item.
|
52 |
+
columns = np.repeat(np.expand_dims(np.arange(x_sz), axis=1), y_sz, axis=1)
|
53 |
+
numerators = relevancy_matrix[columns, ranks] * k_counts
|
54 |
+
# Calculate the final DCG score (note that this isn't expected to sum to 1)
|
55 |
+
return np.sum(numerators / divisors, axis=1)
|
56 |
+
|
57 |
+
|
58 |
+
def calculate_k_counts(relevancy_matrix):
|
59 |
+
"""
|
60 |
+
Works out the maximum number of allowed retrievals when working out the
|
61 |
+
Discounted Cumulative Gain. For each query the DCG only uses the first k
|
62 |
+
items retrieved which constitute the k relevant items for that query
|
63 |
+
(otherwise the nDCG scores can be deceptively high for bad rankings).
|
64 |
+
Params:
|
65 |
+
- relevancy_matrix: matrix of size n1 x n2 where n1 is the number of
|
66 |
+
items in the first modality and n2 is the number of items in the
|
67 |
+
second modality. The [ith, jth] element is the semantic relevancy
|
68 |
+
between the ith item from the first modality and the jth item from
|
69 |
+
the second modality.
|
70 |
+
Returns:
|
71 |
+
- Matrix of size n1 x n2 (see relevancy matrix for more info). This is
|
72 |
+
created as a mask such that if the [ith, jth] element is 1 it
|
73 |
+
represents a valid item to use for the calculation of DCG for the
|
74 |
+
ith item after sorting. For example, if relevancy matrix of:
|
75 |
+
[[1, 0.5, 0],
|
76 |
+
[0, 0 , 1]]
|
77 |
+
is given, then the k_counts matrix will be:
|
78 |
+
[[1, 1, 0],
|
79 |
+
[1, 0, 0]]
|
80 |
+
i.e. the first row has 2 non-zero items, so the first two retrieved
|
81 |
+
items should be used in the calculation. In the second row there is
|
82 |
+
only 1 relevant item, therefore only the first retrieved item should
|
83 |
+
be used for the DCG calculation.
|
84 |
+
"""
|
85 |
+
return (np.sort(relevancy_matrix)[:, ::-1] > 0).astype(int)
|
86 |
+
|
87 |
+
|
88 |
+
def calculate_IDCG(relevancy_matrix, k_counts):
|
89 |
+
"""
|
90 |
+
Calculates the Ideal Discounted Cumulative Gain (IDCG) which is the value
|
91 |
+
of the Discounted Cumulative Gain (DCG) for a perfect retrieval, i.e. the
|
92 |
+
items in the second modality were retrieved in order of their descending
|
93 |
+
relevancy.
|
94 |
+
Params:
|
95 |
+
- relevancy_matrix: matrix of size n1 x n2 where n1 is the number of
|
96 |
+
items in the first modality and n2 is the number of items in the
|
97 |
+
second modality. The [ith, jth] element is the semantic relevancy
|
98 |
+
between the ith item from the first modality and the jth item from
|
99 |
+
the second modality.
|
100 |
+
- k_counts: matrix of size n1 x n2 (see similarity_matrix above) which
|
101 |
+
includes information on which items to use to calculate the DCG for
|
102 |
+
(see calculate_k_counts for more info on this matrix).
|
103 |
+
"""
|
104 |
+
return calculate_DCG(relevancy_matrix, relevancy_matrix, k_counts)
|
105 |
+
|
106 |
+
|
107 |
+
def calculate_nDCG(similarity_matrix, relevancy_matrix, k_counts=None, IDCG=None, reduction='mean'):
|
108 |
+
"""
|
109 |
+
Calculates the normalised Discounted Cumulative Gain (nDCG) between two
|
110 |
+
modalities for the first modality using the Discounted Cumulative Gain
|
111 |
+
(DCG) and the Ideal Discounted Cumulative Gain (IDCG).
|
112 |
+
nDCG = \frac{DCG}{IDCG}
|
113 |
+
Params:
|
114 |
+
- similarity_matrix: matrix of size n1 x n2 where n1 is the number of
|
115 |
+
items in the first modality and n2 is the number of items in the second
|
116 |
+
modality. The [ith,jth] element is the predicted similarity between
|
117 |
+
the ith item from the first modality and the jth item from the second
|
118 |
+
modality.
|
119 |
+
- relevancy_matrix: matrix of size n1 x n2 (see similarity_matrix
|
120 |
+
above). The [ith, jth] element is the semantic relevancy between the
|
121 |
+
ith item from the first modality and the jth item from the second
|
122 |
+
modality.
|
123 |
+
- k_counts: optional parameter: matrix of size n1 x n2 (see
|
124 |
+
similarity_matrix above) which includes information on which items to
|
125 |
+
use to calculate the DCG for (see calculate_k_counts for more info on
|
126 |
+
this matrix). This will be calculated using calculate_IDCG if not
|
127 |
+
present, but should be pre-processed for efficiency.
|
128 |
+
- IDCG: Optional parameter which includes the pre-processed Ideal
|
129 |
+
Discounted Cumulative Gain (IDCG). This is a vector of size n1 (see
|
130 |
+
similarity_matrix above) which contains the IDCG value for each item
|
131 |
+
from the first modality. This will be calculated using calculate_IDCG
|
132 |
+
if not present, but should be pre-processed for efficiency.
|
133 |
+
- reduction: what to use to reduce the different nDCG scores. By
|
134 |
+
default this applies np.mean across all different queries.
|
135 |
+
Returns:
|
136 |
+
- The nDCG values for the first modality.
|
137 |
+
"""
|
138 |
+
if k_counts is None:
|
139 |
+
k_counts = calculate_k_counts(relevancy_matrix)
|
140 |
+
DCG = calculate_DCG(similarity_matrix, relevancy_matrix, k_counts)
|
141 |
+
if IDCG is None:
|
142 |
+
IDCG = calculate_IDCG(relevancy_matrix, k_counts)
|
143 |
+
if reduction == 'mean':
|
144 |
+
return np.mean(DCG / IDCG)
|
145 |
+
elif reduction is None:
|
146 |
+
return DCG / IDCG
|
147 |
+
|
148 |
+
|
149 |
+
def calculate_mAP(sim_mat, relevancy_matrix):
|
150 |
+
"""
|
151 |
+
Computes the mean average precision according to the following formula of
|
152 |
+
average precision:
|
153 |
+
\frac{\sum_{k=1}^n p(k) x rel(k)}{num_rel_docs}
|
154 |
+
where p(k) is the precision at k, rel(k) is an indicator function
|
155 |
+
determining whether the kth returned item is relevant or not and
|
156 |
+
num_rel_docs is the number of relevant items to find within the search.
|
157 |
+
The mean average precision is the mean of the average precision for each
|
158 |
+
query item (i.e row in the matrix)
|
159 |
+
This function takes in two parameters:
|
160 |
+
- sim_mat: a NxM matrix which represents the similarity between two
|
161 |
+
modalities (with modality 1 being of size N and modality 2 of size M).
|
162 |
+
- relevancy_matrix: an NxM matrix which represents the relevancy between two
|
163 |
+
modalities of items (with modality 1 being of size N and modality 2 of
|
164 |
+
size M).
|
165 |
+
"""
|
166 |
+
# Find the order of the items in modality 2 according to modality 1
|
167 |
+
ranked_order = (-sim_mat).argsort()
|
168 |
+
ranked_sim_mat = sim_mat[np.arange(sim_mat.shape[0])[:, None], ranked_order]
|
169 |
+
# re-order the relevancy matrix to accommodate the proposals
|
170 |
+
ranked_rel_mat = relevancy_matrix[np.arange(relevancy_matrix.shape[0])[:, None], ranked_order]
|
171 |
+
|
172 |
+
# find the number of relevant items found at each k
|
173 |
+
cumulative_rel_mat = np.cumsum(ranked_rel_mat, axis=1)
|
174 |
+
# Mask this ensuring that it is non zero if the kth term is 1 (rel(k) above)
|
175 |
+
cumulative_rel_mat[ranked_rel_mat != 1] = 0
|
176 |
+
# find the divisor for p(k)
|
177 |
+
divisor = np.arange(ranked_rel_mat.shape[1]) + 1
|
178 |
+
|
179 |
+
# find the number of relevant docs per query item
|
180 |
+
number_rel_docs = np.sum(ranked_rel_mat == 1, axis=1)
|
181 |
+
|
182 |
+
# find the average precision per query, within np.sum finds p(k) * rel(k)
|
183 |
+
avg_precision = np.sum(cumulative_rel_mat / divisor, axis=1) / number_rel_docs
|
184 |
+
mAP = np.mean(avg_precision)
|
185 |
+
return mAP
|
186 |
+
|
187 |
+
|
188 |
+
def get_mAP(similarity_matrix, rel_matrix):
|
189 |
+
vis_map = calculate_mAP(similarity_matrix, rel_matrix)
|
190 |
+
txt_map = calculate_mAP(similarity_matrix.T, rel_matrix.T)
|
191 |
+
return vis_map, txt_map, (vis_map + txt_map) / 2
|
192 |
+
|
193 |
+
|
194 |
+
def get_nDCG(similarity_matrix, rel_matrix):
|
195 |
+
vis_k_counts = calculate_k_counts(rel_matrix)
|
196 |
+
txt_k_counts = calculate_k_counts(rel_matrix.T)
|
197 |
+
vis_IDCG = calculate_IDCG(rel_matrix, vis_k_counts)
|
198 |
+
txt_IDCG = calculate_IDCG(rel_matrix.T, txt_k_counts)
|
199 |
+
vis_nDCG = calculate_nDCG(similarity_matrix, rel_matrix, k_counts=vis_k_counts, IDCG=vis_IDCG)
|
200 |
+
txt_nDCG = calculate_nDCG(similarity_matrix.T, rel_matrix.T, k_counts=txt_k_counts, IDCG=txt_IDCG)
|
201 |
+
return vis_nDCG, txt_nDCG, (vis_nDCG + txt_nDCG) / 2
|
lavila/utils/preprocess.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import csv
|
8 |
+
|
9 |
+
from lavila.models.tokenizer import MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer
|
10 |
+
|
11 |
+
|
12 |
+
def generate_label_map(dataset):
|
13 |
+
if dataset == 'ek100_cls':
|
14 |
+
print("Preprocess ek100 action label space")
|
15 |
+
vn_list = []
|
16 |
+
mapping_vn2narration = {}
|
17 |
+
for f in [
|
18 |
+
'/data/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv',
|
19 |
+
'/data/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv',
|
20 |
+
]:
|
21 |
+
csv_reader = csv.reader(open(f))
|
22 |
+
_ = next(csv_reader) # skip the header
|
23 |
+
for row in csv_reader:
|
24 |
+
vn = '{}:{}'.format(int(row[10]), int(row[12]))
|
25 |
+
narration = row[8]
|
26 |
+
if vn not in vn_list:
|
27 |
+
vn_list.append(vn)
|
28 |
+
if vn not in mapping_vn2narration:
|
29 |
+
mapping_vn2narration[vn] = [narration]
|
30 |
+
else:
|
31 |
+
mapping_vn2narration[vn].append(narration)
|
32 |
+
# mapping_vn2narration[vn] = [narration]
|
33 |
+
vn_list = sorted(vn_list)
|
34 |
+
print('# of action= {}'.format(len(vn_list)))
|
35 |
+
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
|
36 |
+
labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))]
|
37 |
+
print(labels[:5])
|
38 |
+
elif dataset == 'charades_ego':
|
39 |
+
print("=> preprocessing charades_ego action label space")
|
40 |
+
vn_list = []
|
41 |
+
labels = []
|
42 |
+
with open('/data/CharadesEgo/CharadesEgo/Charades_v1_classes.txt') as f:
|
43 |
+
csv_reader = csv.reader(f)
|
44 |
+
for row in csv_reader:
|
45 |
+
vn = row[0][:4]
|
46 |
+
vn_list.append(vn)
|
47 |
+
narration = row[0][5:]
|
48 |
+
labels.append(narration)
|
49 |
+
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
|
50 |
+
print(labels[:5])
|
51 |
+
elif dataset == 'egtea':
|
52 |
+
print("=> preprocessing egtea action label space")
|
53 |
+
labels = []
|
54 |
+
with open('/data/EGTEA/action_idx.txt') as f:
|
55 |
+
for row in f:
|
56 |
+
row = row.strip()
|
57 |
+
narration = ' '.join(row.split(' ')[:-1])
|
58 |
+
labels.append(narration.replace('_', ' ').lower())
|
59 |
+
# labels.append(narration)
|
60 |
+
mapping_vn2act = {label: i for i, label in enumerate(labels)}
|
61 |
+
print(len(labels), labels[:5])
|
62 |
+
else:
|
63 |
+
raise NotImplementedError
|
64 |
+
return labels, mapping_vn2act
|
65 |
+
|
66 |
+
|
67 |
+
def generate_tokenizer(model):
|
68 |
+
if model.endswith('DISTILBERT_BASE'):
|
69 |
+
tokenizer = MyDistilBertTokenizer('distilbert-base-uncased')
|
70 |
+
elif model.endswith('BERT_BASE'):
|
71 |
+
tokenizer = MyBertTokenizer('bert-base-uncased')
|
72 |
+
elif model.endswith('BERT_LARGE'):
|
73 |
+
tokenizer = MyBertTokenizer('bert-large-uncased')
|
74 |
+
elif model.endswith('GPT2'):
|
75 |
+
tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True)
|
76 |
+
elif model.endswith('GPT2_MEDIUM'):
|
77 |
+
tokenizer = MyGPT2Tokenizer('gpt2-medium', add_bos=True)
|
78 |
+
elif model.endswith('GPT2_LARGE'):
|
79 |
+
tokenizer = MyGPT2Tokenizer('gpt2-large', add_bos=True)
|
80 |
+
elif model.endswith('GPT2_XL'):
|
81 |
+
tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True)
|
82 |
+
else:
|
83 |
+
print("Using SimpleTokenizer because of model '{}'. "
|
84 |
+
"Please check if this is what you want".format(model))
|
85 |
+
tokenizer = SimpleTokenizer()
|
86 |
+
return tokenizer
|
meta/ek100_mir/EPIC_100_retrieval_test_sentence.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
meta/ek100_mir/relevancy_sel_t2v.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6eed7e20dbe71de579e467ca9ab340154ba434461c34a7d089f9291c90739d9f
|
3 |
+
size 232160
|
meta/ek100_mir/relevancy_sel_v2t.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f4054061cdc842aa090de3e5bca380af61be9e2e7d93cdacca0aa13237026ce4
|
3 |
+
size 92336
|
meta/ek100_mir/sel_t2v.csv
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
4465,P16/P16_04.MP4,307.63,322.87,cut the sausage,7,86
|
2 |
+
4466,P16/P16_04.MP4,317.15,338.93,cut the sausage,7,86
|
3 |
+
4467,P16/P16_04.MP4,333.73,367.15,cut the sausage,7,86
|
4 |
+
3552,P11/P11_17.MP4,462.77,473.31,stir vegetables into salmon,10,94
|
5 |
+
3555,P11/P11_17.MP4,481.05,492.14,stir vegetables into salmon,10,94
|
6 |
+
3557,P11/P11_17.MP4,492.24,502.75,stir vegetables into salmon,10,94
|
7 |
+
730,P01/P01_15.MP4,605.88,619.71,rinse cutting board,2,18
|
8 |
+
2526,P07/P07_15.MP4,72.62,82.25,rinse board,2,18
|
9 |
+
4494,P16/P16_04.MP4,25.09,30.59,wash the cutting board,2,18
|
10 |
+
1760,P04/P04_28.MP4,9.03,12.69,pour coconut milk into pan,9,64
|
meta/ek100_mir/sel_v2t.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
762,P01/P01_15.MP4,72.97,73.87,take sponge,0,9
|
2 |
+
9051,P30/P30_08.MP4,203.34,204.35,open microwave,3,90
|
3 |
+
5920,P22/P22_01.MP4,224.87,226.46,rinse sponge,2,9
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
timm==0.5.4
|
2 |
+
torch==1.11.0
|
3 |
+
torchvision==0.12.0
|
4 |
+
decord==0.6.0
|
5 |
+
einops==0.4.1
|
6 |
+
pandas==1.4.2
|
7 |
+
pytorchvideo==0.1.5
|
8 |
+
transformers==4.27
|
9 |
+
ftfy==4.4.3
|
10 |
+
spacy==3.4.1
|
11 |
+
scikit-learn==1.1.1
|
12 |
+
numpy==1.22.3
|
13 |
+
gradio==4.19.1
|
14 |
+
gradio_rich_textbox==0.4.2
|