Astra / scripts /infer_demo.py
Yixuan
update readme
d234621
import os
import sys
from pathlib import Path
from typing import Optional
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(ROOT_DIR)
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import imageio
import json
from diffsynth import WanVideoAstraPipeline, ModelManager
import argparse
from torchvision.transforms import v2
from einops import rearrange
from scipy.spatial.transform import Rotation as R
import random
import copy
from datetime import datetime
VALID_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg"}
class InlineVideoEncoder:
def __init__(self, pipe: WanVideoAstraPipeline, device="cuda"):
self.device = getattr(pipe, "device", device)
self.tiler_kwargs = {"tiled": True, "tile_size": (34, 34), "tile_stride": (18, 16)}
self.frame_process = v2.Compose([
v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
self.pipe = pipe
@staticmethod
def _crop_and_resize(image: Image.Image) -> Image.Image:
target_w, target_h = 832, 480
return v2.functional.resize(
image,
(round(target_h), round(target_w)),
interpolation=v2.InterpolationMode.BILINEAR,
)
def preprocess_frame(self, image: Image.Image) -> torch.Tensor:
image = image.convert("RGB")
image = self._crop_and_resize(image)
return self.frame_process(image)
def load_video_frames(self, video_path: Path) -> Optional[torch.Tensor]:
reader = imageio.get_reader(str(video_path))
frames = []
for frame_data in reader:
frame = Image.fromarray(frame_data)
frames.append(self.preprocess_frame(frame))
reader.close()
if not frames:
return None
frames = torch.stack(frames, dim=0)
return rearrange(frames, "T C H W -> C T H W")
def encode_frames_to_latents(self, frames: torch.Tensor) -> torch.Tensor:
frames = frames.unsqueeze(0).to(self.device, dtype=torch.bfloat16)
with torch.no_grad():
latents = self.pipe.encode_video(frames, **self.tiler_kwargs)[0]
if latents.dim() == 5 and latents.shape[0] == 1:
latents = latents.squeeze(0)
return latents.cpu()
def image_to_frame_stack(
image_path: Path,
encoder: InlineVideoEncoder,
repeat_count: int = 10
) -> torch.Tensor:
"""Repeat a single image into a tensor with specified number of frames, shape [C, T, H, W]"""
if image_path.suffix.lower() not in VALID_IMAGE_EXTENSIONS:
raise ValueError(f"Unsupported image format: {image_path.suffix}")
image = Image.open(str(image_path))
frame = encoder.preprocess_frame(image)
frames = torch.stack([frame for _ in range(repeat_count)], dim=0)
return rearrange(frames, "T C H W -> C T H W")
def load_or_encode_condition(
condition_pth_path: Optional[str],
condition_video: Optional[str],
condition_image: Optional[str],
start_frame: int,
num_frames: int,
device: str,
pipe: WanVideoAstraPipeline,
) -> tuple[torch.Tensor, dict]:
if condition_pth_path:
return load_encoded_video_from_pth(condition_pth_path, start_frame, num_frames)
encoder = InlineVideoEncoder(pipe=pipe, device=device)
if condition_video:
video_path = Path(condition_video).expanduser().resolve()
if not video_path.exists():
raise FileNotFoundError(f"File not Found: {video_path}")
frames = encoder.load_video_frames(video_path)
if frames is None:
raise ValueError(f"no valid frames in {video_path}")
elif condition_image:
image_path = Path(condition_image).expanduser().resolve()
if not image_path.exists():
raise FileNotFoundError(f"File not Found: {image_path}")
frames = image_to_frame_stack(image_path, encoder, repeat_count=10)
else:
raise ValueError("condition video or image is needed for video generation.")
latents = encoder.encode_frames_to_latents(frames)
encoded_data = {"latents": latents}
if start_frame + num_frames > latents.shape[1]:
raise ValueError(
f"Not enough frames after encoding: requested {start_frame + num_frames}, available {latents.shape[1]}"
)
condition_latents = latents[:, start_frame:start_frame + num_frames, :, :]
return condition_latents, encoded_data
def compute_relative_pose_matrix(pose1, pose2):
"""
Compute relative pose between two consecutive frames, return 3x4 camera matrix [R_rel | t_rel]
Args:
pose1: Camera pose of frame i, shape (7,) array [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
pose2: Camera pose of frame i+1, shape (7,) array [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
Returns:
relative_matrix: 3x4 relative pose matrix,
first 3 columns are rotation matrix R_rel,
last column is translation vector t_rel
"""
# Separate translation vector and quaternion
t1 = pose1[:3] # Translation of frame i [tx1, ty1, tz1]
q1 = pose1[3:] # Quaternion of frame i [qx1, qy1, qz1, qw1]
t2 = pose2[:3] # Translation of frame i+1
q2 = pose2[3:] # Quaternion of frame i+1
# 1. Compute relative rotation matrix R_rel
rot1 = R.from_quat(q1) # Rotation of frame i
rot2 = R.from_quat(q2) # Rotation of frame i+1
rot_rel = rot2 * rot1.inv() # Relative rotation = next frame rotation × inverse of current frame rotation
R_rel = rot_rel.as_matrix() # Convert to 3x3 matrix
# 2. Compute relative translation vector t_rel
R1_T = rot1.as_matrix().T # Transpose of current frame rotation matrix (equivalent to inverse)
t_rel = R1_T @ (t2 - t1) # Relative translation = R1^T × (t2 - t1)
# 3. Combine into 3x4 matrix [R_rel | t_rel]
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
return relative_matrix
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
"""Load pre-encoded video data from pth file"""
print(f"Loading encoded video from {pth_path}")
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
full_latents = encoded_data['latents'] # [C, T, H, W]
print(f"Full latents shape: {full_latents.shape}")
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
if start_frame + num_frames > full_latents.shape[1]:
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
print(f"Extracted condition latents shape: {condition_latents.shape}")
return condition_latents, encoded_data
def compute_relative_pose(pose_a, pose_b, use_torch=False):
"""Compute relative pose matrix of camera B with respect to camera A"""
assert pose_a.shape == (4, 4), f"Camera A extrinsic matrix should be (4,4), got {pose_a.shape}"
assert pose_b.shape == (4, 4), f"Camera B extrinsic matrix should be (4,4), got {pose_b.shape}"
if use_torch:
if not isinstance(pose_a, torch.Tensor):
pose_a = torch.from_numpy(pose_a).float()
if not isinstance(pose_b, torch.Tensor):
pose_b = torch.from_numpy(pose_b).float()
pose_a_inv = torch.inverse(pose_a)
relative_pose = torch.matmul(pose_b, pose_a_inv)
else:
if not isinstance(pose_a, np.ndarray):
pose_a = np.array(pose_a, dtype=np.float32)
if not isinstance(pose_b, np.ndarray):
pose_b = np.array(pose_b, dtype=np.float32)
pose_a_inv = np.linalg.inv(pose_a)
relative_pose = np.matmul(pose_b, pose_a_inv)
return relative_pose
def replace_dit_model_in_manager():
"""Replace DiT model class with MoE version"""
from diffsynth.models.wan_video_dit_moe import WanModelMoe
from diffsynth.configs.model_config import model_loader_configs
for i, config in enumerate(model_loader_configs):
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
if 'wan_video_dit' in model_names:
new_model_names = []
new_model_classes = []
for name, cls in zip(model_names, model_classes):
if name == 'wan_video_dit':
new_model_names.append(name)
new_model_classes.append(WanModelMoe)
print(f"Replaced model class: {name} -> WanModelMoe")
else:
new_model_names.append(name)
new_model_classes.append(cls)
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
def add_framepack_components(dit_model):
"""Add FramePack related components"""
if not hasattr(dit_model, 'clean_x_embedder'):
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
class CleanXEmbedder(nn.Module):
def __init__(self, inner_dim):
super().__init__()
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
def forward(self, x, scale="1x"):
if scale == "1x":
x = x.to(self.proj.weight.dtype)
return self.proj(x)
elif scale == "2x":
x = x.to(self.proj_2x.weight.dtype)
return self.proj_2x(x)
elif scale == "4x":
x = x.to(self.proj_4x.weight.dtype)
return self.proj_4x(x)
else:
raise ValueError(f"Unsupported scale: {scale}")
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
model_dtype = next(dit_model.parameters()).dtype
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
print("Added FramePack clean_x_embedder component")
def add_moe_components(dit_model, moe_config):
"""Add MoE related components - corrected version"""
if not hasattr(dit_model, 'moe_config'):
dit_model.moe_config = moe_config
print("Added MoE config to model")
dit_model.top_k = moe_config.get("top_k", 1)
# Dynamically add MoE components for each block
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
unified_dim = moe_config.get("unified_dim", 25)
num_experts = moe_config.get("num_experts", 4)
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX uses 13-dim input, similar to sekai but handled independently
dit_model.global_router = nn.Linear(unified_dim, num_experts)
for i, block in enumerate(dit_model.blocks):
# MoE network - input unified_dim, output dim
block.moe = MultiModalMoE(
unified_dim=unified_dim,
output_dim=dim, # Output dimension matches transformer block dim
num_experts=moe_config.get("num_experts", 4),
top_k=moe_config.get("top_k", 2)
)
print(f"Block {i} added MoE component (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
def generate_sekai_camera_embeddings_sliding(
cam_data,
start_frame,
initial_condition_frames,
new_frames,
total_generated,
use_real_poses=True,
direction="left"):
"""
Generate camera embeddings for Sekai dataset - sliding window version
Args:
cam_data: Dictionary containing Sekai camera extrinsic parameters, key 'extrinsic' corresponds to an N*4*4 numpy array
start_frame: Current generation start frame index
initial_condition_frames: Initial condition frame count
new_frames: Number of new frames to generate this time
total_generated: Total frames already generated
use_real_poses: Whether to use real Sekai camera poses
direction: Camera movement direction, default "left"
Returns:
camera_embedding: Torch tensor of shape (M, 3*4 + 1), where M is the total number of generated frames
"""
time_compression_ratio = 4
# Calculate the actual number of camera frames needed for FramePack
# 1 initial frame + 16 frames 4x + 2 frames 2x + 1 frame 1x + new_frames
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
print("🔧 Using real Sekai camera data")
cam_extrinsic = cam_data['extrinsic']
# Ensure generating a sufficiently long camera sequence
max_needed_frames = max(
start_frame + initial_condition_frames + new_frames,
framepack_needed_frames,
30
)
print(f"🔧 Calculating Sekai camera sequence length:")
print(f" - Basic requirement: {start_frame + initial_condition_frames + new_frames}")
print(f" - FramePack requirement: {framepack_needed_frames}")
print(f" - Final generation: {max_needed_frames}")
relative_poses = []
for i in range(max_needed_frames):
# Calculate the position of the current frame in the original sequence
frame_idx = i * time_compression_ratio
next_frame_idx = frame_idx + time_compression_ratio
if next_frame_idx < len(cam_extrinsic):
cam_prev = cam_extrinsic[frame_idx]
cam_next = cam_extrinsic[next_frame_idx]
relative_pose = compute_relative_pose(cam_prev, cam_next)
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
else:
# Out of range, use zero motion
print(f"⚠️ Frame {frame_idx} exceeds camera data range, using zero motion")
relative_poses.append(torch.zeros(3, 4))
pose_embedding = torch.stack(relative_poses, dim=0)
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
# Create mask sequence of corresponding length
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
# Mark from start_frame to start_frame+initial_condition_frames as condition
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
mask[start_frame:condition_end] = 1.0
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
print(f"🔧 Sekai real camera embedding shape: {camera_embedding.shape}")
return camera_embedding.to(torch.bfloat16)
else:
# Ensure generating a sufficiently long camera sequence
max_needed_frames = max(
start_frame + initial_condition_frames + new_frames,
framepack_needed_frames,
30)
print(f"🔧 Generating Sekai synthetic camera frames: {max_needed_frames}")
CONDITION_FRAMES = initial_condition_frames
STAGE_1 = new_frames//2
STAGE_2 = new_frames - STAGE_1
if direction=="forward":
print("--------------- FORWARD MODE ---------------")
relative_poses = []
for i in range(max_needed_frames):
if i < CONDITION_FRAMES:
# Input condition frames default to zero motion camera pose
pose = np.eye(4, dtype=np.float32)
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
# Forward
forward_speed = 0.03
pose = np.eye(4, dtype=np.float32)
pose[2, 3] = -forward_speed
else:
# The part beyond condition frames and target frames remains stationary
pose = np.eye(4, dtype=np.float32)
relative_pose = pose[:3, :]
relative_poses.append(torch.as_tensor(relative_pose))
elif direction=="left":
print("--------------- LEFT TURNING MODE ---------------")
relative_poses = []
for i in range(max_needed_frames):
if i < CONDITION_FRAMES:
# Input condition frames default to zero motion camera pose
pose = np.eye(4, dtype=np.float32)
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
# Left turn
yaw_per_frame = 0.03
# Rotation matrix
cos_yaw = np.cos(yaw_per_frame)
sin_yaw = np.sin(yaw_per_frame)
# Forward
forward_speed = 0.00
pose = np.eye(4, dtype=np.float32)
pose[0, 0] = cos_yaw
pose[0, 2] = sin_yaw
pose[2, 0] = -sin_yaw
pose[2, 2] = cos_yaw
pose[2, 3] = -forward_speed
else:
# The part beyond condition frames and target frames remains stationary
pose = np.eye(4, dtype=np.float32)
relative_pose = pose[:3, :]
relative_poses.append(torch.as_tensor(relative_pose))
elif direction=="right":
print("--------------- RIGHT TURNING MODE ---------------")
relative_poses = []
for i in range(max_needed_frames):
if i < CONDITION_FRAMES:
# Input condition frames default to zero motion camera pose
pose = np.eye(4, dtype=np.float32)
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
# Right turn
yaw_per_frame = -0.03
# Rotation matrix
cos_yaw = np.cos(yaw_per_frame)
sin_yaw = np.sin(yaw_per_frame)
# Forward
forward_speed = 0.00
pose = np.eye(4, dtype=np.float32)
pose[0, 0] = cos_yaw
pose[0, 2] = sin_yaw
pose[2, 0] = -sin_yaw
pose[2, 2] = cos_yaw
pose[2, 3] = -forward_speed
else:
# The part beyond condition frames and target frames remains stationary
pose = np.eye(4, dtype=np.float32)
relative_pose = pose[:3, :]
relative_poses.append(torch.as_tensor(relative_pose))
elif direction=="forward_left":
print("--------------- FORWARD LEFT MODE ---------------")
relative_poses = []
for i in range(max_needed_frames):
if i < CONDITION_FRAMES:
# Input condition frames default to zero motion camera pose
pose = np.eye(4, dtype=np.float32)
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
# Left turn
yaw_per_frame = 0.03
# Rotation matrix
cos_yaw = np.cos(yaw_per_frame)
sin_yaw = np.sin(yaw_per_frame)
# Forward
forward_speed = 0.03
pose = np.eye(4, dtype=np.float32)
pose[0, 0] = cos_yaw
pose[0, 2] = sin_yaw
pose[2, 0] = -sin_yaw
pose[2, 2] = cos_yaw
pose[2, 3] = -forward_speed
else:
# The part beyond condition frames and target frames remains stationary
pose = np.eye(4, dtype=np.float32)
relative_pose = pose[:3, :]
relative_poses.append(torch.as_tensor(relative_pose))
elif direction=="forward_right":
print("--------------- FORWARD RIGHT MODE ---------------")
relative_poses = []
for i in range(max_needed_frames):
if i < CONDITION_FRAMES:
# Input condition frames default to zero motion camera pose
pose = np.eye(4, dtype=np.float32)
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
# Right turn
yaw_per_frame = -0.03
# Rotation matrix
cos_yaw = np.cos(yaw_per_frame)
sin_yaw = np.sin(yaw_per_frame)
# Forward
forward_speed = 0.03
pose = np.eye(4, dtype=np.float32)
pose[0, 0] = cos_yaw
pose[0, 2] = sin_yaw
pose[2, 0] = -sin_yaw
pose[2, 2] = cos_yaw
pose[2, 3] = -forward_speed
else:
# The part beyond condition frames and target frames remains stationary
pose = np.eye(4, dtype=np.float32)
relative_pose = pose[:3, :]
relative_poses.append(torch.as_tensor(relative_pose))
elif direction=="s_curve":
print("--------------- S CURVE MODE ---------------")
relative_poses = []
for i in range(max_needed_frames):
if i < CONDITION_FRAMES:
# Input condition frames default to zero motion camera pose
pose = np.eye(4, dtype=np.float32)
elif i < CONDITION_FRAMES+STAGE_1:
# Left turn
yaw_per_frame = 0.03
# Rotation matrix
cos_yaw = np.cos(yaw_per_frame)
sin_yaw = np.sin(yaw_per_frame)
# Forward
forward_speed = 0.03
pose = np.eye(4, dtype=np.float32)
pose[0, 0] = cos_yaw
pose[0, 2] = sin_yaw
pose[2, 0] = -sin_yaw
pose[2, 2] = cos_yaw
pose[2, 3] = -forward_speed
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
# Right turn
yaw_per_frame = -0.03
# Rotation matrix
cos_yaw = np.cos(yaw_per_frame)
sin_yaw = np.sin(yaw_per_frame)
# Forward
forward_speed = 0.03
# Slight left drift to maintain inertia
if i < CONDITION_FRAMES+STAGE_1+STAGE_2//3:
radius_shift = -0.01
else:
radius_shift = 0.00
pose = np.eye(4, dtype=np.float32)
pose[0, 0] = cos_yaw
pose[0, 2] = sin_yaw
pose[2, 0] = -sin_yaw
pose[2, 2] = cos_yaw
pose[2, 3] = -forward_speed
pose[0, 3] = radius_shift
else:
# The part beyond condition frames and target frames remains stationary
pose = np.eye(4, dtype=np.float32)
relative_pose = pose[:3, :]
relative_poses.append(torch.as_tensor(relative_pose))
elif direction=="left_right":
print("--------------- LEFT RIGHT MODE ---------------")
relative_poses = []
for i in range(max_needed_frames):
if i < CONDITION_FRAMES:
# Input condition frames default to zero motion camera pose
pose = np.eye(4, dtype=np.float32)
elif i < CONDITION_FRAMES+STAGE_1:
# Left turn
yaw_per_frame = 0.03
# Rotation matrix
cos_yaw = np.cos(yaw_per_frame)
sin_yaw = np.sin(yaw_per_frame)
# Forward
forward_speed = 0.00
pose = np.eye(4, dtype=np.float32)
pose[0, 0] = cos_yaw
pose[0, 2] = sin_yaw
pose[2, 0] = -sin_yaw
pose[2, 2] = cos_yaw
pose[2, 3] = -forward_speed
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
# Right turn
yaw_per_frame = -0.03
# Rotation matrix
cos_yaw = np.cos(yaw_per_frame)
sin_yaw = np.sin(yaw_per_frame)
# Forward
forward_speed = 0.00
pose = np.eye(4, dtype=np.float32)
pose[0, 0] = cos_yaw
pose[0, 2] = sin_yaw
pose[2, 0] = -sin_yaw
pose[2, 2] = cos_yaw
pose[2, 3] = -forward_speed
else:
# The part beyond condition frames and target frames remains stationary
pose = np.eye(4, dtype=np.float32)
relative_pose = pose[:3, :]
relative_poses.append(torch.as_tensor(relative_pose))
else:
raise ValueError(f"Not Defined Direction: {direction}")
pose_embedding = torch.stack(relative_poses, dim=0)
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
# Create mask sequence of corresponding length
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
condition_end = min(start_frame + initial_condition_frames + 1, max_needed_frames)
mask[start_frame:condition_end] = 1.0
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
print(f"🔧 Sekai synthetic camera embedding shape: {camera_embedding.shape}")
return camera_embedding.to(torch.bfloat16)
def generate_openx_camera_embeddings_sliding(
encoded_data, start_frame, initial_condition_frames, new_frames, use_real_poses):
"""Generate camera embeddings for OpenX dataset - sliding window version"""
time_compression_ratio = 4
# Calculate the actual number of camera frames needed for FramePack
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
print("🔧 Using OpenX real camera data")
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
# Ensure generating a sufficiently long camera sequence
max_needed_frames = max(
start_frame + initial_condition_frames + new_frames,
framepack_needed_frames,
30
)
print(f"🔧 Calculating OpenX camera sequence length:")
print(f" - Basic requirement: {start_frame + initial_condition_frames + new_frames}")
print(f" - FramePack requirement: {framepack_needed_frames}")
print(f" - Final generation: {max_needed_frames}")
relative_poses = []
for i in range(max_needed_frames):
# OpenX uses 4x interval, similar to sekai but handles shorter sequences
frame_idx = i * time_compression_ratio
next_frame_idx = frame_idx + time_compression_ratio
if next_frame_idx < len(cam_extrinsic):
cam_prev = cam_extrinsic[frame_idx]
cam_next = cam_extrinsic[next_frame_idx]
relative_pose = compute_relative_pose(cam_prev, cam_next)
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
else:
# Out of range, use zero motion
print(f"⚠️ Frame {frame_idx} exceeds OpenX camera data range, using zero motion")
relative_poses.append(torch.zeros(3, 4))
pose_embedding = torch.stack(relative_poses, dim=0)
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
# Create mask sequence of corresponding length
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
# Mark from start_frame to start_frame + initial_condition_frames as condition
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
mask[start_frame:condition_end] = 1.0
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
print(f"🔧 OpenX real camera embedding shape: {camera_embedding.shape}")
return camera_embedding.to(torch.bfloat16)
else:
print("🔧 Using OpenX synthetic camera data")
max_needed_frames = max(
start_frame + initial_condition_frames + new_frames,
framepack_needed_frames,
30
)
print(f"🔧 Generating OpenX synthetic camera frames: {max_needed_frames}")
relative_poses = []
for i in range(max_needed_frames):
# OpenX robot operation motion mode - smaller motion amplitude
# Simulate fine operation motion of robot arm
roll_per_frame = 0.02 # Slight roll
pitch_per_frame = 0.01 # Slight pitch
yaw_per_frame = 0.015 # Slight yaw
forward_speed = 0.003 # Slower forward speed
pose = np.eye(4, dtype=np.float32)
# Compound rotation - simulate complex motion of robot arm
# Rotate around X-axis (roll)
cos_roll = np.cos(roll_per_frame)
sin_roll = np.sin(roll_per_frame)
# Rotate around Y-axis (pitch)
cos_pitch = np.cos(pitch_per_frame)
sin_pitch = np.sin(pitch_per_frame)
# Rotate around Z-axis (yaw)
cos_yaw = np.cos(yaw_per_frame)
sin_yaw = np.sin(yaw_per_frame)
# Simplified compound rotation matrix (ZYX order)
pose[0, 0] = cos_yaw * cos_pitch
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
pose[1, 0] = sin_yaw * cos_pitch
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
pose[2, 0] = -sin_pitch
pose[2, 1] = cos_pitch * sin_roll
pose[2, 2] = cos_pitch * cos_roll
# Translation - simulate fine movement of robot operation
pose[0, 3] = forward_speed * 0.5 # Slight movement in X direction
pose[1, 3] = forward_speed * 0.3 # Slight movement in Y direction
pose[2, 3] = -forward_speed # Main movement in Z direction (depth)
relative_pose = pose[:3, :]
relative_poses.append(torch.as_tensor(relative_pose))
pose_embedding = torch.stack(relative_poses, dim=0)
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
# Create mask sequence of corresponding length
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
mask[start_frame:condition_end] = 1.0
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
print(f"🔧 OpenX synthetic camera embedding shape: {camera_embedding.shape}")
return camera_embedding.to(torch.bfloat16)
def generate_nuscenes_camera_embeddings_sliding(
scene_info, start_frame, initial_condition_frames, new_frames):
"""
Generate camera embeddings for NuScenes dataset - sliding window version
corrected version, consistent with train_moe.py
"""
time_compression_ratio = 4
# Calculate the actual number of camera frames needed for FramePack
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
if scene_info is not None and 'keyframe_poses' in scene_info:
print("🔧 Using NuScenes real pose data")
keyframe_poses = scene_info['keyframe_poses']
if len(keyframe_poses) == 0:
print("⚠️ NuScenes keyframe_poses is empty, using zero pose")
max_needed_frames = max(framepack_needed_frames, 30)
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
mask[start_frame:condition_end] = 1.0
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
print(f"🔧 NuScenes zero pose embedding shape: {camera_embedding.shape}")
return camera_embedding.to(torch.bfloat16)
# Use first pose as reference
reference_pose = keyframe_poses[0]
max_needed_frames = max(framepack_needed_frames, 30)
pose_vecs = []
for i in range(max_needed_frames):
if i < len(keyframe_poses):
current_pose = keyframe_poses[i]
# Calculate relative displacement
translation = torch.tensor(
np.array(current_pose['translation']) - np.array(reference_pose['translation']),
dtype=torch.float32
)
# Calculate relative rotation (simplified version)
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
else:
# Out of range, use zero pose
pose_vec = torch.cat([
torch.zeros(3, dtype=torch.float32),
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
], dim=0) # [7D]
pose_vecs.append(pose_vec)
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
# Create mask
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
mask[start_frame:condition_end] = 1.0
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
print(f"🔧 NuScenes real pose embedding shape: {camera_embedding.shape}")
return camera_embedding.to(torch.bfloat16)
else:
print("🔧 Using NuScenes synthetic pose data")
max_needed_frames = max(framepack_needed_frames, 30)
# Create synthetic motion sequence
pose_vecs = []
for i in range(max_needed_frames):
# Left turn motion mode - similar to left turns in city driving
angle = i * 0.04 # Rotate 0.08 radians per frame (slightly slower turn)
radius = 15.0 # Larger turning radius, more suitable for car turns
# Calculate position on circular arc trajectory
x = radius * np.sin(angle)
y = 0.0 # Keep horizontal plane motion
z = radius * (1 - np.cos(angle))
translation = torch.tensor([x, y, z], dtype=torch.float32)
# Vehicle orientation - always along trajectory tangent direction
yaw = angle + np.pi/2 # Yaw angle relative to initial forward direction
# Quaternion representation of rotation around Y-axis
rotation = torch.tensor([
np.cos(yaw/2), # w (real part)
0.0, # x
0.0, # y
np.sin(yaw/2) # z (imaginary part, around Y-axis)
], dtype=torch.float32)
pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
pose_vecs.append(pose_vec)
pose_sequence = torch.stack(pose_vecs, dim=0)
# Create mask
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
mask[start_frame:condition_end] = 1.0
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
print(f"🔧 NuScenes synthetic left turn pose embedding shape: {camera_embedding.shape}")
return camera_embedding.to(torch.bfloat16)
def prepare_framepack_sliding_window_with_camera_moe(
history_latents,
target_frames_to_generate,
camera_embedding_full,
start_frame,
modality_type,
max_history_frames=49):
"""FramePack sliding window mechanism - MoE version"""
# history_latents: [C, T, H, W] current history latents
C, T, H, W = history_latents.shape
# Fixed index structure (this determines the number of camera frames needed)
# 1 start frame + 16 frames 4x + 2 frames 2x + 1 frame 1x + target_frames_to_generate
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
indices = torch.arange(0, total_indices_length)
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
indices.split(split_sizes, dim=0)
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
# Check if camera length is sufficient
if camera_embedding_full.shape[0] < total_indices_length:
print(f"⚠️ camera_embedding length insufficient, performing zero padding: current length {camera_embedding_full.shape[0]}, required length {total_indices_length}")
shortage = total_indices_length - camera_embedding_full.shape[0]
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
# Select corresponding part from complete camera sequence
combined_camera = torch.zeros(
total_indices_length,
camera_embedding_full.shape[1],
dtype=camera_embedding_full.dtype,
device=camera_embedding_full.device)
# Camera poses for historical condition frames
history_slice = camera_embedding_full[max(T - 19, 0):T, :].clone()
combined_camera[19 - history_slice.shape[0]:19, :] = history_slice
# Camera poses for target frames
target_slice = camera_embedding_full[T:T + target_frames_to_generate, :].clone()
combined_camera[19:19 + target_slice.shape[0], :] = target_slice
# Reset mask according to current history length
combined_camera[:, -1] = 0.0 # First set all to target (0)
# Set condition mask: first 19 frames determined by actual history length
if T > 0:
available_frames = min(T, 19)
start_pos = 19 - available_frames
combined_camera[start_pos:19, -1] = 1.0 # Mark cameras corresponding to valid clean latents as condition
print(f"🔧 MoE Camera mask update:")
print(f" - History frames: {T}")
print(f" - Valid condition frames: {available_frames if T > 0 else 0}")
print(f" - Modality type: {modality_type}")
# Process latents
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
if T > 0:
available_frames = min(T, 19)
start_pos = 19 - available_frames
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
if T > 0:
start_latent = history_latents[:, 0:1, :, :]
else:
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
return {
'latent_indices': latent_indices,
'clean_latents': clean_latents,
'clean_latents_2x': clean_latents_2x,
'clean_latents_4x': clean_latents_4x,
'clean_latent_indices': clean_latent_indices,
'clean_latent_2x_indices': clean_latent_2x_indices,
'clean_latent_4x_indices': clean_latent_4x_indices,
'camera_embedding': combined_camera,
'modality_type': modality_type, # Added modality type information
'current_length': T,
'next_length': T + target_frames_to_generate
}
def overlay_controls(frame_img, pose_vec, icons):
"""
Overlay control icons (WASD and arrows) on frame based on camera pose
pose_vec: 12 elements (flattened 3x4 matrix) + mask
"""
if pose_vec is None or np.all(pose_vec[:12] == 0):
return frame_img
# Extract translation vector (based on flattened 3x4 matrix indices)
# [r00, r01, r02, tx, r10, r11, r12, ty, r20, r21, r22, tz]
tx = pose_vec[3]
# ty = pose_vec[7]
tz = pose_vec[11]
# Extract rotation (yaw and pitch)
# Yaw: around Y axis. sin(yaw) = r02, cos(yaw) = r00
r00 = pose_vec[0]
r02 = pose_vec[2]
yaw = np.arctan2(r02, r00)
# Pitch: around X axis. sin(pitch) = -r12, cos(pitch) = r22
r12 = pose_vec[6]
r22 = pose_vec[10]
pitch = np.arctan2(-r12, r22)
# Threshold for key activation
TRANS_THRESH = 0.01
ROT_THRESH = 0.005
# Determine key states
# Translation (WASD)
# Assume -Z is forward, +X is right
is_forward = tz < -TRANS_THRESH
is_backward = tz > TRANS_THRESH
is_left = tx < -TRANS_THRESH
is_right = tx > TRANS_THRESH
# Rotation (arrows)
# Yaw: + is left, - is right
is_turn_left = yaw > ROT_THRESH
is_turn_right = yaw < -ROT_THRESH
# Pitch: + is down, - is up
is_turn_up = pitch < -ROT_THRESH
is_turn_down = pitch > ROT_THRESH
W, H = frame_img.size
spacing = 60
def paste_icon(name_active, name_inactive, is_active, x, y):
name = name_active if is_active else name_inactive
if name in icons:
icon = icons[name]
# Paste using alpha channel
frame_img.paste(icon, (int(x), int(y)), icon)
# Overlay WASD (bottom left)
base_x_right = 100
base_y = H - 100
# W
paste_icon('move_forward.png', 'not_move_forward.png', is_forward, base_x_right, base_y - spacing)
# A
paste_icon('move_left.png', 'not_move_left.png', is_left, base_x_right - spacing, base_y)
# S
paste_icon('move_backward.png', 'not_move_backward.png', is_backward, base_x_right, base_y)
# D
paste_icon('move_right.png', 'not_move_right.png', is_right, base_x_right + spacing, base_y)
# Overlay arrows (bottom right)
base_x_left = W - 150
# ↑
paste_icon('turn_up.png', 'not_turn_up.png', is_turn_up, base_x_left, base_y - spacing)
# ←
paste_icon('turn_left.png', 'not_turn_left.png', is_turn_left, base_x_left - spacing, base_y)
# ↓
paste_icon('turn_down.png', 'not_turn_down.png', is_turn_down, base_x_left, base_y)
# →
paste_icon('turn_right.png', 'not_turn_right.png', is_turn_right, base_x_left + spacing, base_y)
return frame_img
def inference_moe_framepack_sliding_window(
condition_pth_path=None,
condition_video=None,
condition_image=None,
dit_path=None,
wan_model_path=None,
output_path="../examples/output_videos/output_moe_framepack_sliding.mp4",
start_frame=0,
initial_condition_frames=8,
frames_per_generation=4,
total_frames_to_generate=32,
max_history_frames=49,
device="cuda",
prompt="A video of a scene shot using a pedestrian's front camera while walking",
modality_type="sekai", # "sekai" or "nuscenes"
use_real_poses=True,
scene_info_path=None, # For NuScenes dataset
# CFG parameters
use_camera_cfg=True,
camera_guidance_scale=2.0,
text_guidance_scale=1.0,
# MoE parameters
moe_num_experts=4,
moe_top_k=2,
moe_hidden_dim=None,
direction="left",
use_gt_prompt=True,
add_icons=False
):
"""
MoE FramePack sliding window video generation - multi-modal support
"""
# Create output directory
dir_path = os.path.dirname(output_path)
os.makedirs(dir_path, exist_ok=True)
print(f"🔧 Starting MoE FramePack sliding window generation...")
print(f" Modality type: {modality_type}")
print(f" Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
print(f" Text guidance scale: {text_guidance_scale}")
print(f" MoE config: experts={moe_num_experts}, top_k={moe_top_k}")
# 1. Model initialization
replace_dit_model_in_manager()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([
os.path.join(wan_model_path, "diffusion_pytorch_model.safetensors"),
os.path.join(wan_model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
os.path.join(wan_model_path, "Wan2.1_VAE.pth"),
])
pipe = WanVideoAstraPipeline.from_model_manager(model_manager, device="cuda")
# 2. Add traditional camera encoder (compatibility)
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
for block in pipe.dit.blocks:
block.cam_encoder = nn.Linear(13, dim)
block.projector = nn.Linear(dim, dim)
block.cam_encoder.weight.data.zero_()
block.cam_encoder.bias.data.zero_()
block.projector.weight = nn.Parameter(torch.eye(dim))
block.projector.bias = nn.Parameter(torch.zeros(dim))
# 3. Add FramePack components
add_framepack_components(pipe.dit)
# 4. Add MoE components
moe_config = {
"num_experts": moe_num_experts,
"top_k": moe_top_k,
"hidden_dim": moe_hidden_dim or dim * 2,
"sekai_input_dim": 13, # Sekai: 12-dim pose + 1-dim mask
"nuscenes_input_dim": 8, # NuScenes: 7-dim pose + 1-dim mask
"openx_input_dim": 13 # OpenX: 12-dim pose + 1-dim mask (similar to sekai)
}
add_moe_components(pipe.dit, moe_config)
# 5. Load trained weights
dit_state_dict = torch.load(dit_path, map_location="cpu")
pipe.dit.load_state_dict(dit_state_dict, strict=False) # Use strict=False to be compatible with newly added MoE components
pipe = pipe.to(device)
model_dtype = next(pipe.dit.parameters()).dtype
if hasattr(pipe.dit, 'clean_x_embedder'):
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
# Set denoising steps
pipe.scheduler.set_timesteps(50)
# 6. Load initial conditions
print("Loading initial condition frames...")
initial_latents, encoded_data = load_or_encode_condition(
condition_pth_path,
condition_video,
condition_image,
start_frame,
initial_condition_frames,
device,
pipe,
)
# Spatial cropping
target_height, target_width = 60, 104
C, T, H, W = initial_latents.shape
if H > target_height or W > target_width:
h_start = (H - target_height) // 2
w_start = (W - target_width) // 2
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
H, W = target_height, target_width
history_latents = initial_latents.to(device, dtype=model_dtype)
print(f"Initial history_latents shape: {history_latents.shape}")
# 7. Encode prompt - support CFG
if use_gt_prompt and 'prompt_emb' in encoded_data:
print("✅ Using pre-encoded GT prompt embedding")
prompt_emb_pos = encoded_data['prompt_emb']
# Move prompt_emb to correct device and dtype
if 'context' in prompt_emb_pos:
prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
if 'context_mask' in prompt_emb_pos:
prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
# Generate negative prompt if using Text CFG
if text_guidance_scale > 1.0:
prompt_emb_neg = pipe.encode_prompt("")
print(f"Using Text CFG with GT prompt, guidance scale: {text_guidance_scale}")
else:
prompt_emb_neg = None
print("Not using Text CFG")
# Print GT prompt text if available
if 'prompt' in encoded_data['prompt_emb']:
gt_prompt_text = encoded_data['prompt_emb']['prompt']
print(f"📝 GT Prompt text: {gt_prompt_text}")
else:
# Re-encode using provided prompt parameter
print(f"🔄 Re-encoding prompt: {prompt}")
if text_guidance_scale > 1.0:
prompt_emb_pos = pipe.encode_prompt(prompt)
prompt_emb_neg = pipe.encode_prompt("")
print(f"Using Text CFG, guidance scale: {text_guidance_scale}")
else:
prompt_emb_pos = pipe.encode_prompt(prompt)
prompt_emb_neg = None
print("Not using Text CFG")
# 8. Load scene information (for NuScenes)
scene_info = None
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
with open(scene_info_path, 'r') as f:
scene_info = json.load(f)
print(f"Loading NuScenes scene information: {scene_info_path}")
# 9. Pre-generate complete camera embedding sequence
if modality_type == "sekai":
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
encoded_data.get('cam_emb', None),
start_frame,
initial_condition_frames,
total_frames_to_generate,
0,
use_real_poses=use_real_poses,
direction=direction
).to(device, dtype=model_dtype)
elif modality_type == "nuscenes":
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
scene_info,
start_frame,
initial_condition_frames,
total_frames_to_generate
).to(device, dtype=model_dtype)
elif modality_type == "openx":
camera_embedding_full = generate_openx_camera_embeddings_sliding(
encoded_data,
start_frame,
initial_condition_frames,
total_frames_to_generate,
use_real_poses=use_real_poses
).to(device, dtype=model_dtype)
else:
raise ValueError(f"Unsupported modality type: {modality_type}")
print(f"Complete camera sequence shape: {camera_embedding_full.shape}")
# 10. Create unconditional camera embedding for Camera CFG
if use_camera_cfg:
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
print(f"Creating unconditional camera embedding for CFG")
# 11. Sliding window generation loop
total_generated = 0
all_generated_frames = []
while total_generated < total_frames_to_generate:
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
print(f"\nGeneration step {total_generated // frames_per_generation + 1}")
print(f"Current history length: {history_latents.shape[1]}, generating: {current_generation}")
# FramePack data preparation - MoE version
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
history_latents,
current_generation,
camera_embedding_full,
start_frame,
modality_type,
max_history_frames
)
# Prepare input
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
# Prepare modality_inputs
modality_inputs = {modality_type: camera_embedding}
# Prepare unconditional camera embedding for CFG
if use_camera_cfg:
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
# Index processing
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
# Initialize latents to generate
new_latents = torch.randn(
1, C, current_generation, H, W,
device=device, dtype=model_dtype
)
extra_input = pipe.prepare_extra_input(new_latents)
print(f"Camera embedding shape: {camera_embedding.shape}")
print(f"Camera mask distribution - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
# Denoising loop - supports CFG
timesteps = pipe.scheduler.timesteps
for i, timestep in enumerate(timesteps):
if i % 10 == 0:
print(f" Denoising step {i+1}/{len(timesteps)}")
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
with torch.no_grad():
# CFG inference
if use_camera_cfg and camera_guidance_scale > 1.0:
# Conditional prediction (with camera)
noise_pred_cond, moe_loess = pipe.dit(
new_latents,
timestep=timestep_tensor,
cam_emb=camera_embedding,
modality_inputs=modality_inputs, # MoE modality input
latent_indices=latent_indices,
clean_latents=clean_latents,
clean_latent_indices=clean_latent_indices,
clean_latents_2x=clean_latents_2x,
clean_latent_2x_indices=clean_latent_2x_indices,
clean_latents_4x=clean_latents_4x,
clean_latent_4x_indices=clean_latent_4x_indices,
**prompt_emb_pos,
**extra_input
)
# Unconditional prediction (no camera)
noise_pred_uncond, moe_loess = pipe.dit(
new_latents,
timestep=timestep_tensor,
cam_emb=camera_embedding_uncond_batch,
modality_inputs=modality_inputs_uncond, # MoE unconditional modality input
latent_indices=latent_indices,
clean_latents=clean_latents,
clean_latent_indices=clean_latent_indices,
clean_latents_2x=clean_latents_2x,
clean_latent_2x_indices=clean_latent_2x_indices,
clean_latents_4x=clean_latents_4x,
clean_latent_4x_indices=clean_latent_4x_indices,
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
**extra_input
)
# Camera CFG
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
# If using Text CFG at the same time
if text_guidance_scale > 1.0 and prompt_emb_neg:
noise_pred_text_uncond, moe_loess = pipe.dit(
new_latents,
timestep=timestep_tensor,
cam_emb=camera_embedding,
modality_inputs=modality_inputs,
latent_indices=latent_indices,
clean_latents=clean_latents,
clean_latent_indices=clean_latent_indices,
clean_latents_2x=clean_latents_2x,
clean_latent_2x_indices=clean_latent_2x_indices,
clean_latents_4x=clean_latents_4x,
clean_latent_4x_indices=clean_latent_4x_indices,
**prompt_emb_neg,
**extra_input
)
# Apply Text CFG to results that have already applied Camera CFG
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
elif text_guidance_scale > 1.0 and prompt_emb_neg:
# Use Text CFG only
noise_pred_cond, moe_loess = pipe.dit(
new_latents,
timestep=timestep_tensor,
cam_emb=camera_embedding,
modality_inputs=modality_inputs,
latent_indices=latent_indices,
clean_latents=clean_latents,
clean_latent_indices=clean_latent_indices,
clean_latents_2x=clean_latents_2x,
clean_latent_2x_indices=clean_latent_2x_indices,
clean_latents_4x=clean_latents_4x,
clean_latent_4x_indices=clean_latent_4x_indices,
**prompt_emb_pos,
**extra_input
)
noise_pred_uncond, moe_loess= pipe.dit(
new_latents,
timestep=timestep_tensor,
cam_emb=camera_embedding,
modality_inputs=modality_inputs,
latent_indices=latent_indices,
clean_latents=clean_latents,
clean_latent_indices=clean_latent_indices,
clean_latents_2x=clean_latents_2x,
clean_latent_2x_indices=clean_latent_2x_indices,
clean_latents_4x=clean_latents_4x,
clean_latent_4x_indices=clean_latent_4x_indices,
**prompt_emb_neg,
**extra_input
)
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
# Standard inference (no CFG)
noise_pred, moe_loess = pipe.dit(
new_latents,
timestep=timestep_tensor,
cam_emb=camera_embedding,
modality_inputs=modality_inputs, # MoE modality input
latent_indices=latent_indices,
clean_latents=clean_latents,
clean_latent_indices=clean_latent_indices,
clean_latents_2x=clean_latents_2x,
clean_latent_2x_indices=clean_latent_2x_indices,
clean_latents_4x=clean_latents_4x,
clean_latent_4x_indices=clean_latent_4x_indices,
**prompt_emb_pos,
**extra_input
)
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
# Update history
new_latents_squeezed = new_latents.squeeze(0)
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
# Maintain sliding window
if history_latents.shape[1] > max_history_frames:
first_frame = history_latents[:, 0:1, :, :]
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
history_latents = torch.cat([first_frame, recent_frames], dim=1)
print(f"⚠️ History window full, keeping first frame + latest {max_history_frames-1} frames")
print(f"History_latents shape after update: {history_latents.shape}")
all_generated_frames.append(new_latents_squeezed)
total_generated += current_generation
print(f"✅ Generated {total_generated}/{total_frames_to_generate} frames")
# 12. Decode and save
print("\nDecoding generated video...")
all_generated = torch.cat(all_generated_frames, dim=1)
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
print(f"Final video shape: {final_video.shape}")
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
print(f"Saving video to {output_path} ...")
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
video_np = (video_np * 255).astype(np.uint8)
icons = {}
video_camera_poses = None
if add_icons:
# Load icon resources for overlay
icons_dir = os.path.join(ROOT_DIR, 'icons')
icon_names = ['move_forward.png', 'not_move_forward.png',
'move_backward.png', 'not_move_backward.png',
'move_left.png', 'not_move_left.png',
'move_right.png', 'not_move_right.png',
'turn_up.png', 'not_turn_up.png',
'turn_down.png', 'not_turn_down.png',
'turn_left.png', 'not_turn_left.png',
'turn_right.png', 'not_turn_right.png']
for name in icon_names:
path = os.path.join(icons_dir, name)
if os.path.exists(path):
try:
icon = Image.open(path).convert("RGBA")
# Adjust icon size
icon = icon.resize((50, 50), Image.Resampling.LANCZOS)
icons[name] = icon
except Exception as e:
print(f"Error loading icon {name}: {e}")
else:
print(f"⚠️ Warning: Icon {name} not found at {path}")
# Get camera poses corresponding to video frames
time_compression_ratio = 4
camera_poses = camera_embedding_full.detach().float().cpu().numpy()
video_camera_poses = [x for x in camera_poses for _ in range(time_compression_ratio)]
with imageio.get_writer(output_path, fps=20) as writer:
for i, frame in enumerate(video_np):
# Convert to PIL for overlay
img = Image.fromarray(frame)
if add_icons and video_camera_poses is not None and icons:
# Video frame i corresponds to camera_embedding_full[start_frame + i]
pose_idx = start_frame + i
if pose_idx < len(video_camera_poses):
pose_vec = video_camera_poses[pose_idx]
img = overlay_controls(img, pose_vec, icons)
writer.append_data(np.array(img))
print(f"✅ MoE FramePack sliding window generation completed! Saved to: {output_path}")
print(f" Total generated {total_generated} frames (compressed), corresponding to original {total_generated * 4} frames")
print(f" Using modality: {modality_type}")
def main():
parser = argparse.ArgumentParser(description="MoE FramePack sliding window video generation - supports multi-modal")
# Basic parameters
parser.add_argument("--condition_pth",
type=str,
default=None,
help="Path to pre-encoded condition pth file")
parser.add_argument("--condition_video",
type=str,
default=None,
help="Input video for novel view synthesis.")
parser.add_argument("--condition_image",
type=str,
default=None,
required=True,
help="Input image for novel view synthesis.")
parser.add_argument("--start_frame", type=int, default=0)
parser.add_argument("--initial_condition_frames", type=int, default=1)
parser.add_argument("--frames_per_generation", type=int, default=8)
parser.add_argument("--total_frames_to_generate", type=int, default=24)
parser.add_argument("--max_history_frames", type=int, default=100)
parser.add_argument("--use_real_poses", default=False)
parser.add_argument("--dit_path", type=str,
default="../models/Astra/checkpoints/diffusion_pytorch_model.ckpt",
help="path to the pretrained DiT MoE model checkpoint")
parser.add_argument("--wan_model_path",
type=str,
default="../models/Wan-AI/Wan2.1-T2V-1.3B",
help="path to Wan2.1-T2V-1.3B")
parser.add_argument("--output_path", type=str,
default='../examples/output_videos/output_moe_framepack_sliding.mp4')
parser.add_argument("--prompt",
type=str,
default="",
help="text prompt for video generation")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--add_icons", action="store_true", default=False,
help="Overlay control icons on generated video")
# Modality type parameters
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"],
default="sekai", help="Modality type: sekai, nuscenes, or openx")
parser.add_argument("--scene_info_path", type=str, default=None,
help="NuScenes scene info file path (for nuscenes modality only)")
# CFG parameters
parser.add_argument("--use_camera_cfg", default=False,
help="Use Camera CFG")
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
help="Camera guidance scale for CFG")
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
help="Text guidance scale for CFG")
# MoE parameters
parser.add_argument("--moe_num_experts", type=int, default=3, help="Number of experts")
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K experts")
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE hidden dimension")
parser.add_argument("--direction", type=str, default="left", help="Direction of video trajectory")
parser.add_argument("--use_gt_prompt", action="store_true", default=False,
help="Use ground truth prompt embedding from dataset")
args = parser.parse_args()
print(f"MoE FramePack CFG generation settings:")
print(f"Modality type: {args.modality_type}")
print(f"Camera CFG: {args.use_camera_cfg}")
if args.use_camera_cfg:
print(f"Camera guidance scale: {args.camera_guidance_scale}")
print(f"Using GT Prompt: {args.use_gt_prompt}")
print(f"Text guidance scale: {args.text_guidance_scale}")
print(f"MoE config: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
print(f"DiT{args.dit_path}")
# Validate NuScenes parameters
if args.modality_type == "nuscenes" and not args.scene_info_path:
print("⚠️ Warning: Using NuScenes modality but scene_info_path not provided, will use synthetic pose data")
if not args.use_gt_prompt and (args.prompt is None or args.prompt.strip() == ""):
print("⚠️ Warning: No prompt provided, will use empty string as prompt")
if not any([args.condition_pth, args.condition_video, args.condition_image]):
raise ValueError("Need to provide condition_pth, condition_video, or condition_image as condition input")
if args.condition_pth:
print(f"Using pre-encoded pth: {args.condition_pth}")
elif args.condition_video:
print(f"Using condition video for online encoding: {args.condition_video}")
elif args.condition_image:
print(f"Using condition image for online encoding: {args.condition_image} (repeat 10 frames)")
inference_moe_framepack_sliding_window(
condition_pth_path=args.condition_pth,
condition_video=args.condition_video,
condition_image=args.condition_image,
dit_path=args.dit_path,
wan_model_path=args.wan_model_path,
output_path=args.output_path,
start_frame=args.start_frame,
initial_condition_frames=args.initial_condition_frames,
frames_per_generation=args.frames_per_generation,
total_frames_to_generate=args.total_frames_to_generate,
max_history_frames=args.max_history_frames,
device=args.device,
prompt=args.prompt,
modality_type=args.modality_type,
use_real_poses=args.use_real_poses,
scene_info_path=args.scene_info_path,
# CFG parameters
use_camera_cfg=args.use_camera_cfg,
camera_guidance_scale=args.camera_guidance_scale,
text_guidance_scale=args.text_guidance_scale,
# MoE parameters
moe_num_experts=args.moe_num_experts,
moe_top_k=args.moe_top_k,
moe_hidden_dim=args.moe_hidden_dim,
direction=args.direction,
use_gt_prompt=args.use_gt_prompt,
add_icons=args.add_icons
)
if __name__ == "__main__":
main()