File size: 8,490 Bytes
ef1c94f
 
79d80e3
 
ef1c94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84805b3
ef1c94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch

from run_on_video.data_utils import ClipFeatureExtractor
from run_on_video.model_utils import build_inference_model
from utils.tensor_utils import pad_sequences_1d
from moment_detr.span_utils import span_cxw_to_xx
from utils.basic_utils import l2_normalize_np_array
import torch.nn.functional as F
import numpy as np
import os
from PIL import Image

from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from moviepy.video.io.VideoFileClip import VideoFileClip


class MomentDETRPredictor:
    def __init__(self, ckpt_path, clip_model_name_or_path="ViT-B/32", device="cuda"):
        self.clip_len = 2  # seconds
        self.device = device
        print("Loading feature extractors...")
        self.feature_extractor = ClipFeatureExtractor(
            framerate=1/self.clip_len, size=224, centercrop=True,
            model_name_or_path=clip_model_name_or_path, device=device
        )
        print("Loading trained Moment-DETR model...")
        self.model = build_inference_model(ckpt_path).to(self.device)
        self.model.eval()

    @torch.no_grad()
    def localize_moment(self, video_path, query_list):
        """
        Args:
            video_path: str, path to the video file
            query_list: List[str], each str is a query for this video
        """
        # construct model inputs
        n_query = len(query_list)
        video_feats, video_frames = self.feature_extractor.encode_video(video_path)
        video_feats = F.normalize(video_feats, dim=-1, eps=1e-5)
        n_frames = len(video_feats)
        # add tef
        tef_st = torch.arange(0, n_frames, 1.0) / n_frames
        tef_ed = tef_st + 1.0 / n_frames
        tef = torch.stack([tef_st, tef_ed], dim=1).to(self.device)  # (n_frames, 2)
        video_feats = torch.cat([video_feats, tef], dim=1)
        
        assert n_frames <= 75, "The positional embedding of this pretrained MomentDETR only support video up " \
                               "to 150 secs (i.e., 75 2-sec clips) in length"
        video_feats = video_feats.unsqueeze(0).repeat(n_query, 1, 1)  # (#text, T, d)
        video_mask = torch.ones(n_query, n_frames).to(self.device)
        query_feats = self.feature_extractor.encode_text(query_list)  # #text * (L, d)
        query_feats, query_mask = pad_sequences_1d(
            query_feats, dtype=torch.float32, device=self.device, fixed_length=None)
        query_feats = F.normalize(query_feats, dim=-1, eps=1e-5)
        model_inputs = dict(
            src_vid=video_feats,
            src_vid_mask=video_mask,
            src_txt=query_feats,
            src_txt_mask=query_mask
        )

        # decode outputs
        outputs = self.model(**model_inputs)
        # #moment_queries refers to the positional embeddings in MomentDETR's decoder, not the input text query
        prob = F.softmax(outputs["pred_logits"], -1)  # (batch_size, #moment_queries=10, #classes=2)
        scores = prob[..., 0]  # * (batch_size, #moment_queries)  foreground label is 0, we directly take it
        pred_spans = outputs["pred_spans"]  # (bsz, #moment_queries, 2)
        print(pred_spans)
        _saliency_scores = outputs["saliency_scores"].half()  # (bsz, L)
        saliency_scores = []
        valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist()
        for j in range(len(valid_vid_lengths)):
            _score = _saliency_scores[j, :int(valid_vid_lengths[j])].tolist()
            _score = [round(e, 4) for e in _score]
            saliency_scores.append(_score)

        # compose predictions
        predictions = []
        video_duration = n_frames * self.clip_len
        for idx, (spans, score) in enumerate(zip(pred_spans.cpu(), scores.cpu())):
            spans = span_cxw_to_xx(spans) * video_duration
            # # (#queries, 3), [st(float), ed(float), score(float)]
            cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist()
            cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True)
            cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds]
            cur_query_pred = dict(
                query=query_list[idx],  # str
                vid=video_path,
                pred_relevant_windows=cur_ranked_preds,  # List([st(float), ed(float), score(float)])
                pred_saliency_scores=saliency_scores[idx]  # List(float), len==n_frames, scores for each frame
            )
            predictions.append(cur_query_pred)

        return predictions, video_frames


