Spaces:
Runtime error
Runtime error
File size: 8,490 Bytes
ef1c94f 79d80e3 ef1c94f 84805b3 ef1c94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
from run_on_video.data_utils import ClipFeatureExtractor
from run_on_video.model_utils import build_inference_model
from utils.tensor_utils import pad_sequences_1d
from moment_detr.span_utils import span_cxw_to_xx
from utils.basic_utils import l2_normalize_np_array
import torch.nn.functional as F
import numpy as np
import os
from PIL import Image
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from moviepy.video.io.VideoFileClip import VideoFileClip
class MomentDETRPredictor:
def __init__(self, ckpt_path, clip_model_name_or_path="ViT-B/32", device="cuda"):
self.clip_len = 2 # seconds
self.device = device
print("Loading feature extractors...")
self.feature_extractor = ClipFeatureExtractor(
framerate=1/self.clip_len, size=224, centercrop=True,
model_name_or_path=clip_model_name_or_path, device=device
)
print("Loading trained Moment-DETR model...")
self.model = build_inference_model(ckpt_path).to(self.device)
self.model.eval()
@torch.no_grad()
def localize_moment(self, video_path, query_list):
"""
Args:
video_path: str, path to the video file
query_list: List[str], each str is a query for this video
"""
# construct model inputs
n_query = len(query_list)
video_feats, video_frames = self.feature_extractor.encode_video(video_path)
video_feats = F.normalize(video_feats, dim=-1, eps=1e-5)
n_frames = len(video_feats)
# add tef
tef_st = torch.arange(0, n_frames, 1.0) / n_frames
tef_ed = tef_st + 1.0 / n_frames
tef = torch.stack([tef_st, tef_ed], dim=1).to(self.device) # (n_frames, 2)
video_feats = torch.cat([video_feats, tef], dim=1)
assert n_frames <= 75, "The positional embedding of this pretrained MomentDETR only support video up " \
"to 150 secs (i.e., 75 2-sec clips) in length"
video_feats = video_feats.unsqueeze(0).repeat(n_query, 1, 1) # (#text, T, d)
video_mask = torch.ones(n_query, n_frames).to(self.device)
query_feats = self.feature_extractor.encode_text(query_list) # #text * (L, d)
query_feats, query_mask = pad_sequences_1d(
query_feats, dtype=torch.float32, device=self.device, fixed_length=None)
query_feats = F.normalize(query_feats, dim=-1, eps=1e-5)
model_inputs = dict(
src_vid=video_feats,
src_vid_mask=video_mask,
src_txt=query_feats,
src_txt_mask=query_mask
)
# decode outputs
outputs = self.model(**model_inputs)
# #moment_queries refers to the positional embeddings in MomentDETR's decoder, not the input text query
prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #moment_queries=10, #classes=2)
scores = prob[..., 0] # * (batch_size, #moment_queries) foreground label is 0, we directly take it
pred_spans = outputs["pred_spans"] # (bsz, #moment_queries, 2)
print(pred_spans)
_saliency_scores = outputs["saliency_scores"].half() # (bsz, L)
saliency_scores = []
valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist()
for j in range(len(valid_vid_lengths)):
_score = _saliency_scores[j, :int(valid_vid_lengths[j])].tolist()
_score = [round(e, 4) for e in _score]
saliency_scores.append(_score)
# compose predictions
predictions = []
video_duration = n_frames * self.clip_len
for idx, (spans, score) in enumerate(zip(pred_spans.cpu(), scores.cpu())):
spans = span_cxw_to_xx(spans) * video_duration
# # (#queries, 3), [st(float), ed(float), score(float)]
cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist()
cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True)
cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds]
cur_query_pred = dict(
query=query_list[idx], # str
vid=video_path,
pred_relevant_windows=cur_ranked_preds, # List([st(float), ed(float), score(float)])
pred_saliency_scores=saliency_scores[idx] # List(float), len==n_frames, scores for each frame
)
predictions.append(cur_query_pred)
return predictions, video_frames
def run_example():
# load example data
from utils.basic_utils import load_jsonl
video_dir = "run_on_video/example/testing_videos/dogs"
#video_path = "run_on_video/example/testing_videos/"
query_path = "run_on_video/example/queries_highlight.jsonl"
queries = load_jsonl(query_path)
query_text_list = [e["query"] for e in queries]
ckpt_path = "run_on_video/moment_detr_ckpt/model_best.ckpt"
# run predictions
print("Build models...")
clip_model_name_or_path = "ViT-B/32"
# clip_model_name_or_path = "tmp/ViT-B-32.pt"
moment_detr_predictor = MomentDETRPredictor(
ckpt_path=ckpt_path,
clip_model_name_or_path=clip_model_name_or_path,
device="cuda"
)
print("Run prediction...")
video_paths = [os.path.join(video_dir, e) for e in os.listdir(video_dir)]
#video_paths = ["run_on_video/example/testing_videos/celebration_18s.mov"]
for video_path in video_paths:
output_dir = os.path.join("run_on_video/example/output/dog/empty_str", os.path.basename(video_path))
predictions, video_frames = moment_detr_predictor.localize_moment(
video_path=video_path, query_list=query_text_list)
#check output directory exists
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# print data
for idx, query_data in enumerate(queries):
print("-"*30 + f"idx{idx}")
print(f">> query: {query_data['query']}")
print(f">> video_path: {video_path}")
#print(f">> GT moments: {query_data['relevant_windows']}")
print(f">> Predicted moments ([start_in_seconds, end_in_seconds, score]): "
f"{predictions[idx]['pred_relevant_windows']}")
#print(f">> GT saliency scores (only localized 2-sec clips): {query_data['saliency_scores']}")
print(f">> Predicted saliency scores (for all 2-sec clip): "
f"{predictions[idx]['pred_saliency_scores']}")
#output the retrievved moments
#sort the moment by the third element in the list
predictions[idx]['pred_relevant_windows'] = sorted(predictions[idx]['pred_relevant_windows'], key=lambda x: x[2], reverse=True)
for i, (start_time, end_time, score) in enumerate(predictions[idx]['pred_relevant_windows']):
print(start_time, end_time, score)
ffmpeg_extract_subclip(video_path, start_time, end_time, targetname=os.path.join(output_dir, f'moment_{i}.mp4'))
#store the sorted pred_relevant_windows scores and time
with open(os.path.join(output_dir, 'moment_scores.txt'), 'w') as f:
for i, (start_time, end_time, score) in enumerate(predictions[idx]['pred_relevant_windows']):
f.write(str(i)+'. '+str(start_time)+' '+str(end_time)+' '+str(score) + '\n')
#To-dos: save the video frames sorted by pred_saliency_scores
sorted_frames = [frame for _, frame in sorted(zip(predictions[idx]['pred_saliency_scores'], video_frames), reverse=True)]
#save the sorted scores and also the original index
sorted_scores = sorted(predictions[idx]['pred_saliency_scores'], reverse=True)
print(sorted_scores)
#save frames to output directory
for i, frame in enumerate(sorted_frames):
#transfer frame from tensor to PIL image
frame = frame.permute(1, 2, 0).cpu().numpy()
frame = frame.astype(np.uint8)
frame = Image.fromarray(frame)
frame.save(os.path.join(output_dir, str(i) + '.jpg'))
#save scores to output directory
with open(os.path.join(output_dir, 'scores.txt'), 'w') as f:
for i, score in enumerate(sorted_scores):
f.write(str(i)+'. '+str(score) + '\n')
if __name__ == "__main__":
run_example()
|