|
import logging |
|
from pathlib import Path |
|
from typing import List, Tuple |
|
|
|
import cv2 |
|
import torch |
|
from torchvision.transforms.functional import resize |
|
from einops import repeat, rearrange |
|
|
|
|
|
|
|
import decord |
|
|
|
decord.bridge.set_bridge("torch") |
|
|
|
from PIL import Image |
|
import numpy as np |
|
import pdb |
|
|
|
|
|
|
|
|
|
def load_prompts(prompt_path: Path) -> List[str]: |
|
with open(prompt_path, "r", encoding="utf-8") as file: |
|
return [line.strip() for line in file.readlines() if len(line.strip()) > 0] |
|
|
|
|
|
def load_videos(video_path: Path) -> List[Path]: |
|
with open(video_path, "r", encoding="utf-8") as file: |
|
return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] |
|
|
|
|
|
def load_images(image_path: Path) -> List[Path]: |
|
with open(image_path, "r", encoding="utf-8") as file: |
|
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0] |
|
|
|
|
|
def load_images_from_videos(videos_path: List[Path]) -> List[Path]: |
|
first_frames_dir = videos_path[0].parent.parent / "first_frames" |
|
first_frames_dir.mkdir(exist_ok=True) |
|
|
|
first_frame_paths = [] |
|
for video_path in videos_path: |
|
frame_path = first_frames_dir / f"{video_path.stem}.png" |
|
if frame_path.exists(): |
|
first_frame_paths.append(frame_path) |
|
continue |
|
|
|
|
|
cap = cv2.VideoCapture(str(video_path)) |
|
|
|
|
|
ret, frame = cap.read() |
|
if not ret: |
|
raise RuntimeError(f"Failed to read video: {video_path}") |
|
|
|
|
|
cv2.imwrite(str(frame_path), frame) |
|
logging.info(f"Saved first frame to {frame_path}") |
|
|
|
|
|
cap.release() |
|
|
|
first_frame_paths.append(frame_path) |
|
|
|
return first_frame_paths |
|
|
|
|
|
def load_binary_mask_compressed(path, shape, device, dtype): |
|
|
|
with open(path, 'rb') as f: |
|
packed = np.frombuffer(f.read(), dtype=np.uint8) |
|
unpacked = np.unpackbits(packed)[:np.prod(shape)] |
|
mask_loaded = torch.from_numpy(unpacked).to(device, dtype).reshape(shape) |
|
|
|
mask_interp = torch.nn.functional.interpolate(rearrange(mask_loaded, 'f c h w -> c f h w').unsqueeze(0), size=(shape[0]//4+1, shape[2]//8, shape[3]//8), mode='trilinear', align_corners=False).squeeze(0) |
|
mask_interp[mask_interp>=0.5] = 1.0 |
|
mask_interp[mask_interp<0.5] = 0.0 |
|
|
|
return rearrange(mask_loaded, 'f c h w -> c f h w'), mask_interp |
|
|
|
|
|
|
|
|
|
def preprocess_image_with_resize( |
|
image_path: Path | str, |
|
height: int, |
|
width: int, |
|
) -> torch.Tensor: |
|
""" |
|
Loads and resizes a single image. |
|
|
|
Args: |
|
image_path: Path to the image file. |
|
height: Target height for resizing. |
|
width: Target width for resizing. |
|
|
|
Returns: |
|
torch.Tensor: Image tensor with shape [C, H, W] where: |
|
C = number of channels (3 for RGB) |
|
H = height |
|
W = width |
|
""" |
|
if isinstance(image_path, str): |
|
image_path = Path(image_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
image = np.array(Image.open(image_path.as_posix()).resize((width, height))) |
|
image = torch.from_numpy(image).float() |
|
image = image.permute(2, 0, 1).contiguous() |
|
|
|
return image |
|
|
|
|
|
def preprocess_video_with_resize( |
|
video_path: Path | str, |
|
max_num_frames: int, |
|
height: int, |
|
width: int, |
|
) -> torch.Tensor: |
|
""" |
|
Loads and resizes a single video. |
|
|
|
The function processes the video through these steps: |
|
1. If video frame count > max_num_frames, downsample frames evenly |
|
2. If video dimensions don't match (height, width), resize frames |
|
|
|
Args: |
|
video_path: Path to the video file. |
|
max_num_frames: Maximum number of frames to keep. |
|
height: Target height for resizing. |
|
width: Target width for resizing. |
|
|
|
Returns: |
|
A torch.Tensor with shape [F, C, H, W] where: |
|
F = number of frames |
|
C = number of channels (3 for RGB) |
|
H = height |
|
W = width |
|
""" |
|
if isinstance(video_path, str): |
|
video_path = Path(video_path) |
|
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height) |
|
video_num_frames = len(video_reader) |
|
if video_num_frames < max_num_frames: |
|
|
|
frames = video_reader.get_batch(list(range(video_num_frames))) |
|
|
|
last_frame = frames[-1:] |
|
num_repeats = max_num_frames - video_num_frames |
|
repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1) |
|
frames = torch.cat([frames, repeated_frames], dim=0) |
|
return frames.float().permute(0, 3, 1, 2).contiguous() |
|
else: |
|
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames)) |
|
frames = video_reader.get_batch(indices) |
|
import pdb |
|
pdb.set_trace() |
|
frames = frames[:max_num_frames].float() |
|
frames = frames.permute(0, 3, 1, 2).contiguous() |
|
return frames |
|
|
|
|
|
def preprocess_video_with_buckets( |
|
video_path: Path, |
|
resolution_buckets: List[Tuple[int, int, int]], |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
video_path: Path to the video file. |
|
resolution_buckets: List of tuples (num_frames, height, width) representing |
|
available resolution buckets. |
|
|
|
Returns: |
|
torch.Tensor: Video tensor with shape [F, C, H, W] where: |
|
F = number of frames |
|
C = number of channels (3 for RGB) |
|
H = height |
|
W = width |
|
|
|
The function processes the video through these steps: |
|
1. Finds nearest frame bucket <= video frame count |
|
2. Downsamples frames evenly to match bucket size |
|
3. Finds nearest resolution bucket based on dimensions |
|
4. Resizes frames to match bucket resolution |
|
""" |
|
video_reader = decord.VideoReader(uri=video_path.as_posix()) |
|
video_num_frames = len(video_reader) |
|
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames] |
|
if len(resolution_buckets) == 0: |
|
raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}") |
|
|
|
nearest_frame_bucket = min( |
|
resolution_buckets, |
|
key=lambda bucket: video_num_frames - bucket[0], |
|
default=1, |
|
)[0] |
|
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) |
|
frames = video_reader.get_batch(frame_indices) |
|
frames = frames[:nearest_frame_bucket].float() |
|
frames = frames.permute(0, 3, 1, 2).contiguous() |
|
|
|
nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3])) |
|
nearest_res = (nearest_res[1], nearest_res[2]) |
|
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0) |
|
|
|
return frames |
|
|