def run_example():
    # load example data
    from utils.basic_utils import load_jsonl
    video_dir = "run_on_video/example/testing_videos/dogs"
    
    #video_path = "run_on_video/example/testing_videos/"
    query_path = "run_on_video/example/queries_highlight.jsonl"
    queries = load_jsonl(query_path)
    query_text_list = [e["query"] for e in queries]
    ckpt_path = "run_on_video/moment_detr_ckpt/model_best.ckpt"

    # run predictions
    print("Build models...")
    clip_model_name_or_path = "ViT-B/32"
    # clip_model_name_or_path = "tmp/ViT-B-32.pt"
    moment_detr_predictor = MomentDETRPredictor(
        ckpt_path=ckpt_path,
        clip_model_name_or_path=clip_model_name_or_path,
        device="cuda"
    )
    print("Run prediction...")
    video_paths = [os.path.join(video_dir, e) for e in os.listdir(video_dir)]
    #video_paths = ["run_on_video/example/testing_videos/celebration_18s.mov"]

    for video_path in video_paths:
        output_dir = os.path.join("run_on_video/example/output/dog/empty_str", os.path.basename(video_path))
        predictions, video_frames = moment_detr_predictor.localize_moment(
            video_path=video_path, query_list=query_text_list)
        #check output directory exists
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        # print data
        for idx, query_data in enumerate(queries):
            print("-"*30 + f"idx{idx}")
            print(f">> query: {query_data['query']}")
            print(f">> video_path: {video_path}")
            #print(f">> GT moments: {query_data['relevant_windows']}")
            print(f">> Predicted moments ([start_in_seconds, end_in_seconds, score]): "
                f"{predictions[idx]['pred_relevant_windows']}")
            #print(f">> GT saliency scores (only localized 2-sec clips): {query_data['saliency_scores']}")
            print(f">> Predicted saliency scores (for all 2-sec clip): "
                f"{predictions[idx]['pred_saliency_scores']}")
            #output the retrievved moments
            #sort the moment by the third element in the list
            predictions[idx]['pred_relevant_windows'] = sorted(predictions[idx]['pred_relevant_windows'], key=lambda x: x[2], reverse=True)
            for i, (start_time, end_time, score) in enumerate(predictions[idx]['pred_relevant_windows']):
                print(start_time, end_time, score)
                ffmpeg_extract_subclip(video_path, start_time, end_time, targetname=os.path.join(output_dir, f'moment_{i}.mp4'))
            #store the sorted pred_relevant_windows scores and time
            with open(os.path.join(output_dir, 'moment_scores.txt'), 'w') as f:
                for i, (start_time, end_time, score) in enumerate(predictions[idx]['pred_relevant_windows']):
                    f.write(str(i)+'. '+str(start_time)+' '+str(end_time)+' '+str(score) + '\n')
            #To-dos: save the video frames sorted by pred_saliency_scores
            sorted_frames = [frame for _, frame in sorted(zip(predictions[idx]['pred_saliency_scores'], video_frames), reverse=True)]
            #save the sorted scores and also the original index
            sorted_scores = sorted(predictions[idx]['pred_saliency_scores'], reverse=True)
            print(sorted_scores)
            #save frames to output directory
            for i, frame in enumerate(sorted_frames):
                #transfer frame from tensor to PIL image
                frame = frame.permute(1, 2, 0).cpu().numpy()
                frame = frame.astype(np.uint8)
                frame = Image.fromarray(frame)
                frame.save(os.path.join(output_dir, str(i) + '.jpg'))
            #save scores to output directory
            with open(os.path.join(output_dir, 'scores.txt'), 'w') as f:
                for i, score in enumerate(sorted_scores):
                    f.write(str(i)+'. '+str(score) + '\n')



if __name__ == "__main__":
    run_example()