|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
import math |
|
import os |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from PIL import Image |
|
|
|
from inference_config import SamplingConfig |
|
from .log import log |
|
|
|
_IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"] |
|
_VIDEO_EXTENSIONS = [".mp4"] |
|
_SUPPORTED_CONTEXT_LEN = [1, 9] |
|
NUM_TOTAL_FRAMES = 33 |
|
|
|
|
|
def add_common_arguments(parser): |
|
"""Add common command line arguments. |
|
|
|
Args: |
|
parser (ArgumentParser): Argument parser to add arguments to |
|
""" |
|
parser.add_argument( |
|
"--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints" |
|
) |
|
parser.add_argument( |
|
"--video_save_name", |
|
type=str, |
|
default="output", |
|
help="Output filename for generating a single video", |
|
) |
|
parser.add_argument("--video_save_folder", type=str, default="outputs/", help="Output folder for saving videos") |
|
parser.add_argument( |
|
"--input_image_or_video_path", |
|
type=str, |
|
help="Input path for input image or video", |
|
) |
|
parser.add_argument( |
|
"--batch_input_path", |
|
type=str, |
|
help="Input folder containing all input images or videos", |
|
) |
|
parser.add_argument( |
|
"--num_input_frames", |
|
type=int, |
|
default=9, |
|
help="Number of input frames for world generation", |
|
choices=_SUPPORTED_CONTEXT_LEN, |
|
) |
|
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling") |
|
parser.add_argument("--top_p", type=float, default=0.8, help="Top-p value for sampling") |
|
parser.add_argument("--seed", type=int, default=0, help="Random seed") |
|
parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder") |
|
parser.add_argument( |
|
"--offload_guardrail_models", |
|
action="store_true", |
|
help="Offload guardrail models after inference", |
|
) |
|
parser.add_argument( |
|
"--offload_diffusion_decoder", |
|
action="store_true", |
|
help="Offload diffusion decoder after inference", |
|
) |
|
parser.add_argument( |
|
"--offload_ar_model", |
|
action="store_true", |
|
help="Offload AR model after inference", |
|
) |
|
parser.add_argument( |
|
"--offload_tokenizer", |
|
action="store_true", |
|
help="Offload discrete tokenizer model after inference", |
|
) |
|
|
|
|
|
def validate_args(args: argparse.Namespace, inference_type: str): |
|
"""Validate command line arguments for base and video2world generation.""" |
|
assert inference_type in [ |
|
"base", |
|
"video2world", |
|
], "Invalid inference_type, must be 'base' or 'video2world'" |
|
if args.input_type in ["image", "text_and_image"] and args.num_input_frames != 1: |
|
args.num_input_frames = 1 |
|
log.info(f"Set num_input_frames to 1 for {args.input_type} input") |
|
|
|
if args.num_input_frames == 1: |
|
if "4B" in args.ar_model_dir: |
|
log.warning( |
|
"The failure rate for 4B model with image input is ~15%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." |
|
) |
|
elif "5B" in args.ar_model_dir: |
|
log.warning( |
|
"The failure rate for 5B model with image input is ~7%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details." |
|
) |
|
|
|
|
|
assert ( |
|
args.input_image_or_video_path or args.batch_input_path |
|
), "--input_image_or_video_path or --batch_input_path must be provided." |
|
if inference_type == "video2world" and (not args.batch_input_path): |
|
assert args.prompt, "--prompt is required for single video generation." |
|
args.data_resolution = [640, 1024] |
|
|
|
|
|
num_gpus = int(os.getenv("WORLD_SIZE", 1)) |
|
assert num_gpus <= 1, "We support only single GPU inference for now" |
|
|
|
|
|
Path(args.video_save_folder).mkdir(parents=True, exist_ok=True) |
|
|
|
sampling_config = SamplingConfig( |
|
echo=True, |
|
temperature=args.temperature, |
|
top_p=args.top_p, |
|
compile_sampling=True, |
|
) |
|
return sampling_config |
|
|
|
|
|
def resize_input(video: torch.Tensor, resolution: list[int]): |
|
r""" |
|
Function to perform aspect ratio preserving resizing and center cropping. |
|
This is needed to make the video into target resolution. |
|
Args: |
|
video (torch.Tensor): Input video tensor |
|
resolution (list[int]): Data resolution |
|
Returns: |
|
Cropped video |
|
""" |
|
|
|
orig_h, orig_w = video.shape[2], video.shape[3] |
|
target_h, target_w = resolution |
|
|
|
scaling_ratio = max((target_w / orig_w), (target_h / orig_h)) |
|
resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w))) |
|
video_resized = torchvision.transforms.functional.resize(video, resizing_shape) |
|
video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution) |
|
return video_cropped |
|
|
|
|
|
def load_image_from_list(flist, data_resolution: List[int]) -> dict: |
|
""" |
|
Function to load images from a list of image paths. |
|
Args: |
|
flist (List[str]): List of image paths |
|
data_resolution (List[int]): Data resolution |
|
Returns: |
|
Dict containing input images |
|
""" |
|
all_videos = dict() |
|
for img_path in flist: |
|
ext = os.path.splitext(img_path)[1] |
|
if ext in _IMAGE_EXTENSIONS: |
|
|
|
img = Image.open(img_path) |
|
|
|
|
|
img = torchvision.transforms.functional.to_tensor(img) |
|
static_vid = img.unsqueeze(0).repeat(NUM_TOTAL_FRAMES, 1, 1, 1) |
|
static_vid = static_vid * 2 - 1 |
|
|
|
log.debug( |
|
f"Resizing input image of shape ({static_vid.shape[2]}, {static_vid.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" |
|
) |
|
static_vid = resize_input(static_vid, data_resolution) |
|
fname = os.path.basename(img_path) |
|
all_videos[fname] = static_vid.transpose(0, 1).unsqueeze(0) |
|
|
|
return all_videos |
|
|
|
|
|
def read_input_images(batch_input_path: str, data_resolution: List[int]) -> dict: |
|
""" |
|
Function to read input images from a JSONL file. |
|
|
|
Args: |
|
batch_input_path (str): Path to JSONL file containing visual input paths |
|
data_resolution (list[int]): Data resolution |
|
|
|
Returns: |
|
Dict containing input images |
|
""" |
|
|
|
flist = [] |
|
with open(batch_input_path, "r") as f: |
|
for line in f: |
|
data = json.loads(line.strip()) |
|
flist.append(data["visual_input"]) |
|
|
|
return load_image_from_list(flist, data_resolution=data_resolution) |
|
|
|
|
|
def read_input_image(input_path: str, data_resolution: List[int]) -> dict: |
|
""" |
|
Function to read input image. |
|
Args: |
|
input_path (str): Path to input image |
|
data_resolution (List[int]): Data resolution |
|
Returns: |
|
Dict containing input image |
|
""" |
|
flist = [input_path] |
|
return load_image_from_list(flist, data_resolution=data_resolution) |
|
|
|
|
|
def read_input_videos(batch_input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: |
|
r""" |
|
Function to read input videos. |
|
Args: |
|
batch_input_path (str): Path to JSONL file containing visual input paths |
|
data_resolution (list[int]): Data resolution |
|
Returns: |
|
Dict containing input videos |
|
""" |
|
|
|
flist = [] |
|
with open(batch_input_path, "r") as f: |
|
for line in f: |
|
data = json.loads(line.strip()) |
|
flist.append(data["visual_input"]) |
|
return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) |
|
|
|
|
|
def read_input_video(input_path: str, data_resolution: List[int], num_input_frames: int) -> dict: |
|
""" |
|
Function to read input video. |
|
Args: |
|
input_path (str): Path to input video |
|
data_resolution (List[int]): Data resolution |
|
num_input_frames (int): Number of frames in context |
|
Returns: |
|
Dict containing input video |
|
""" |
|
flist = [input_path] |
|
return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames) |
|
|
|
|
|
def load_videos_from_list(flist: List[str], data_resolution: List[int], num_input_frames: int) -> dict: |
|
""" |
|
Function to load videos from a list of video paths. |
|
Args: |
|
flist (List[str]): List of video paths |
|
data_resolution (List[int]): Data resolution |
|
num_input_frames (int): Number of frames in context |
|
Returns: |
|
Dict containing input videos |
|
""" |
|
all_videos = dict() |
|
|
|
for video_path in flist: |
|
ext = os.path.splitext(video_path)[-1] |
|
if ext in _VIDEO_EXTENSIONS: |
|
video, _, _ = torchvision.io.read_video(video_path, pts_unit="sec") |
|
video = video.float() / 255.0 |
|
video = video * 2 - 1 |
|
|
|
|
|
nframes_in_video = video.shape[0] |
|
if nframes_in_video < num_input_frames: |
|
fname = os.path.basename(video_path) |
|
log.warning( |
|
f"Video {fname} has {nframes_in_video} frames, less than the requried {num_input_frames} frames. Skipping." |
|
) |
|
continue |
|
|
|
video = video[-num_input_frames:, :, :, :] |
|
|
|
|
|
video = torch.cat( |
|
(video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_TOTAL_FRAMES - num_input_frames, 1, 1, 1)), |
|
dim=0, |
|
) |
|
|
|
video = video.permute(0, 3, 1, 2) |
|
|
|
log.debug( |
|
f"Resizing input video of shape ({video.shape[2]}, {video.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})" |
|
) |
|
video = resize_input(video, data_resolution) |
|
|
|
fname = os.path.basename(video_path) |
|
all_videos[fname] = video.transpose(0, 1).unsqueeze(0) |
|
|
|
return all_videos |
|
|
|
|
|
def load_vision_input( |
|
input_type: str, |
|
batch_input_path: str, |
|
input_image_or_video_path: str, |
|
data_resolution: List[int], |
|
num_input_frames: int, |
|
): |
|
""" |
|
Function to load vision input. |
|
Note: We pad the frames of the input image/video to NUM_TOTAL_FRAMES here, and feed the padded video tensors to the video tokenizer to obtain tokens. The tokens will be truncated based on num_input_frames when feeding to the autoregressive model. |
|
Args: |
|
input_type (str): Type of input |
|
batch_input_path (str): Folder containing input images or videos |
|
input_image_or_video_path (str): Path to input image or video |
|
data_resolution (List[int]): Data resolution |
|
num_input_frames (int): Number of frames in context |
|
Returns: |
|
Dict containing input videos |
|
""" |
|
if batch_input_path: |
|
log.info(f"Reading batch inputs from path: {batch_input_path}") |
|
if input_type == "image" or input_type == "text_and_image": |
|
input_videos = read_input_images(batch_input_path, data_resolution=data_resolution) |
|
elif input_type == "video" or input_type == "text_and_video": |
|
input_videos = read_input_videos( |
|
batch_input_path, |
|
data_resolution=data_resolution, |
|
num_input_frames=num_input_frames, |
|
) |
|
else: |
|
raise ValueError(f"Invalid input type {input_type}") |
|
else: |
|
if input_type == "image" or input_type == "text_and_image": |
|
input_videos = read_input_image(input_image_or_video_path, data_resolution=data_resolution) |
|
elif input_type == "video" or input_type == "text_and_video": |
|
input_videos = read_input_video( |
|
input_image_or_video_path, |
|
data_resolution=data_resolution, |
|
num_input_frames=num_input_frames, |
|
) |
|
else: |
|
raise ValueError(f"Invalid input type {input_type}") |
|
return input_videos |
|
|
|
|
|
def prepare_video_batch_for_saving(video_batch: List[torch.Tensor]) -> List[np.ndarray]: |
|
""" |
|
Function to convert output tensors to numpy format for saving. |
|
Args: |
|
video_batch (List[torch.Tensor]): List of output tensors |
|
Returns: |
|
List of numpy arrays |
|
""" |
|
return [(video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() for video in video_batch] |
|
|