|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) |
|
|
sys.path.append(ROOT_DIR) |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import imageio |
|
|
import json |
|
|
from diffsynth import WanVideoAstraPipeline, ModelManager |
|
|
import argparse |
|
|
from torchvision.transforms import v2 |
|
|
from einops import rearrange |
|
|
from scipy.spatial.transform import Rotation as R |
|
|
import random |
|
|
import copy |
|
|
from datetime import datetime |
|
|
|
|
|
VALID_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg"} |
|
|
class InlineVideoEncoder: |
|
|
|
|
|
def __init__(self, pipe: WanVideoAstraPipeline, device="cuda"): |
|
|
self.device = getattr(pipe, "device", device) |
|
|
self.tiler_kwargs = {"tiled": True, "tile_size": (34, 34), "tile_stride": (18, 16)} |
|
|
self.frame_process = v2.Compose([ |
|
|
v2.ToTensor(), |
|
|
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
|
]) |
|
|
|
|
|
self.pipe = pipe |
|
|
|
|
|
@staticmethod |
|
|
def _crop_and_resize(image: Image.Image) -> Image.Image: |
|
|
target_w, target_h = 832, 480 |
|
|
return v2.functional.resize( |
|
|
image, |
|
|
(round(target_h), round(target_w)), |
|
|
interpolation=v2.InterpolationMode.BILINEAR, |
|
|
) |
|
|
|
|
|
def preprocess_frame(self, image: Image.Image) -> torch.Tensor: |
|
|
image = image.convert("RGB") |
|
|
image = self._crop_and_resize(image) |
|
|
return self.frame_process(image) |
|
|
|
|
|
def load_video_frames(self, video_path: Path) -> Optional[torch.Tensor]: |
|
|
reader = imageio.get_reader(str(video_path)) |
|
|
frames = [] |
|
|
for frame_data in reader: |
|
|
frame = Image.fromarray(frame_data) |
|
|
frames.append(self.preprocess_frame(frame)) |
|
|
reader.close() |
|
|
|
|
|
if not frames: |
|
|
return None |
|
|
|
|
|
frames = torch.stack(frames, dim=0) |
|
|
return rearrange(frames, "T C H W -> C T H W") |
|
|
|
|
|
def encode_frames_to_latents(self, frames: torch.Tensor) -> torch.Tensor: |
|
|
frames = frames.unsqueeze(0).to(self.device, dtype=torch.bfloat16) |
|
|
with torch.no_grad(): |
|
|
latents = self.pipe.encode_video(frames, **self.tiler_kwargs)[0] |
|
|
|
|
|
if latents.dim() == 5 and latents.shape[0] == 1: |
|
|
latents = latents.squeeze(0) |
|
|
return latents.cpu() |
|
|
|
|
|
def image_to_frame_stack( |
|
|
image_path: Path, |
|
|
encoder: InlineVideoEncoder, |
|
|
repeat_count: int = 10 |
|
|
) -> torch.Tensor: |
|
|
"""Repeat a single image into a tensor with specified number of frames, shape [C, T, H, W]""" |
|
|
if image_path.suffix.lower() not in VALID_IMAGE_EXTENSIONS: |
|
|
raise ValueError(f"Unsupported image format: {image_path.suffix}") |
|
|
|
|
|
image = Image.open(str(image_path)) |
|
|
frame = encoder.preprocess_frame(image) |
|
|
frames = torch.stack([frame for _ in range(repeat_count)], dim=0) |
|
|
return rearrange(frames, "T C H W -> C T H W") |
|
|
|
|
|
|
|
|
def load_or_encode_condition( |
|
|
condition_pth_path: Optional[str], |
|
|
condition_video: Optional[str], |
|
|
condition_image: Optional[str], |
|
|
start_frame: int, |
|
|
num_frames: int, |
|
|
device: str, |
|
|
pipe: WanVideoAstraPipeline, |
|
|
) -> tuple[torch.Tensor, dict]: |
|
|
if condition_pth_path: |
|
|
return load_encoded_video_from_pth(condition_pth_path, start_frame, num_frames) |
|
|
|
|
|
encoder = InlineVideoEncoder(pipe=pipe, device=device) |
|
|
|
|
|
if condition_video: |
|
|
video_path = Path(condition_video).expanduser().resolve() |
|
|
if not video_path.exists(): |
|
|
raise FileNotFoundError(f"File not Found: {video_path}") |
|
|
frames = encoder.load_video_frames(video_path) |
|
|
if frames is None: |
|
|
raise ValueError(f"no valid frames in {video_path}") |
|
|
elif condition_image: |
|
|
image_path = Path(condition_image).expanduser().resolve() |
|
|
if not image_path.exists(): |
|
|
raise FileNotFoundError(f"File not Found: {image_path}") |
|
|
frames = image_to_frame_stack(image_path, encoder, repeat_count=10) |
|
|
else: |
|
|
raise ValueError("condition video or image is needed for video generation.") |
|
|
|
|
|
latents = encoder.encode_frames_to_latents(frames) |
|
|
encoded_data = {"latents": latents} |
|
|
|
|
|
if start_frame + num_frames > latents.shape[1]: |
|
|
raise ValueError( |
|
|
f"Not enough frames after encoding: requested {start_frame + num_frames}, available {latents.shape[1]}" |
|
|
) |
|
|
|
|
|
condition_latents = latents[:, start_frame:start_frame + num_frames, :, :] |
|
|
return condition_latents, encoded_data |
|
|
|
|
|
|
|
|
|
|
|
def compute_relative_pose_matrix(pose1, pose2): |
|
|
""" |
|
|
Compute relative pose between two consecutive frames, return 3x4 camera matrix [R_rel | t_rel] |
|
|
|
|
|
Args: |
|
|
pose1: Camera pose of frame i, shape (7,) array [tx1, ty1, tz1, qx1, qy1, qz1, qw1] |
|
|
pose2: Camera pose of frame i+1, shape (7,) array [tx2, ty2, tz2, qx2, qy2, qz2, qw2] |
|
|
|
|
|
Returns: |
|
|
relative_matrix: 3x4 relative pose matrix, |
|
|
first 3 columns are rotation matrix R_rel, |
|
|
last column is translation vector t_rel |
|
|
""" |
|
|
|
|
|
t1 = pose1[:3] |
|
|
q1 = pose1[3:] |
|
|
t2 = pose2[:3] |
|
|
q2 = pose2[3:] |
|
|
|
|
|
|
|
|
rot1 = R.from_quat(q1) |
|
|
rot2 = R.from_quat(q2) |
|
|
rot_rel = rot2 * rot1.inv() |
|
|
R_rel = rot_rel.as_matrix() |
|
|
|
|
|
|
|
|
R1_T = rot1.as_matrix().T |
|
|
t_rel = R1_T @ (t2 - t1) |
|
|
|
|
|
|
|
|
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)]) |
|
|
|
|
|
return relative_matrix |
|
|
|
|
|
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10): |
|
|
"""Load pre-encoded video data from pth file""" |
|
|
print(f"Loading encoded video from {pth_path}") |
|
|
|
|
|
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu") |
|
|
full_latents = encoded_data['latents'] |
|
|
|
|
|
print(f"Full latents shape: {full_latents.shape}") |
|
|
print(f"Extracting frames {start_frame} to {start_frame + num_frames}") |
|
|
|
|
|
if start_frame + num_frames > full_latents.shape[1]: |
|
|
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}") |
|
|
|
|
|
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :] |
|
|
print(f"Extracted condition latents shape: {condition_latents.shape}") |
|
|
|
|
|
return condition_latents, encoded_data |
|
|
|
|
|
def compute_relative_pose(pose_a, pose_b, use_torch=False): |
|
|
"""Compute relative pose matrix of camera B with respect to camera A""" |
|
|
assert pose_a.shape == (4, 4), f"Camera A extrinsic matrix should be (4,4), got {pose_a.shape}" |
|
|
assert pose_b.shape == (4, 4), f"Camera B extrinsic matrix should be (4,4), got {pose_b.shape}" |
|
|
|
|
|
if use_torch: |
|
|
if not isinstance(pose_a, torch.Tensor): |
|
|
pose_a = torch.from_numpy(pose_a).float() |
|
|
if not isinstance(pose_b, torch.Tensor): |
|
|
pose_b = torch.from_numpy(pose_b).float() |
|
|
|
|
|
pose_a_inv = torch.inverse(pose_a) |
|
|
relative_pose = torch.matmul(pose_b, pose_a_inv) |
|
|
else: |
|
|
if not isinstance(pose_a, np.ndarray): |
|
|
pose_a = np.array(pose_a, dtype=np.float32) |
|
|
if not isinstance(pose_b, np.ndarray): |
|
|
pose_b = np.array(pose_b, dtype=np.float32) |
|
|
|
|
|
pose_a_inv = np.linalg.inv(pose_a) |
|
|
relative_pose = np.matmul(pose_b, pose_a_inv) |
|
|
|
|
|
return relative_pose |
|
|
|
|
|
|
|
|
def replace_dit_model_in_manager(): |
|
|
"""Replace DiT model class with MoE version""" |
|
|
from diffsynth.models.wan_video_dit_moe import WanModelMoe |
|
|
from diffsynth.configs.model_config import model_loader_configs |
|
|
|
|
|
for i, config in enumerate(model_loader_configs): |
|
|
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config |
|
|
|
|
|
if 'wan_video_dit' in model_names: |
|
|
new_model_names = [] |
|
|
new_model_classes = [] |
|
|
|
|
|
for name, cls in zip(model_names, model_classes): |
|
|
if name == 'wan_video_dit': |
|
|
new_model_names.append(name) |
|
|
new_model_classes.append(WanModelMoe) |
|
|
print(f"Replaced model class: {name} -> WanModelMoe") |
|
|
else: |
|
|
new_model_names.append(name) |
|
|
new_model_classes.append(cls) |
|
|
|
|
|
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource) |
|
|
|
|
|
|
|
|
def add_framepack_components(dit_model): |
|
|
"""Add FramePack related components""" |
|
|
if not hasattr(dit_model, 'clean_x_embedder'): |
|
|
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0] |
|
|
|
|
|
class CleanXEmbedder(nn.Module): |
|
|
def __init__(self, inner_dim): |
|
|
super().__init__() |
|
|
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) |
|
|
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) |
|
|
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) |
|
|
|
|
|
def forward(self, x, scale="1x"): |
|
|
if scale == "1x": |
|
|
x = x.to(self.proj.weight.dtype) |
|
|
return self.proj(x) |
|
|
elif scale == "2x": |
|
|
x = x.to(self.proj_2x.weight.dtype) |
|
|
return self.proj_2x(x) |
|
|
elif scale == "4x": |
|
|
x = x.to(self.proj_4x.weight.dtype) |
|
|
return self.proj_4x(x) |
|
|
else: |
|
|
raise ValueError(f"Unsupported scale: {scale}") |
|
|
|
|
|
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim) |
|
|
model_dtype = next(dit_model.parameters()).dtype |
|
|
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype) |
|
|
print("Added FramePack clean_x_embedder component") |
|
|
|
|
|
|
|
|
def add_moe_components(dit_model, moe_config): |
|
|
"""Add MoE related components - corrected version""" |
|
|
if not hasattr(dit_model, 'moe_config'): |
|
|
dit_model.moe_config = moe_config |
|
|
print("Added MoE config to model") |
|
|
dit_model.top_k = moe_config.get("top_k", 1) |
|
|
|
|
|
|
|
|
dim = dit_model.blocks[0].self_attn.q.weight.shape[0] |
|
|
unified_dim = moe_config.get("unified_dim", 25) |
|
|
num_experts = moe_config.get("num_experts", 4) |
|
|
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE |
|
|
dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim) |
|
|
dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim) |
|
|
dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) |
|
|
dit_model.global_router = nn.Linear(unified_dim, num_experts) |
|
|
|
|
|
|
|
|
for i, block in enumerate(dit_model.blocks): |
|
|
|
|
|
block.moe = MultiModalMoE( |
|
|
unified_dim=unified_dim, |
|
|
output_dim=dim, |
|
|
num_experts=moe_config.get("num_experts", 4), |
|
|
top_k=moe_config.get("top_k", 2) |
|
|
) |
|
|
|
|
|
print(f"Block {i} added MoE component (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})") |
|
|
|
|
|
|
|
|
def generate_sekai_camera_embeddings_sliding( |
|
|
cam_data, |
|
|
start_frame, |
|
|
initial_condition_frames, |
|
|
new_frames, |
|
|
total_generated, |
|
|
use_real_poses=True, |
|
|
direction="left"): |
|
|
""" |
|
|
Generate camera embeddings for Sekai dataset - sliding window version |
|
|
|
|
|
Args: |
|
|
cam_data: Dictionary containing Sekai camera extrinsic parameters, key 'extrinsic' corresponds to an N*4*4 numpy array |
|
|
start_frame: Current generation start frame index |
|
|
initial_condition_frames: Initial condition frame count |
|
|
new_frames: Number of new frames to generate this time |
|
|
total_generated: Total frames already generated |
|
|
use_real_poses: Whether to use real Sekai camera poses |
|
|
direction: Camera movement direction, default "left" |
|
|
|
|
|
Returns: |
|
|
camera_embedding: Torch tensor of shape (M, 3*4 + 1), where M is the total number of generated frames |
|
|
""" |
|
|
time_compression_ratio = 4 |
|
|
|
|
|
|
|
|
|
|
|
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames |
|
|
|
|
|
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data: |
|
|
print("🔧 Using real Sekai camera data") |
|
|
cam_extrinsic = cam_data['extrinsic'] |
|
|
|
|
|
|
|
|
max_needed_frames = max( |
|
|
start_frame + initial_condition_frames + new_frames, |
|
|
framepack_needed_frames, |
|
|
30 |
|
|
) |
|
|
|
|
|
print(f"🔧 Calculating Sekai camera sequence length:") |
|
|
print(f" - Basic requirement: {start_frame + initial_condition_frames + new_frames}") |
|
|
print(f" - FramePack requirement: {framepack_needed_frames}") |
|
|
print(f" - Final generation: {max_needed_frames}") |
|
|
|
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
|
|
|
frame_idx = i * time_compression_ratio |
|
|
next_frame_idx = frame_idx + time_compression_ratio |
|
|
|
|
|
if next_frame_idx < len(cam_extrinsic): |
|
|
cam_prev = cam_extrinsic[frame_idx] |
|
|
cam_next = cam_extrinsic[next_frame_idx] |
|
|
relative_pose = compute_relative_pose(cam_prev, cam_next) |
|
|
relative_poses.append(torch.as_tensor(relative_pose[:3, :])) |
|
|
else: |
|
|
|
|
|
print(f"⚠️ Frame {frame_idx} exceeds camera data range, using zero motion") |
|
|
relative_poses.append(torch.zeros(3, 4)) |
|
|
|
|
|
pose_embedding = torch.stack(relative_poses, dim=0) |
|
|
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
|
|
|
|
|
|
|
|
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
|
|
|
|
|
condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
|
|
mask[start_frame:condition_end] = 1.0 |
|
|
|
|
|
camera_embedding = torch.cat([pose_embedding, mask], dim=1) |
|
|
print(f"🔧 Sekai real camera embedding shape: {camera_embedding.shape}") |
|
|
return camera_embedding.to(torch.bfloat16) |
|
|
|
|
|
else: |
|
|
|
|
|
max_needed_frames = max( |
|
|
start_frame + initial_condition_frames + new_frames, |
|
|
framepack_needed_frames, |
|
|
30) |
|
|
|
|
|
print(f"🔧 Generating Sekai synthetic camera frames: {max_needed_frames}") |
|
|
|
|
|
CONDITION_FRAMES = initial_condition_frames |
|
|
STAGE_1 = new_frames//2 |
|
|
STAGE_2 = new_frames - STAGE_1 |
|
|
|
|
|
if direction=="forward": |
|
|
print("--------------- FORWARD MODE ---------------") |
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
if i < CONDITION_FRAMES: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
|
|
|
|
|
forward_speed = 0.03 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
pose[2, 3] = -forward_speed |
|
|
else: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
relative_pose = pose[:3, :] |
|
|
relative_poses.append(torch.as_tensor(relative_pose)) |
|
|
|
|
|
elif direction=="left": |
|
|
print("--------------- LEFT TURNING MODE ---------------") |
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
if i < CONDITION_FRAMES: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
|
|
|
|
|
yaw_per_frame = 0.03 |
|
|
|
|
|
|
|
|
cos_yaw = np.cos(yaw_per_frame) |
|
|
sin_yaw = np.sin(yaw_per_frame) |
|
|
|
|
|
|
|
|
forward_speed = 0.00 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
pose[0, 0] = cos_yaw |
|
|
pose[0, 2] = sin_yaw |
|
|
pose[2, 0] = -sin_yaw |
|
|
pose[2, 2] = cos_yaw |
|
|
pose[2, 3] = -forward_speed |
|
|
else: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
relative_pose = pose[:3, :] |
|
|
relative_poses.append(torch.as_tensor(relative_pose)) |
|
|
|
|
|
elif direction=="right": |
|
|
print("--------------- RIGHT TURNING MODE ---------------") |
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
if i < CONDITION_FRAMES: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
|
|
|
|
|
yaw_per_frame = -0.03 |
|
|
|
|
|
|
|
|
cos_yaw = np.cos(yaw_per_frame) |
|
|
sin_yaw = np.sin(yaw_per_frame) |
|
|
|
|
|
|
|
|
forward_speed = 0.00 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
pose[0, 0] = cos_yaw |
|
|
pose[0, 2] = sin_yaw |
|
|
pose[2, 0] = -sin_yaw |
|
|
pose[2, 2] = cos_yaw |
|
|
pose[2, 3] = -forward_speed |
|
|
else: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
relative_pose = pose[:3, :] |
|
|
relative_poses.append(torch.as_tensor(relative_pose)) |
|
|
|
|
|
elif direction=="forward_left": |
|
|
print("--------------- FORWARD LEFT MODE ---------------") |
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
if i < CONDITION_FRAMES: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
|
|
|
|
|
yaw_per_frame = 0.03 |
|
|
|
|
|
|
|
|
cos_yaw = np.cos(yaw_per_frame) |
|
|
sin_yaw = np.sin(yaw_per_frame) |
|
|
|
|
|
|
|
|
forward_speed = 0.03 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
pose[0, 0] = cos_yaw |
|
|
pose[0, 2] = sin_yaw |
|
|
pose[2, 0] = -sin_yaw |
|
|
pose[2, 2] = cos_yaw |
|
|
pose[2, 3] = -forward_speed |
|
|
|
|
|
else: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
relative_pose = pose[:3, :] |
|
|
relative_poses.append(torch.as_tensor(relative_pose)) |
|
|
|
|
|
elif direction=="forward_right": |
|
|
print("--------------- FORWARD RIGHT MODE ---------------") |
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
if i < CONDITION_FRAMES: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
|
|
|
|
|
yaw_per_frame = -0.03 |
|
|
|
|
|
|
|
|
cos_yaw = np.cos(yaw_per_frame) |
|
|
sin_yaw = np.sin(yaw_per_frame) |
|
|
|
|
|
|
|
|
forward_speed = 0.03 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
pose[0, 0] = cos_yaw |
|
|
pose[0, 2] = sin_yaw |
|
|
pose[2, 0] = -sin_yaw |
|
|
pose[2, 2] = cos_yaw |
|
|
pose[2, 3] = -forward_speed |
|
|
|
|
|
else: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
relative_pose = pose[:3, :] |
|
|
relative_poses.append(torch.as_tensor(relative_pose)) |
|
|
|
|
|
elif direction=="s_curve": |
|
|
print("--------------- S CURVE MODE ---------------") |
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
if i < CONDITION_FRAMES: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
elif i < CONDITION_FRAMES+STAGE_1: |
|
|
|
|
|
yaw_per_frame = 0.03 |
|
|
|
|
|
|
|
|
cos_yaw = np.cos(yaw_per_frame) |
|
|
sin_yaw = np.sin(yaw_per_frame) |
|
|
|
|
|
|
|
|
forward_speed = 0.03 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
pose[0, 0] = cos_yaw |
|
|
pose[0, 2] = sin_yaw |
|
|
pose[2, 0] = -sin_yaw |
|
|
pose[2, 2] = cos_yaw |
|
|
pose[2, 3] = -forward_speed |
|
|
|
|
|
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
|
|
|
|
|
yaw_per_frame = -0.03 |
|
|
|
|
|
|
|
|
cos_yaw = np.cos(yaw_per_frame) |
|
|
sin_yaw = np.sin(yaw_per_frame) |
|
|
|
|
|
|
|
|
forward_speed = 0.03 |
|
|
|
|
|
if i < CONDITION_FRAMES+STAGE_1+STAGE_2//3: |
|
|
radius_shift = -0.01 |
|
|
else: |
|
|
radius_shift = 0.00 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
pose[0, 0] = cos_yaw |
|
|
pose[0, 2] = sin_yaw |
|
|
pose[2, 0] = -sin_yaw |
|
|
pose[2, 2] = cos_yaw |
|
|
pose[2, 3] = -forward_speed |
|
|
pose[0, 3] = radius_shift |
|
|
|
|
|
else: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
relative_pose = pose[:3, :] |
|
|
relative_poses.append(torch.as_tensor(relative_pose)) |
|
|
|
|
|
elif direction=="left_right": |
|
|
print("--------------- LEFT RIGHT MODE ---------------") |
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
if i < CONDITION_FRAMES: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
elif i < CONDITION_FRAMES+STAGE_1: |
|
|
|
|
|
yaw_per_frame = 0.03 |
|
|
|
|
|
|
|
|
cos_yaw = np.cos(yaw_per_frame) |
|
|
sin_yaw = np.sin(yaw_per_frame) |
|
|
|
|
|
|
|
|
forward_speed = 0.00 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
pose[0, 0] = cos_yaw |
|
|
pose[0, 2] = sin_yaw |
|
|
pose[2, 0] = -sin_yaw |
|
|
pose[2, 2] = cos_yaw |
|
|
pose[2, 3] = -forward_speed |
|
|
|
|
|
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2: |
|
|
|
|
|
yaw_per_frame = -0.03 |
|
|
|
|
|
|
|
|
cos_yaw = np.cos(yaw_per_frame) |
|
|
sin_yaw = np.sin(yaw_per_frame) |
|
|
|
|
|
|
|
|
forward_speed = 0.00 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
pose[0, 0] = cos_yaw |
|
|
pose[0, 2] = sin_yaw |
|
|
pose[2, 0] = -sin_yaw |
|
|
pose[2, 2] = cos_yaw |
|
|
pose[2, 3] = -forward_speed |
|
|
|
|
|
else: |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
relative_pose = pose[:3, :] |
|
|
relative_poses.append(torch.as_tensor(relative_pose)) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Not Defined Direction: {direction}") |
|
|
|
|
|
pose_embedding = torch.stack(relative_poses, dim=0) |
|
|
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
|
|
|
|
|
|
|
|
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
|
|
condition_end = min(start_frame + initial_condition_frames + 1, max_needed_frames) |
|
|
mask[start_frame:condition_end] = 1.0 |
|
|
|
|
|
camera_embedding = torch.cat([pose_embedding, mask], dim=1) |
|
|
print(f"🔧 Sekai synthetic camera embedding shape: {camera_embedding.shape}") |
|
|
return camera_embedding.to(torch.bfloat16) |
|
|
|
|
|
|
|
|
def generate_openx_camera_embeddings_sliding( |
|
|
encoded_data, start_frame, initial_condition_frames, new_frames, use_real_poses): |
|
|
"""Generate camera embeddings for OpenX dataset - sliding window version""" |
|
|
time_compression_ratio = 4 |
|
|
|
|
|
|
|
|
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames |
|
|
|
|
|
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']: |
|
|
print("🔧 Using OpenX real camera data") |
|
|
cam_extrinsic = encoded_data['cam_emb']['extrinsic'] |
|
|
|
|
|
|
|
|
max_needed_frames = max( |
|
|
start_frame + initial_condition_frames + new_frames, |
|
|
framepack_needed_frames, |
|
|
30 |
|
|
) |
|
|
|
|
|
print(f"🔧 Calculating OpenX camera sequence length:") |
|
|
print(f" - Basic requirement: {start_frame + initial_condition_frames + new_frames}") |
|
|
print(f" - FramePack requirement: {framepack_needed_frames}") |
|
|
print(f" - Final generation: {max_needed_frames}") |
|
|
|
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
|
|
|
frame_idx = i * time_compression_ratio |
|
|
next_frame_idx = frame_idx + time_compression_ratio |
|
|
|
|
|
if next_frame_idx < len(cam_extrinsic): |
|
|
cam_prev = cam_extrinsic[frame_idx] |
|
|
cam_next = cam_extrinsic[next_frame_idx] |
|
|
relative_pose = compute_relative_pose(cam_prev, cam_next) |
|
|
relative_poses.append(torch.as_tensor(relative_pose[:3, :])) |
|
|
else: |
|
|
|
|
|
print(f"⚠️ Frame {frame_idx} exceeds OpenX camera data range, using zero motion") |
|
|
relative_poses.append(torch.zeros(3, 4)) |
|
|
|
|
|
pose_embedding = torch.stack(relative_poses, dim=0) |
|
|
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
|
|
|
|
|
|
|
|
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
|
|
|
|
|
condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
|
|
mask[start_frame:condition_end] = 1.0 |
|
|
|
|
|
camera_embedding = torch.cat([pose_embedding, mask], dim=1) |
|
|
print(f"🔧 OpenX real camera embedding shape: {camera_embedding.shape}") |
|
|
return camera_embedding.to(torch.bfloat16) |
|
|
|
|
|
else: |
|
|
print("🔧 Using OpenX synthetic camera data") |
|
|
|
|
|
max_needed_frames = max( |
|
|
start_frame + initial_condition_frames + new_frames, |
|
|
framepack_needed_frames, |
|
|
30 |
|
|
) |
|
|
|
|
|
print(f"🔧 Generating OpenX synthetic camera frames: {max_needed_frames}") |
|
|
relative_poses = [] |
|
|
for i in range(max_needed_frames): |
|
|
|
|
|
|
|
|
roll_per_frame = 0.02 |
|
|
pitch_per_frame = 0.01 |
|
|
yaw_per_frame = 0.015 |
|
|
forward_speed = 0.003 |
|
|
|
|
|
pose = np.eye(4, dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
cos_roll = np.cos(roll_per_frame) |
|
|
sin_roll = np.sin(roll_per_frame) |
|
|
|
|
|
cos_pitch = np.cos(pitch_per_frame) |
|
|
sin_pitch = np.sin(pitch_per_frame) |
|
|
|
|
|
cos_yaw = np.cos(yaw_per_frame) |
|
|
sin_yaw = np.sin(yaw_per_frame) |
|
|
|
|
|
|
|
|
pose[0, 0] = cos_yaw * cos_pitch |
|
|
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll |
|
|
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll |
|
|
pose[1, 0] = sin_yaw * cos_pitch |
|
|
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll |
|
|
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll |
|
|
pose[2, 0] = -sin_pitch |
|
|
pose[2, 1] = cos_pitch * sin_roll |
|
|
pose[2, 2] = cos_pitch * cos_roll |
|
|
|
|
|
|
|
|
pose[0, 3] = forward_speed * 0.5 |
|
|
pose[1, 3] = forward_speed * 0.3 |
|
|
pose[2, 3] = -forward_speed |
|
|
|
|
|
relative_pose = pose[:3, :] |
|
|
relative_poses.append(torch.as_tensor(relative_pose)) |
|
|
|
|
|
pose_embedding = torch.stack(relative_poses, dim=0) |
|
|
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') |
|
|
|
|
|
|
|
|
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
|
|
condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
|
|
mask[start_frame:condition_end] = 1.0 |
|
|
|
|
|
camera_embedding = torch.cat([pose_embedding, mask], dim=1) |
|
|
print(f"🔧 OpenX synthetic camera embedding shape: {camera_embedding.shape}") |
|
|
return camera_embedding.to(torch.bfloat16) |
|
|
|
|
|
|
|
|
def generate_nuscenes_camera_embeddings_sliding( |
|
|
scene_info, start_frame, initial_condition_frames, new_frames): |
|
|
""" |
|
|
Generate camera embeddings for NuScenes dataset - sliding window version |
|
|
|
|
|
corrected version, consistent with train_moe.py |
|
|
""" |
|
|
time_compression_ratio = 4 |
|
|
|
|
|
|
|
|
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames |
|
|
|
|
|
if scene_info is not None and 'keyframe_poses' in scene_info: |
|
|
print("🔧 Using NuScenes real pose data") |
|
|
keyframe_poses = scene_info['keyframe_poses'] |
|
|
|
|
|
if len(keyframe_poses) == 0: |
|
|
print("⚠️ NuScenes keyframe_poses is empty, using zero pose") |
|
|
max_needed_frames = max(framepack_needed_frames, 30) |
|
|
|
|
|
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32) |
|
|
|
|
|
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
|
|
condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
|
|
mask[start_frame:condition_end] = 1.0 |
|
|
|
|
|
camera_embedding = torch.cat([pose_sequence, mask], dim=1) |
|
|
print(f"🔧 NuScenes zero pose embedding shape: {camera_embedding.shape}") |
|
|
return camera_embedding.to(torch.bfloat16) |
|
|
|
|
|
|
|
|
reference_pose = keyframe_poses[0] |
|
|
|
|
|
max_needed_frames = max(framepack_needed_frames, 30) |
|
|
|
|
|
pose_vecs = [] |
|
|
for i in range(max_needed_frames): |
|
|
if i < len(keyframe_poses): |
|
|
current_pose = keyframe_poses[i] |
|
|
|
|
|
|
|
|
translation = torch.tensor( |
|
|
np.array(current_pose['translation']) - np.array(reference_pose['translation']), |
|
|
dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32) |
|
|
|
|
|
pose_vec = torch.cat([translation, rotation], dim=0) |
|
|
else: |
|
|
|
|
|
pose_vec = torch.cat([ |
|
|
torch.zeros(3, dtype=torch.float32), |
|
|
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) |
|
|
], dim=0) |
|
|
|
|
|
pose_vecs.append(pose_vec) |
|
|
|
|
|
pose_sequence = torch.stack(pose_vecs, dim=0) |
|
|
|
|
|
|
|
|
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
|
|
condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
|
|
mask[start_frame:condition_end] = 1.0 |
|
|
|
|
|
camera_embedding = torch.cat([pose_sequence, mask], dim=1) |
|
|
print(f"🔧 NuScenes real pose embedding shape: {camera_embedding.shape}") |
|
|
return camera_embedding.to(torch.bfloat16) |
|
|
|
|
|
else: |
|
|
print("🔧 Using NuScenes synthetic pose data") |
|
|
max_needed_frames = max(framepack_needed_frames, 30) |
|
|
|
|
|
|
|
|
pose_vecs = [] |
|
|
for i in range(max_needed_frames): |
|
|
|
|
|
angle = i * 0.04 |
|
|
radius = 15.0 |
|
|
|
|
|
|
|
|
x = radius * np.sin(angle) |
|
|
y = 0.0 |
|
|
z = radius * (1 - np.cos(angle)) |
|
|
|
|
|
translation = torch.tensor([x, y, z], dtype=torch.float32) |
|
|
|
|
|
|
|
|
yaw = angle + np.pi/2 |
|
|
|
|
|
rotation = torch.tensor([ |
|
|
np.cos(yaw/2), |
|
|
0.0, |
|
|
0.0, |
|
|
np.sin(yaw/2) |
|
|
], dtype=torch.float32) |
|
|
|
|
|
pose_vec = torch.cat([translation, rotation], dim=0) |
|
|
pose_vecs.append(pose_vec) |
|
|
|
|
|
pose_sequence = torch.stack(pose_vecs, dim=0) |
|
|
|
|
|
|
|
|
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32) |
|
|
condition_end = min(start_frame + initial_condition_frames, max_needed_frames) |
|
|
mask[start_frame:condition_end] = 1.0 |
|
|
|
|
|
camera_embedding = torch.cat([pose_sequence, mask], dim=1) |
|
|
print(f"🔧 NuScenes synthetic left turn pose embedding shape: {camera_embedding.shape}") |
|
|
return camera_embedding.to(torch.bfloat16) |
|
|
|
|
|
def prepare_framepack_sliding_window_with_camera_moe( |
|
|
history_latents, |
|
|
target_frames_to_generate, |
|
|
camera_embedding_full, |
|
|
start_frame, |
|
|
modality_type, |
|
|
max_history_frames=49): |
|
|
"""FramePack sliding window mechanism - MoE version""" |
|
|
|
|
|
C, T, H, W = history_latents.shape |
|
|
|
|
|
|
|
|
|
|
|
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate |
|
|
indices = torch.arange(0, total_indices_length) |
|
|
split_sizes = [1, 16, 2, 1, target_frames_to_generate] |
|
|
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \ |
|
|
indices.split(split_sizes, dim=0) |
|
|
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0) |
|
|
|
|
|
|
|
|
if camera_embedding_full.shape[0] < total_indices_length: |
|
|
print(f"⚠️ camera_embedding length insufficient, performing zero padding: current length {camera_embedding_full.shape[0]}, required length {total_indices_length}") |
|
|
shortage = total_indices_length - camera_embedding_full.shape[0] |
|
|
padding = torch.zeros(shortage, camera_embedding_full.shape[1], |
|
|
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device) |
|
|
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0) |
|
|
|
|
|
|
|
|
combined_camera = torch.zeros( |
|
|
total_indices_length, |
|
|
camera_embedding_full.shape[1], |
|
|
dtype=camera_embedding_full.dtype, |
|
|
device=camera_embedding_full.device) |
|
|
|
|
|
|
|
|
history_slice = camera_embedding_full[max(T - 19, 0):T, :].clone() |
|
|
combined_camera[19 - history_slice.shape[0]:19, :] = history_slice |
|
|
|
|
|
|
|
|
target_slice = camera_embedding_full[T:T + target_frames_to_generate, :].clone() |
|
|
combined_camera[19:19 + target_slice.shape[0], :] = target_slice |
|
|
|
|
|
|
|
|
combined_camera[:, -1] = 0.0 |
|
|
|
|
|
|
|
|
if T > 0: |
|
|
available_frames = min(T, 19) |
|
|
start_pos = 19 - available_frames |
|
|
combined_camera[start_pos:19, -1] = 1.0 |
|
|
|
|
|
print(f"🔧 MoE Camera mask update:") |
|
|
print(f" - History frames: {T}") |
|
|
print(f" - Valid condition frames: {available_frames if T > 0 else 0}") |
|
|
print(f" - Modality type: {modality_type}") |
|
|
|
|
|
|
|
|
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device) |
|
|
|
|
|
if T > 0: |
|
|
available_frames = min(T, 19) |
|
|
start_pos = 19 - available_frames |
|
|
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :] |
|
|
|
|
|
clean_latents_4x = clean_latents_combined[:, 0:16, :, :] |
|
|
clean_latents_2x = clean_latents_combined[:, 16:18, :, :] |
|
|
clean_latents_1x = clean_latents_combined[:, 18:19, :, :] |
|
|
|
|
|
if T > 0: |
|
|
start_latent = history_latents[:, 0:1, :, :] |
|
|
else: |
|
|
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device) |
|
|
|
|
|
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1) |
|
|
|
|
|
return { |
|
|
'latent_indices': latent_indices, |
|
|
'clean_latents': clean_latents, |
|
|
'clean_latents_2x': clean_latents_2x, |
|
|
'clean_latents_4x': clean_latents_4x, |
|
|
'clean_latent_indices': clean_latent_indices, |
|
|
'clean_latent_2x_indices': clean_latent_2x_indices, |
|
|
'clean_latent_4x_indices': clean_latent_4x_indices, |
|
|
'camera_embedding': combined_camera, |
|
|
'modality_type': modality_type, |
|
|
'current_length': T, |
|
|
'next_length': T + target_frames_to_generate |
|
|
} |
|
|
|
|
|
def overlay_controls(frame_img, pose_vec, icons): |
|
|
""" |
|
|
Overlay control icons (WASD and arrows) on frame based on camera pose |
|
|
pose_vec: 12 elements (flattened 3x4 matrix) + mask |
|
|
""" |
|
|
if pose_vec is None or np.all(pose_vec[:12] == 0): |
|
|
return frame_img |
|
|
|
|
|
|
|
|
|
|
|
tx = pose_vec[3] |
|
|
|
|
|
tz = pose_vec[11] |
|
|
|
|
|
|
|
|
|
|
|
r00 = pose_vec[0] |
|
|
r02 = pose_vec[2] |
|
|
yaw = np.arctan2(r02, r00) |
|
|
|
|
|
|
|
|
r12 = pose_vec[6] |
|
|
r22 = pose_vec[10] |
|
|
pitch = np.arctan2(-r12, r22) |
|
|
|
|
|
|
|
|
TRANS_THRESH = 0.01 |
|
|
ROT_THRESH = 0.005 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_forward = tz < -TRANS_THRESH |
|
|
is_backward = tz > TRANS_THRESH |
|
|
is_left = tx < -TRANS_THRESH |
|
|
is_right = tx > TRANS_THRESH |
|
|
|
|
|
|
|
|
|
|
|
is_turn_left = yaw > ROT_THRESH |
|
|
is_turn_right = yaw < -ROT_THRESH |
|
|
|
|
|
|
|
|
is_turn_up = pitch < -ROT_THRESH |
|
|
is_turn_down = pitch > ROT_THRESH |
|
|
|
|
|
W, H = frame_img.size |
|
|
spacing = 60 |
|
|
|
|
|
def paste_icon(name_active, name_inactive, is_active, x, y): |
|
|
name = name_active if is_active else name_inactive |
|
|
if name in icons: |
|
|
icon = icons[name] |
|
|
|
|
|
frame_img.paste(icon, (int(x), int(y)), icon) |
|
|
|
|
|
|
|
|
base_x_right = 100 |
|
|
base_y = H - 100 |
|
|
|
|
|
|
|
|
paste_icon('move_forward.png', 'not_move_forward.png', is_forward, base_x_right, base_y - spacing) |
|
|
|
|
|
paste_icon('move_left.png', 'not_move_left.png', is_left, base_x_right - spacing, base_y) |
|
|
|
|
|
paste_icon('move_backward.png', 'not_move_backward.png', is_backward, base_x_right, base_y) |
|
|
|
|
|
paste_icon('move_right.png', 'not_move_right.png', is_right, base_x_right + spacing, base_y) |
|
|
|
|
|
|
|
|
base_x_left = W - 150 |
|
|
|
|
|
|
|
|
paste_icon('turn_up.png', 'not_turn_up.png', is_turn_up, base_x_left, base_y - spacing) |
|
|
|
|
|
paste_icon('turn_left.png', 'not_turn_left.png', is_turn_left, base_x_left - spacing, base_y) |
|
|
|
|
|
paste_icon('turn_down.png', 'not_turn_down.png', is_turn_down, base_x_left, base_y) |
|
|
|
|
|
paste_icon('turn_right.png', 'not_turn_right.png', is_turn_right, base_x_left + spacing, base_y) |
|
|
|
|
|
return frame_img |
|
|
|
|
|
|
|
|
def inference_moe_framepack_sliding_window( |
|
|
condition_pth_path=None, |
|
|
condition_video=None, |
|
|
condition_image=None, |
|
|
dit_path=None, |
|
|
wan_model_path=None, |
|
|
output_path="../examples/output_videos/output_moe_framepack_sliding.mp4", |
|
|
start_frame=0, |
|
|
initial_condition_frames=8, |
|
|
frames_per_generation=4, |
|
|
total_frames_to_generate=32, |
|
|
max_history_frames=49, |
|
|
device="cuda", |
|
|
prompt="A video of a scene shot using a pedestrian's front camera while walking", |
|
|
modality_type="sekai", |
|
|
use_real_poses=True, |
|
|
scene_info_path=None, |
|
|
|
|
|
use_camera_cfg=True, |
|
|
camera_guidance_scale=2.0, |
|
|
text_guidance_scale=1.0, |
|
|
|
|
|
moe_num_experts=4, |
|
|
moe_top_k=2, |
|
|
moe_hidden_dim=None, |
|
|
direction="left", |
|
|
use_gt_prompt=True, |
|
|
add_icons=False |
|
|
): |
|
|
""" |
|
|
MoE FramePack sliding window video generation - multi-modal support |
|
|
""" |
|
|
|
|
|
dir_path = os.path.dirname(output_path) |
|
|
os.makedirs(dir_path, exist_ok=True) |
|
|
|
|
|
print(f"🔧 Starting MoE FramePack sliding window generation...") |
|
|
print(f" Modality type: {modality_type}") |
|
|
print(f" Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}") |
|
|
print(f" Text guidance scale: {text_guidance_scale}") |
|
|
print(f" MoE config: experts={moe_num_experts}, top_k={moe_top_k}") |
|
|
|
|
|
|
|
|
replace_dit_model_in_manager() |
|
|
|
|
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") |
|
|
model_manager.load_models([ |
|
|
os.path.join(wan_model_path, "diffusion_pytorch_model.safetensors"), |
|
|
os.path.join(wan_model_path, "models_t5_umt5-xxl-enc-bf16.pth"), |
|
|
os.path.join(wan_model_path, "Wan2.1_VAE.pth"), |
|
|
]) |
|
|
pipe = WanVideoAstraPipeline.from_model_manager(model_manager, device="cuda") |
|
|
|
|
|
|
|
|
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] |
|
|
for block in pipe.dit.blocks: |
|
|
block.cam_encoder = nn.Linear(13, dim) |
|
|
block.projector = nn.Linear(dim, dim) |
|
|
block.cam_encoder.weight.data.zero_() |
|
|
block.cam_encoder.bias.data.zero_() |
|
|
block.projector.weight = nn.Parameter(torch.eye(dim)) |
|
|
block.projector.bias = nn.Parameter(torch.zeros(dim)) |
|
|
|
|
|
|
|
|
add_framepack_components(pipe.dit) |
|
|
|
|
|
|
|
|
moe_config = { |
|
|
"num_experts": moe_num_experts, |
|
|
"top_k": moe_top_k, |
|
|
"hidden_dim": moe_hidden_dim or dim * 2, |
|
|
"sekai_input_dim": 13, |
|
|
"nuscenes_input_dim": 8, |
|
|
"openx_input_dim": 13 |
|
|
} |
|
|
add_moe_components(pipe.dit, moe_config) |
|
|
|
|
|
|
|
|
dit_state_dict = torch.load(dit_path, map_location="cpu") |
|
|
pipe.dit.load_state_dict(dit_state_dict, strict=False) |
|
|
pipe = pipe.to(device) |
|
|
model_dtype = next(pipe.dit.parameters()).dtype |
|
|
|
|
|
if hasattr(pipe.dit, 'clean_x_embedder'): |
|
|
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype) |
|
|
|
|
|
|
|
|
pipe.scheduler.set_timesteps(50) |
|
|
|
|
|
|
|
|
print("Loading initial condition frames...") |
|
|
initial_latents, encoded_data = load_or_encode_condition( |
|
|
condition_pth_path, |
|
|
condition_video, |
|
|
condition_image, |
|
|
start_frame, |
|
|
initial_condition_frames, |
|
|
device, |
|
|
pipe, |
|
|
) |
|
|
|
|
|
|
|
|
target_height, target_width = 60, 104 |
|
|
C, T, H, W = initial_latents.shape |
|
|
|
|
|
if H > target_height or W > target_width: |
|
|
h_start = (H - target_height) // 2 |
|
|
w_start = (W - target_width) // 2 |
|
|
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width] |
|
|
H, W = target_height, target_width |
|
|
|
|
|
history_latents = initial_latents.to(device, dtype=model_dtype) |
|
|
|
|
|
print(f"Initial history_latents shape: {history_latents.shape}") |
|
|
|
|
|
|
|
|
if use_gt_prompt and 'prompt_emb' in encoded_data: |
|
|
print("✅ Using pre-encoded GT prompt embedding") |
|
|
prompt_emb_pos = encoded_data['prompt_emb'] |
|
|
|
|
|
if 'context' in prompt_emb_pos: |
|
|
prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype) |
|
|
if 'context_mask' in prompt_emb_pos: |
|
|
prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype) |
|
|
|
|
|
|
|
|
if text_guidance_scale > 1.0: |
|
|
prompt_emb_neg = pipe.encode_prompt("") |
|
|
print(f"Using Text CFG with GT prompt, guidance scale: {text_guidance_scale}") |
|
|
else: |
|
|
prompt_emb_neg = None |
|
|
print("Not using Text CFG") |
|
|
|
|
|
|
|
|
if 'prompt' in encoded_data['prompt_emb']: |
|
|
gt_prompt_text = encoded_data['prompt_emb']['prompt'] |
|
|
print(f"📝 GT Prompt text: {gt_prompt_text}") |
|
|
else: |
|
|
|
|
|
print(f"🔄 Re-encoding prompt: {prompt}") |
|
|
if text_guidance_scale > 1.0: |
|
|
prompt_emb_pos = pipe.encode_prompt(prompt) |
|
|
prompt_emb_neg = pipe.encode_prompt("") |
|
|
print(f"Using Text CFG, guidance scale: {text_guidance_scale}") |
|
|
else: |
|
|
prompt_emb_pos = pipe.encode_prompt(prompt) |
|
|
prompt_emb_neg = None |
|
|
print("Not using Text CFG") |
|
|
|
|
|
|
|
|
scene_info = None |
|
|
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path): |
|
|
with open(scene_info_path, 'r') as f: |
|
|
scene_info = json.load(f) |
|
|
print(f"Loading NuScenes scene information: {scene_info_path}") |
|
|
|
|
|
|
|
|
if modality_type == "sekai": |
|
|
camera_embedding_full = generate_sekai_camera_embeddings_sliding( |
|
|
encoded_data.get('cam_emb', None), |
|
|
start_frame, |
|
|
initial_condition_frames, |
|
|
total_frames_to_generate, |
|
|
0, |
|
|
use_real_poses=use_real_poses, |
|
|
direction=direction |
|
|
).to(device, dtype=model_dtype) |
|
|
elif modality_type == "nuscenes": |
|
|
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding( |
|
|
scene_info, |
|
|
start_frame, |
|
|
initial_condition_frames, |
|
|
total_frames_to_generate |
|
|
).to(device, dtype=model_dtype) |
|
|
elif modality_type == "openx": |
|
|
camera_embedding_full = generate_openx_camera_embeddings_sliding( |
|
|
encoded_data, |
|
|
start_frame, |
|
|
initial_condition_frames, |
|
|
total_frames_to_generate, |
|
|
use_real_poses=use_real_poses |
|
|
).to(device, dtype=model_dtype) |
|
|
else: |
|
|
raise ValueError(f"Unsupported modality type: {modality_type}") |
|
|
|
|
|
print(f"Complete camera sequence shape: {camera_embedding_full.shape}") |
|
|
|
|
|
|
|
|
if use_camera_cfg: |
|
|
camera_embedding_uncond = torch.zeros_like(camera_embedding_full) |
|
|
print(f"Creating unconditional camera embedding for CFG") |
|
|
|
|
|
|
|
|
total_generated = 0 |
|
|
all_generated_frames = [] |
|
|
|
|
|
while total_generated < total_frames_to_generate: |
|
|
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated) |
|
|
print(f"\nGeneration step {total_generated // frames_per_generation + 1}") |
|
|
print(f"Current history length: {history_latents.shape[1]}, generating: {current_generation}") |
|
|
|
|
|
|
|
|
framepack_data = prepare_framepack_sliding_window_with_camera_moe( |
|
|
history_latents, |
|
|
current_generation, |
|
|
camera_embedding_full, |
|
|
start_frame, |
|
|
modality_type, |
|
|
max_history_frames |
|
|
) |
|
|
|
|
|
|
|
|
clean_latents = framepack_data['clean_latents'].unsqueeze(0) |
|
|
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0) |
|
|
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0) |
|
|
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0) |
|
|
|
|
|
|
|
|
modality_inputs = {modality_type: camera_embedding} |
|
|
|
|
|
|
|
|
if use_camera_cfg: |
|
|
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0) |
|
|
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch} |
|
|
|
|
|
|
|
|
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu() |
|
|
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu() |
|
|
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu() |
|
|
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu() |
|
|
|
|
|
|
|
|
new_latents = torch.randn( |
|
|
1, C, current_generation, H, W, |
|
|
device=device, dtype=model_dtype |
|
|
) |
|
|
|
|
|
extra_input = pipe.prepare_extra_input(new_latents) |
|
|
|
|
|
print(f"Camera embedding shape: {camera_embedding.shape}") |
|
|
print(f"Camera mask distribution - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}") |
|
|
|
|
|
|
|
|
timesteps = pipe.scheduler.timesteps |
|
|
|
|
|
for i, timestep in enumerate(timesteps): |
|
|
if i % 10 == 0: |
|
|
print(f" Denoising step {i+1}/{len(timesteps)}") |
|
|
|
|
|
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
if use_camera_cfg and camera_guidance_scale > 1.0: |
|
|
|
|
|
noise_pred_cond, moe_loess = pipe.dit( |
|
|
new_latents, |
|
|
timestep=timestep_tensor, |
|
|
cam_emb=camera_embedding, |
|
|
modality_inputs=modality_inputs, |
|
|
latent_indices=latent_indices, |
|
|
clean_latents=clean_latents, |
|
|
clean_latent_indices=clean_latent_indices, |
|
|
clean_latents_2x=clean_latents_2x, |
|
|
clean_latent_2x_indices=clean_latent_2x_indices, |
|
|
clean_latents_4x=clean_latents_4x, |
|
|
clean_latent_4x_indices=clean_latent_4x_indices, |
|
|
**prompt_emb_pos, |
|
|
**extra_input |
|
|
) |
|
|
|
|
|
|
|
|
noise_pred_uncond, moe_loess = pipe.dit( |
|
|
new_latents, |
|
|
timestep=timestep_tensor, |
|
|
cam_emb=camera_embedding_uncond_batch, |
|
|
modality_inputs=modality_inputs_uncond, |
|
|
latent_indices=latent_indices, |
|
|
clean_latents=clean_latents, |
|
|
clean_latent_indices=clean_latent_indices, |
|
|
clean_latents_2x=clean_latents_2x, |
|
|
clean_latent_2x_indices=clean_latent_2x_indices, |
|
|
clean_latents_4x=clean_latents_4x, |
|
|
clean_latent_4x_indices=clean_latent_4x_indices, |
|
|
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos), |
|
|
**extra_input |
|
|
) |
|
|
|
|
|
|
|
|
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
|
|
|
|
|
if text_guidance_scale > 1.0 and prompt_emb_neg: |
|
|
noise_pred_text_uncond, moe_loess = pipe.dit( |
|
|
new_latents, |
|
|
timestep=timestep_tensor, |
|
|
cam_emb=camera_embedding, |
|
|
modality_inputs=modality_inputs, |
|
|
latent_indices=latent_indices, |
|
|
clean_latents=clean_latents, |
|
|
clean_latent_indices=clean_latent_indices, |
|
|
clean_latents_2x=clean_latents_2x, |
|
|
clean_latent_2x_indices=clean_latent_2x_indices, |
|
|
clean_latents_4x=clean_latents_4x, |
|
|
clean_latent_4x_indices=clean_latent_4x_indices, |
|
|
**prompt_emb_neg, |
|
|
**extra_input |
|
|
) |
|
|
|
|
|
|
|
|
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond) |
|
|
|
|
|
elif text_guidance_scale > 1.0 and prompt_emb_neg: |
|
|
|
|
|
noise_pred_cond, moe_loess = pipe.dit( |
|
|
new_latents, |
|
|
timestep=timestep_tensor, |
|
|
cam_emb=camera_embedding, |
|
|
modality_inputs=modality_inputs, |
|
|
latent_indices=latent_indices, |
|
|
clean_latents=clean_latents, |
|
|
clean_latent_indices=clean_latent_indices, |
|
|
clean_latents_2x=clean_latents_2x, |
|
|
clean_latent_2x_indices=clean_latent_2x_indices, |
|
|
clean_latents_4x=clean_latents_4x, |
|
|
clean_latent_4x_indices=clean_latent_4x_indices, |
|
|
**prompt_emb_pos, |
|
|
**extra_input |
|
|
) |
|
|
|
|
|
noise_pred_uncond, moe_loess= pipe.dit( |
|
|
new_latents, |
|
|
timestep=timestep_tensor, |
|
|
cam_emb=camera_embedding, |
|
|
modality_inputs=modality_inputs, |
|
|
latent_indices=latent_indices, |
|
|
clean_latents=clean_latents, |
|
|
clean_latent_indices=clean_latent_indices, |
|
|
clean_latents_2x=clean_latents_2x, |
|
|
clean_latent_2x_indices=clean_latent_2x_indices, |
|
|
clean_latents_4x=clean_latents_4x, |
|
|
clean_latent_4x_indices=clean_latent_4x_indices, |
|
|
**prompt_emb_neg, |
|
|
**extra_input |
|
|
) |
|
|
|
|
|
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
|
|
else: |
|
|
|
|
|
noise_pred, moe_loess = pipe.dit( |
|
|
new_latents, |
|
|
timestep=timestep_tensor, |
|
|
cam_emb=camera_embedding, |
|
|
modality_inputs=modality_inputs, |
|
|
latent_indices=latent_indices, |
|
|
clean_latents=clean_latents, |
|
|
clean_latent_indices=clean_latent_indices, |
|
|
clean_latents_2x=clean_latents_2x, |
|
|
clean_latent_2x_indices=clean_latent_2x_indices, |
|
|
clean_latents_4x=clean_latents_4x, |
|
|
clean_latent_4x_indices=clean_latent_4x_indices, |
|
|
**prompt_emb_pos, |
|
|
**extra_input |
|
|
) |
|
|
|
|
|
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents) |
|
|
|
|
|
|
|
|
new_latents_squeezed = new_latents.squeeze(0) |
|
|
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1) |
|
|
|
|
|
|
|
|
if history_latents.shape[1] > max_history_frames: |
|
|
first_frame = history_latents[:, 0:1, :, :] |
|
|
recent_frames = history_latents[:, -(max_history_frames-1):, :, :] |
|
|
history_latents = torch.cat([first_frame, recent_frames], dim=1) |
|
|
print(f"⚠️ History window full, keeping first frame + latest {max_history_frames-1} frames") |
|
|
|
|
|
print(f"History_latents shape after update: {history_latents.shape}") |
|
|
|
|
|
all_generated_frames.append(new_latents_squeezed) |
|
|
total_generated += current_generation |
|
|
|
|
|
print(f"✅ Generated {total_generated}/{total_frames_to_generate} frames") |
|
|
|
|
|
|
|
|
print("\nDecoding generated video...") |
|
|
|
|
|
all_generated = torch.cat(all_generated_frames, dim=1) |
|
|
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0) |
|
|
|
|
|
print(f"Final video shape: {final_video.shape}") |
|
|
|
|
|
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)) |
|
|
|
|
|
print(f"Saving video to {output_path} ...") |
|
|
|
|
|
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() |
|
|
video_np = (video_np * 0.5 + 0.5).clip(0, 1) |
|
|
video_np = (video_np * 255).astype(np.uint8) |
|
|
|
|
|
icons = {} |
|
|
video_camera_poses = None |
|
|
if add_icons: |
|
|
|
|
|
icons_dir = os.path.join(ROOT_DIR, 'icons') |
|
|
icon_names = ['move_forward.png', 'not_move_forward.png', |
|
|
'move_backward.png', 'not_move_backward.png', |
|
|
'move_left.png', 'not_move_left.png', |
|
|
'move_right.png', 'not_move_right.png', |
|
|
'turn_up.png', 'not_turn_up.png', |
|
|
'turn_down.png', 'not_turn_down.png', |
|
|
'turn_left.png', 'not_turn_left.png', |
|
|
'turn_right.png', 'not_turn_right.png'] |
|
|
for name in icon_names: |
|
|
path = os.path.join(icons_dir, name) |
|
|
if os.path.exists(path): |
|
|
try: |
|
|
icon = Image.open(path).convert("RGBA") |
|
|
|
|
|
icon = icon.resize((50, 50), Image.Resampling.LANCZOS) |
|
|
icons[name] = icon |
|
|
except Exception as e: |
|
|
print(f"Error loading icon {name}: {e}") |
|
|
else: |
|
|
print(f"⚠️ Warning: Icon {name} not found at {path}") |
|
|
|
|
|
|
|
|
time_compression_ratio = 4 |
|
|
camera_poses = camera_embedding_full.detach().float().cpu().numpy() |
|
|
video_camera_poses = [x for x in camera_poses for _ in range(time_compression_ratio)] |
|
|
|
|
|
with imageio.get_writer(output_path, fps=20) as writer: |
|
|
for i, frame in enumerate(video_np): |
|
|
|
|
|
img = Image.fromarray(frame) |
|
|
|
|
|
if add_icons and video_camera_poses is not None and icons: |
|
|
|
|
|
pose_idx = start_frame + i |
|
|
if pose_idx < len(video_camera_poses): |
|
|
pose_vec = video_camera_poses[pose_idx] |
|
|
img = overlay_controls(img, pose_vec, icons) |
|
|
|
|
|
writer.append_data(np.array(img)) |
|
|
|
|
|
print(f"✅ MoE FramePack sliding window generation completed! Saved to: {output_path}") |
|
|
print(f" Total generated {total_generated} frames (compressed), corresponding to original {total_generated * 4} frames") |
|
|
print(f" Using modality: {modality_type}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="MoE FramePack sliding window video generation - supports multi-modal") |
|
|
|
|
|
|
|
|
parser.add_argument("--condition_pth", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to pre-encoded condition pth file") |
|
|
parser.add_argument("--condition_video", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Input video for novel view synthesis.") |
|
|
parser.add_argument("--condition_image", |
|
|
type=str, |
|
|
default=None, |
|
|
required=True, |
|
|
help="Input image for novel view synthesis.") |
|
|
parser.add_argument("--start_frame", type=int, default=0) |
|
|
parser.add_argument("--initial_condition_frames", type=int, default=1) |
|
|
parser.add_argument("--frames_per_generation", type=int, default=8) |
|
|
parser.add_argument("--total_frames_to_generate", type=int, default=24) |
|
|
parser.add_argument("--max_history_frames", type=int, default=100) |
|
|
parser.add_argument("--use_real_poses", default=False) |
|
|
parser.add_argument("--dit_path", type=str, |
|
|
default="../models/Astra/checkpoints/diffusion_pytorch_model.ckpt", |
|
|
help="path to the pretrained DiT MoE model checkpoint") |
|
|
parser.add_argument("--wan_model_path", |
|
|
type=str, |
|
|
default="../models/Wan-AI/Wan2.1-T2V-1.3B", |
|
|
help="path to Wan2.1-T2V-1.3B") |
|
|
parser.add_argument("--output_path", type=str, |
|
|
default='../examples/output_videos/output_moe_framepack_sliding.mp4') |
|
|
parser.add_argument("--prompt", |
|
|
type=str, |
|
|
default="", |
|
|
help="text prompt for video generation") |
|
|
parser.add_argument("--device", type=str, default="cuda") |
|
|
parser.add_argument("--add_icons", action="store_true", default=False, |
|
|
help="Overlay control icons on generated video") |
|
|
|
|
|
|
|
|
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], |
|
|
default="sekai", help="Modality type: sekai, nuscenes, or openx") |
|
|
parser.add_argument("--scene_info_path", type=str, default=None, |
|
|
help="NuScenes scene info file path (for nuscenes modality only)") |
|
|
|
|
|
|
|
|
parser.add_argument("--use_camera_cfg", default=False, |
|
|
help="Use Camera CFG") |
|
|
parser.add_argument("--camera_guidance_scale", type=float, default=2.0, |
|
|
help="Camera guidance scale for CFG") |
|
|
parser.add_argument("--text_guidance_scale", type=float, default=1.0, |
|
|
help="Text guidance scale for CFG") |
|
|
|
|
|
|
|
|
parser.add_argument("--moe_num_experts", type=int, default=3, help="Number of experts") |
|
|
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K experts") |
|
|
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE hidden dimension") |
|
|
parser.add_argument("--direction", type=str, default="left", help="Direction of video trajectory") |
|
|
parser.add_argument("--use_gt_prompt", action="store_true", default=False, |
|
|
help="Use ground truth prompt embedding from dataset") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"MoE FramePack CFG generation settings:") |
|
|
print(f"Modality type: {args.modality_type}") |
|
|
print(f"Camera CFG: {args.use_camera_cfg}") |
|
|
if args.use_camera_cfg: |
|
|
print(f"Camera guidance scale: {args.camera_guidance_scale}") |
|
|
print(f"Using GT Prompt: {args.use_gt_prompt}") |
|
|
print(f"Text guidance scale: {args.text_guidance_scale}") |
|
|
print(f"MoE config: experts={args.moe_num_experts}, top_k={args.moe_top_k}") |
|
|
print(f"DiT{args.dit_path}") |
|
|
|
|
|
|
|
|
if args.modality_type == "nuscenes" and not args.scene_info_path: |
|
|
print("⚠️ Warning: Using NuScenes modality but scene_info_path not provided, will use synthetic pose data") |
|
|
|
|
|
if not args.use_gt_prompt and (args.prompt is None or args.prompt.strip() == ""): |
|
|
print("⚠️ Warning: No prompt provided, will use empty string as prompt") |
|
|
|
|
|
if not any([args.condition_pth, args.condition_video, args.condition_image]): |
|
|
raise ValueError("Need to provide condition_pth, condition_video, or condition_image as condition input") |
|
|
|
|
|
if args.condition_pth: |
|
|
print(f"Using pre-encoded pth: {args.condition_pth}") |
|
|
elif args.condition_video: |
|
|
print(f"Using condition video for online encoding: {args.condition_video}") |
|
|
elif args.condition_image: |
|
|
print(f"Using condition image for online encoding: {args.condition_image} (repeat 10 frames)") |
|
|
|
|
|
inference_moe_framepack_sliding_window( |
|
|
condition_pth_path=args.condition_pth, |
|
|
condition_video=args.condition_video, |
|
|
condition_image=args.condition_image, |
|
|
dit_path=args.dit_path, |
|
|
wan_model_path=args.wan_model_path, |
|
|
output_path=args.output_path, |
|
|
start_frame=args.start_frame, |
|
|
initial_condition_frames=args.initial_condition_frames, |
|
|
frames_per_generation=args.frames_per_generation, |
|
|
total_frames_to_generate=args.total_frames_to_generate, |
|
|
max_history_frames=args.max_history_frames, |
|
|
device=args.device, |
|
|
prompt=args.prompt, |
|
|
modality_type=args.modality_type, |
|
|
use_real_poses=args.use_real_poses, |
|
|
scene_info_path=args.scene_info_path, |
|
|
|
|
|
use_camera_cfg=args.use_camera_cfg, |
|
|
camera_guidance_scale=args.camera_guidance_scale, |
|
|
text_guidance_scale=args.text_guidance_scale, |
|
|
|
|
|
moe_num_experts=args.moe_num_experts, |
|
|
moe_top_k=args.moe_top_k, |
|
|
moe_hidden_dim=args.moe_hidden_dim, |
|
|
direction=args.direction, |
|
|
use_gt_prompt=args.use_gt_prompt, |
|
|
add_icons=args.add_icons |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |