Spaces:
Runtime error
Runtime error
import torch | |
import os | |
import numpy as np | |
from tqdm import tqdm | |
from vbench2_beta_i2v.third_party.cotracker.utils.visualizer import Visualizer | |
from vbench2_beta_i2v.utils import load_video, load_dimension_info | |
def transform(vector): | |
x = np.mean([item[0] for item in vector]) | |
y = np.mean([item[1] for item in vector]) | |
return [x, y] | |
def transform_class(vector, min_reso, factor=0.005): # 768*0.05 | |
scale = min_reso * factor | |
x, y = vector | |
direction = [] | |
if x > scale: | |
direction.append("right") | |
elif x < -scale: | |
direction.append("left") | |
if y > scale: | |
direction.append("down") | |
elif y < -scale: | |
direction.append("up") | |
return direction if direction else ["static"] | |
class CameraPredict: | |
def __init__(self, device, submodules_list): | |
self.device = device | |
self.grid_size = 10 | |
try: | |
self.model = torch.hub.load(submodules_list["repo"], submodules_list["model"]).to(self.device) | |
except: | |
# workaround for CERTIFICATE_VERIFY_FAILED (see: https://github.com/pytorch/pytorch/issues/33288#issuecomment-954160699) | |
import ssl | |
ssl._create_default_https_context = ssl._create_unverified_context | |
self.model = torch.hub.load(submodules_list["repo"], submodules_list["model"]).to(self.device) | |
def infer(self, video_path, save_video=False, save_dir="./saved_videos"): | |
# load video | |
video = load_video(video_path, return_tensor=False) | |
# set scale | |
height, width = video.shape[1], video.shape[2] | |
self.scale = min(height, width) | |
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float().to(self.device) # B T C H W | |
pred_tracks, pred_visibility = self.model(video, grid_size=self.grid_size) # B T N 2, B T N 1 | |
if save_video: | |
video_name = os.path.basename(video_path)[:-4] | |
vis = Visualizer(save_dir=save_dir, pad_value=120, linewidth=3) | |
vis.visualize(video, pred_tracks, pred_visibility, filename=video_name) | |
return pred_tracks[0].long().detach().cpu().numpy() | |
def get_edge_point(self, track): | |
middle = self.grid_size // 2 | |
top = [list(track[0, i, :]) for i in range(middle-2, middle+2)] | |
down = [list(track[self.grid_size-1, i, :]) for i in range(middle-2, middle+2)] | |
left = [list(track[i, 0, :]) for i in range(middle-2, middle+2)] | |
right = [list(track[i, self.grid_size-1, :]) for i in range(middle-2, middle+2)] | |
return top, down, left, right | |
def get_edge_direction(self, track1, track2): | |
edge_points1 = self.get_edge_point(track1) | |
edge_points2 = self.get_edge_point(track2) | |
vector_results = [] | |
for points1, points2 in zip(edge_points1, edge_points2): | |
vectors = [[end[0]-start[0], end[1]-start[1]] for start, end in zip(points1, points2)] | |
vector_results.append(vectors) | |
vector_results = list(map(transform, vector_results)) | |
class_results = [transform_class(vector, min_reso=self.scale) for vector in vector_results] | |
return class_results | |
def classify_top_down(self, top, down): | |
results = [] | |
classes = [f"{item_t}_{item_d}" for item_t in top for item_d in down] | |
results_mapping = { | |
"left_left": "pan_right", | |
"right_right": "pan_left", | |
"down_down": "tilt_up", | |
"up_up": "tilt_down", | |
"up_down": "zoom_in", | |
"down_up": "zoom_out", | |
"static_static": "static" | |
} | |
results = [results_mapping.get(cls) for cls in classes if cls in results_mapping] | |
return results if results else ["None"] | |
def classify_left_right(self, left, right): | |
results = [] | |
classes = [f"{item_l}_{item_r}" for item_l in left for item_r in right] | |
results_mapping = { | |
"left_left": "pan_right", | |
"right_right": "pan_left", | |
"down_down": "tilt_up", | |
"up_up": "tilt_down", | |
"left_right": "zoom_in", | |
"right_left": "zoom_out", | |
"static_static": "static" | |
} | |
results = [results_mapping.get(cls) for cls in classes if cls in results_mapping] | |
return results if results else ["None"] | |
def camera_classify(self, track1, track2): | |
top, down, left, right = self.get_edge_direction(track1, track2) | |
top_results = self.classify_top_down(top, down) | |
left_results = self.classify_left_right(left, right) | |
results = list(set(top_results+left_results)) | |
if "static" in results and len(results)>1: | |
results.remove("static") | |
if "None" in results and len(results)>1: | |
results.remove("None") | |
return results | |
def predict(self, video_path): | |
pred_track = self.infer(video_path) | |
track1 = pred_track[0].reshape((self.grid_size, self.grid_size, 2)) | |
track2 = pred_track[-1].reshape((self.grid_size, self.grid_size, 2)) | |
results = self.camera_classify(track1, track2) | |
return results | |
def get_type(video_name): | |
camera_mapping = { | |
"camera pans left": "pan_left", | |
"camera pans right": "pan_right", | |
"camera tilts up": "tilt_up", | |
"camera tilts down": "tilt_down", | |
"camera zooms in": "zoom_in", | |
"camera zooms out": "zoom_out", | |
"camera static": "static" | |
} | |
for item, value in camera_mapping.items(): | |
if item in video_name: | |
return value | |
raise ValueError("Not a recognized video name") | |
def camera_motion(camera, video_list): | |
sim = [] | |
video_results = [] | |
diff_type_results = { | |
"pan_left":[], | |
"pan_right":[], | |
"tilt_up":[], | |
"tilt_down":[], | |
"zoom_in":[], | |
"zoom_out":[], | |
"static":[], | |
} | |
for video_path in tqdm(video_list): | |
target_type = get_type(os.path.basename(video_path)) | |
predict_results = camera.predict(video_path) | |
video_score = 1.0 if target_type in predict_results else 0.0 | |
diff_type_results[target_type].append(video_score) | |
video_results.append({'video_path': video_path, 'video_results': video_score, 'prompt_type':target_type, 'predict_type': predict_results}) | |
sim.append(video_score) | |
avg_score = np.mean(sim) | |
for key, value in diff_type_results.items(): | |
diff_type_results[key] = np.mean(value) | |
return avg_score, diff_type_results, video_results | |
def compute_camera_motion(json_dir, device, submodules_list): | |
camera = CameraPredict(device, submodules_list) | |
video_list, _ = load_dimension_info(json_dir, dimension='camera_motion', lang='en') | |
all_results, diff_type_results, video_results = camera_motion(camera, video_list) | |
return all_results, diff_type_results, video_results | |