import os import threading from dataclasses import dataclass from urllib.parse import urlparse import gradio as gr import numpy as np import spaces import torch from diffusers.models import AutoencoderKLWan from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from einops import rearrange from jaxtyping import Float from PIL import Image from torch import Tensor from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel from . import tensor_to_pil from utils.file_utils import save_tensor_to_file, load_tensor_from_file TEX_PIPE = None VAE = None LATENTS_MEAN, LATENTS_STD = None, None TEX_PIPE_LOCK = threading.Lock() @dataclass class Config: video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" seqtex_transformer_path: str = "VAST-AI/SeqTex-Transformer" min_noise_level_index: int = 15 # refer to paper [WorldMem](https://arxiv.org/pdf/2504.12369v1) num_views: int = 4 uv_num_views: int = 1 mv_height: int = 512 mv_width: int = 512 uv_height: int = 1024 uv_width: int = 1024 flow_shift: float = 5.0 eval_guidance_scale: float = 1.0 eval_num_inference_steps: int = 30 eval_seed: int = 42 cfg = Config() def get_seqtex_pipe(): """ Lazy load the SeqTex pipeline for texture generation. Must be called within @spaces.GPU context. """ global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD if TEX_PIPE is not None: return TEX_PIPE gr.Info("First called, loading SeqTex pipeline... It may take about 1 minute.") with TEX_PIPE_LOCK: if TEX_PIPE is not None: return TEX_PIPE # Load transformer with auto-configured LoRA adapter first transformer = WanT2TexTransformer3DModel.from_pretrained( cfg.seqtex_transformer_path, token=os.environ["SEQTEX_SPACE_TOKEN"] ) assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to VAST-AI/SeqTex-Transformer." # Pipeline - pass the pre-loaded transformer to avoid re-loading TEX_PIPE = WanT2TexPipeline.from_pretrained( cfg.video_base_name, transformer=transformer, torch_dtype=torch.bfloat16 ) del(transformer) VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32) TEX_PIPE.vae = VAE # Some useful parameters - delay CUDA initialization until GPU context LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view( 1, VAE.config.z_dim, 1, 1, 1 ).to(torch.float32) LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view( 1, VAE.config.z_dim, 1, 1, 1 ).to(torch.float32) scheduler: FlowMatchEulerDiscreteScheduler = ( FlowMatchEulerDiscreteScheduler.from_config( TEX_PIPE.scheduler.config, shift=cfg.flow_shift ) ) min_noise_level_index = scheduler.config.num_train_timesteps - cfg.min_noise_level_index # in our scheduler, the first time is noise. set to 1000 - 15 typically setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index) min_noise_level_timestep = scheduler.timesteps[min_noise_level_index] setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep) setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.) return TEX_PIPE.to("cuda") @torch.amp.autocast('cuda', dtype=torch.float32) def encode_images( images: Float[Tensor, "B F H W C"], encode_as_first: bool = False ) -> Float[Tensor, "B C' F H/8 W/8"]: """ Encode images to latent space using VAE. Every frame is seen as a separate image, without any awareness of the temporal dimension. :param images: Input images tensor with shape [B, F, H, W, C]. :param encode_as_first: Whether to encode all frames as the first frame. :return: Encoded latents with shape [B, C', F, H/8, W/8]. """ global VAE, LATENTS_MEAN, LATENTS_STD VAE = VAE.to("cuda").requires_grad_(False) LATENTS_MEAN = LATENTS_MEAN.to("cuda") LATENTS_STD = LATENTS_STD.to("cuda") if images.min() < - 0.1: # images are in [-1, 1] range images = (images + 1.0) / 2.0 # Normalize to [0, 1] range if encode_as_first: # encode all the frame as the first one B = images.shape[0] images = rearrange(images, "B F H W C -> (B F) C 1 H W") latents = (VAE.encode(images).latent_dist.sample() - LATENTS_MEAN) * LATENTS_STD latents = rearrange(latents, "(B F) C 1 H W -> B C F H W", B=B) else: raise NotImplementedError("Currently only support encode as first frame.") return latents @torch.amp.autocast('cuda', dtype=torch.float32) def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False): """ Decode latents back to images using VAE. :param latents: Input latents with shape [B, C, F, H, W]. :param decode_as_first: Whether to decode all frames as the first frame. :return: Decoded images with shape [B, C, F*Nv, H*8, W*8]. """ global VAE, LATENTS_MEAN, LATENTS_STD VAE = VAE.to("cuda").requires_grad_(False) LATENTS_MEAN = LATENTS_MEAN.to("cuda") LATENTS_STD = LATENTS_STD.to("cuda") if decode_as_first: F = latents.shape[2] latents = latents.to(VAE.dtype) latents = latents / LATENTS_STD + LATENTS_MEAN latents = rearrange(latents, "B C F H W -> (B F) C 1 H W") images = VAE.decode(latents, return_dict=False)[0] images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1) else: raise NotImplementedError("Currently only support decode as first frame.") return images def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H W C"]: """ Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask. :param image: PIL Image to convert. [0, 255] :param device: Target device for the tensor. :return: Tensor representation of the image. [0.0, 1.0], still [H, W, C] """ # Convert to RGBA to ensure alpha channel exists image = image.convert("RGBA") np_img = np.array(image) rgb = np_img[..., :3] alpha = np_img[..., 3:4] / 255.0 # Normalize alpha to [0, 1] # Blend with black background using alpha mask rgb = rgb * alpha rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1] tensor = torch.from_numpy(rgb) if device != "cpu": tensor = tensor.to(device) return tensor @spaces.GPU(duration=90) @torch.no_grad @torch.inference_mode def generate_texture(position_map_path, normal_map_path, position_images_path, normal_images_path, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()): """ Use SeqTex to generate texture for the mesh based on the image condition. :param position_images_path: File path to position images tensor :param normal_images_path: File path to normal images tensor :param condition_image: Image condition generated from the selected view. :param text_prompt: Text prompt for texture generation. :param selected_view: The view selected for generating the image condition. :return: File paths of generated texture map and multi-view frames, and PIL images """ position_map = load_tensor_from_file(position_map_path, map_location=device) normal_map = load_tensor_from_file(normal_map_path, map_location=device) position_images = load_tensor_from_file(position_images_path, map_location=device) normal_images = load_tensor_from_file(normal_images_path, map_location=device) progress(0, desc="Loading SeqTex pipeline...") tex_pipe = get_seqtex_pipe() # assert tex_pipe is in gpu assert tex_pipe.device.type == "cuda", "SeqTex pipeline must be loaded in GPU context." progress(0.2, desc="SeqTex pipeline loaded successfully.") view_id_map = { "First View": 0, "Second View": 1, "Third View": 2, "Fourth View": 3 } view_id = view_id_map[selected_view] progress(0.3, desc="Encoding position and normal images...") nat_seq = torch.cat([position_images.unsqueeze(0), normal_images.unsqueeze(0)], dim=0) # 1 F H W C uv_seq = torch.cat([position_map.unsqueeze(0), normal_map.unsqueeze(0)], dim=0) nat_latents = encode_images(nat_seq, encode_as_first=True) # B C F H W uv_latents = encode_images(uv_seq, encode_as_first=True) # B C F' H' W' nat_pos_latents, nat_norm_latents = torch.chunk(nat_latents, 2, dim=0) uv_pos_latents, uv_norm_latents = torch.chunk(uv_latents, 2, dim=0) nat_geo_latents = torch.cat([nat_pos_latents, nat_norm_latents], dim=1) uv_geo_latents = torch.cat([uv_pos_latents, uv_norm_latents], dim=1) cond_model_latents = (nat_geo_latents, uv_geo_latents) num_frames = cfg.num_views * (2 ** sum(VAE.config.temperal_downsample)) uv_num_frames = cfg.uv_num_views * (2 ** sum(VAE.config.temperal_downsample)) progress(0.4, desc="Encoding condition image...") if isinstance(condition_image, Image.Image): condition_image = condition_image.resize((cfg.mv_width, cfg.mv_height), Image.LANCZOS) # Convert PIL Image to tensor condition_image = convert_img_to_tensor(condition_image, device=device) condition_image = condition_image.unsqueeze(0).unsqueeze(0) gt_latents = (encode_images(condition_image, encode_as_first=True), None) progress(0.5, desc="Generating texture with SeqTex...") latents = tex_pipe( prompt=text_prompt, negative_prompt=negative_prompt, num_frames=num_frames, generator=torch.Generator(device=device).manual_seed(cfg.eval_seed), num_inference_steps=cfg.eval_num_inference_steps, guidance_scale=cfg.eval_guidance_scale, height=cfg.mv_height, width=cfg.mv_width, output_type="latent", cond_model_latents=cond_model_latents, # mask_indices=test_mask_indices, uv_height=cfg.uv_height, uv_width=cfg.uv_width, uv_num_frames=uv_num_frames, treat_as_first=True, gt_condition=gt_latents, inference_img_cond_frame=view_id, use_qk_geometry=True, task_type="img2tex", # img2tex progress=progress, ).frames mv_latents, uv_latents = latents progress(0.9, desc="Decoding generated latents to images...") mv_frames = decode_images(mv_latents, decode_as_first=True) # B C 4 H W uv_frames = decode_images(uv_latents, decode_as_first=True) # B C 1 H W uv_map_pred = uv_frames[:, :, -1, ...] uv_map_pred.squeeze_(0) mv_out = rearrange(mv_frames[:, :, :cfg.num_views, ...], "B C (F N) H W -> N C (B H) (F W)", N=1)[0] mv_out = torch.clamp(mv_out, 0.0, 1.0) uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0) progress(1, desc="Texture generated successfully.") uv_map_pred_path = save_tensor_to_file(uv_map_pred, prefix="uv_map_pred") return uv_map_pred_path, tensor_to_pil(uv_map_pred, normalize=False), tensor_to_pil(mv_out, normalize=False), "Step 3: Texture generated successfully."