Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import threading | |
from dataclasses import dataclass | |
from urllib.parse import urlparse | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from diffusers.models import AutoencoderKLWan | |
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
from einops import rearrange | |
from jaxtyping import Float | |
from PIL import Image | |
from torch import Tensor | |
from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline | |
from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel | |
from . import tensor_to_pil | |
from utils.file_utils import save_tensor_to_file, load_tensor_from_file | |
TEX_PIPE = None | |
VAE = None | |
LATENTS_MEAN, LATENTS_STD = None, None | |
TEX_PIPE_LOCK = threading.Lock() | |
class Config: | |
video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" | |
seqtex_transformer_path: str = "VAST-AI/SeqTex-Transformer" | |
min_noise_level_index: int = 15 # refer to paper [WorldMem](https://arxiv.org/pdf/2504.12369v1) | |
num_views: int = 4 | |
uv_num_views: int = 1 | |
mv_height: int = 512 | |
mv_width: int = 512 | |
uv_height: int = 1024 | |
uv_width: int = 1024 | |
flow_shift: float = 5.0 | |
eval_guidance_scale: float = 1.0 | |
eval_num_inference_steps: int = 30 | |
eval_seed: int = 42 | |
cfg = Config() | |
def get_seqtex_pipe(): | |
""" | |
Lazy load the SeqTex pipeline for texture generation. | |
Must be called within @spaces.GPU context. | |
""" | |
global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD | |
if TEX_PIPE is not None: | |
return TEX_PIPE | |
gr.Info("First called, loading SeqTex pipeline... It may take about 1 minute.") | |
with TEX_PIPE_LOCK: | |
if TEX_PIPE is not None: | |
return TEX_PIPE | |
# Load transformer with auto-configured LoRA adapter first | |
transformer = WanT2TexTransformer3DModel.from_pretrained( | |
cfg.seqtex_transformer_path, | |
token=os.environ["SEQTEX_SPACE_TOKEN"] | |
) | |
assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to VAST-AI/SeqTex-Transformer." | |
# Pipeline - pass the pre-loaded transformer to avoid re-loading | |
TEX_PIPE = WanT2TexPipeline.from_pretrained( | |
cfg.video_base_name, | |
transformer=transformer, | |
torch_dtype=torch.bfloat16 | |
) | |
del(transformer) | |
VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32) | |
TEX_PIPE.vae = VAE | |
# Some useful parameters - delay CUDA initialization until GPU context | |
LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view( | |
1, VAE.config.z_dim, 1, 1, 1 | |
).to(torch.float32) | |
LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view( | |
1, VAE.config.z_dim, 1, 1, 1 | |
).to(torch.float32) | |
scheduler: FlowMatchEulerDiscreteScheduler = ( | |
FlowMatchEulerDiscreteScheduler.from_config( | |
TEX_PIPE.scheduler.config, shift=cfg.flow_shift | |
) | |
) | |
min_noise_level_index = scheduler.config.num_train_timesteps - cfg.min_noise_level_index # in our scheduler, the first time is noise. set to 1000 - 15 typically | |
setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index) | |
min_noise_level_timestep = scheduler.timesteps[min_noise_level_index] | |
setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep) | |
setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.) | |
return TEX_PIPE.to("cuda") | |
def encode_images( | |
images: Float[Tensor, "B F H W C"], encode_as_first: bool = False | |
) -> Float[Tensor, "B C' F H/8 W/8"]: | |
""" | |
Encode images to latent space using VAE. | |
Every frame is seen as a separate image, without any awareness of the temporal dimension. | |
:param images: Input images tensor with shape [B, F, H, W, C]. | |
:param encode_as_first: Whether to encode all frames as the first frame. | |
:return: Encoded latents with shape [B, C', F, H/8, W/8]. | |
""" | |
global VAE, LATENTS_MEAN, LATENTS_STD | |
VAE = VAE.to("cuda").requires_grad_(False) | |
LATENTS_MEAN = LATENTS_MEAN.to("cuda") | |
LATENTS_STD = LATENTS_STD.to("cuda") | |
if images.min() < - 0.1: | |
# images are in [-1, 1] range | |
images = (images + 1.0) / 2.0 # Normalize to [0, 1] range | |
if encode_as_first: | |
# encode all the frame as the first one | |
B = images.shape[0] | |
images = rearrange(images, "B F H W C -> (B F) C 1 H W") | |
latents = (VAE.encode(images).latent_dist.sample() - LATENTS_MEAN) * LATENTS_STD | |
latents = rearrange(latents, "(B F) C 1 H W -> B C F H W", B=B) | |
else: | |
raise NotImplementedError("Currently only support encode as first frame.") | |
return latents | |
def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False): | |
""" | |
Decode latents back to images using VAE. | |
:param latents: Input latents with shape [B, C, F, H, W]. | |
:param decode_as_first: Whether to decode all frames as the first frame. | |
:return: Decoded images with shape [B, C, F*Nv, H*8, W*8]. | |
""" | |
global VAE, LATENTS_MEAN, LATENTS_STD | |
VAE = VAE.to("cuda").requires_grad_(False) | |
LATENTS_MEAN = LATENTS_MEAN.to("cuda") | |
LATENTS_STD = LATENTS_STD.to("cuda") | |
if decode_as_first: | |
F = latents.shape[2] | |
latents = latents.to(VAE.dtype) | |
latents = latents / LATENTS_STD + LATENTS_MEAN | |
latents = rearrange(latents, "B C F H W -> (B F) C 1 H W") | |
images = VAE.decode(latents, return_dict=False)[0] | |
images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1) | |
else: | |
raise NotImplementedError("Currently only support decode as first frame.") | |
return images | |
def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H W C"]: | |
""" | |
Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask. | |
:param image: PIL Image to convert. [0, 255] | |
:param device: Target device for the tensor. | |
:return: Tensor representation of the image. [0.0, 1.0], still [H, W, C] | |
""" | |
# Convert to RGBA to ensure alpha channel exists | |
image = image.convert("RGBA") | |
np_img = np.array(image) | |
rgb = np_img[..., :3] | |
alpha = np_img[..., 3:4] / 255.0 # Normalize alpha to [0, 1] | |
# Blend with black background using alpha mask | |
rgb = rgb * alpha | |
rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1] | |
tensor = torch.from_numpy(rgb) | |
if device != "cpu": | |
tensor = tensor.to(device) | |
return tensor | |
def generate_texture(position_map_path, normal_map_path, position_images_path, normal_images_path, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()): | |
""" | |
Use SeqTex to generate texture for the mesh based on the image condition. | |
:param position_images_path: File path to position images tensor | |
:param normal_images_path: File path to normal images tensor | |
:param condition_image: Image condition generated from the selected view. | |
:param text_prompt: Text prompt for texture generation. | |
:param selected_view: The view selected for generating the image condition. | |
:return: File paths of generated texture map and multi-view frames, and PIL images | |
""" | |
position_map = load_tensor_from_file(position_map_path, map_location=device) | |
normal_map = load_tensor_from_file(normal_map_path, map_location=device) | |
position_images = load_tensor_from_file(position_images_path, map_location=device) | |
normal_images = load_tensor_from_file(normal_images_path, map_location=device) | |
progress(0, desc="Loading SeqTex pipeline...") | |
tex_pipe = get_seqtex_pipe() | |
# assert tex_pipe is in gpu | |
assert tex_pipe.device.type == "cuda", "SeqTex pipeline must be loaded in GPU context." | |
progress(0.2, desc="SeqTex pipeline loaded successfully.") | |
view_id_map = { | |
"First View": 0, | |
"Second View": 1, | |
"Third View": 2, | |
"Fourth View": 3 | |
} | |
view_id = view_id_map[selected_view] | |
progress(0.3, desc="Encoding position and normal images...") | |
nat_seq = torch.cat([position_images.unsqueeze(0), normal_images.unsqueeze(0)], dim=0) # 1 F H W C | |
uv_seq = torch.cat([position_map.unsqueeze(0), normal_map.unsqueeze(0)], dim=0) | |
nat_latents = encode_images(nat_seq, encode_as_first=True) # B C F H W | |
uv_latents = encode_images(uv_seq, encode_as_first=True) # B C F' H' W' | |
nat_pos_latents, nat_norm_latents = torch.chunk(nat_latents, 2, dim=0) | |
uv_pos_latents, uv_norm_latents = torch.chunk(uv_latents, 2, dim=0) | |
nat_geo_latents = torch.cat([nat_pos_latents, nat_norm_latents], dim=1) | |
uv_geo_latents = torch.cat([uv_pos_latents, uv_norm_latents], dim=1) | |
cond_model_latents = (nat_geo_latents, uv_geo_latents) | |
num_frames = cfg.num_views * (2 ** sum(VAE.config.temperal_downsample)) | |
uv_num_frames = cfg.uv_num_views * (2 ** sum(VAE.config.temperal_downsample)) | |
progress(0.4, desc="Encoding condition image...") | |
if isinstance(condition_image, Image.Image): | |
condition_image = condition_image.resize((cfg.mv_width, cfg.mv_height), Image.LANCZOS) | |
# Convert PIL Image to tensor | |
condition_image = convert_img_to_tensor(condition_image, device=device) | |
condition_image = condition_image.unsqueeze(0).unsqueeze(0) | |
gt_latents = (encode_images(condition_image, encode_as_first=True), None) | |
progress(0.5, desc="Generating texture with SeqTex...") | |
latents = tex_pipe( | |
prompt=text_prompt, | |
negative_prompt=negative_prompt, | |
num_frames=num_frames, | |
generator=torch.Generator(device=device).manual_seed(cfg.eval_seed), | |
num_inference_steps=cfg.eval_num_inference_steps, | |
guidance_scale=cfg.eval_guidance_scale, | |
height=cfg.mv_height, | |
width=cfg.mv_width, | |
output_type="latent", | |
cond_model_latents=cond_model_latents, | |
# mask_indices=test_mask_indices, | |
uv_height=cfg.uv_height, | |
uv_width=cfg.uv_width, | |
uv_num_frames=uv_num_frames, | |
treat_as_first=True, | |
gt_condition=gt_latents, | |
inference_img_cond_frame=view_id, | |
use_qk_geometry=True, | |
task_type="img2tex", # img2tex | |
progress=progress, | |
).frames | |
mv_latents, uv_latents = latents | |
progress(0.9, desc="Decoding generated latents to images...") | |
mv_frames = decode_images(mv_latents, decode_as_first=True) # B C 4 H W | |
uv_frames = decode_images(uv_latents, decode_as_first=True) # B C 1 H W | |
uv_map_pred = uv_frames[:, :, -1, ...] | |
uv_map_pred.squeeze_(0) | |
mv_out = rearrange(mv_frames[:, :, :cfg.num_views, ...], "B C (F N) H W -> N C (B H) (F W)", N=1)[0] | |
mv_out = torch.clamp(mv_out, 0.0, 1.0) | |
uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0) | |
progress(1, desc="Texture generated successfully.") | |
uv_map_pred_path = save_tensor_to_file(uv_map_pred, prefix="uv_map_pred") | |
return uv_map_pred_path, tensor_to_pil(uv_map_pred, normalize=False), tensor_to_pil(mv_out, normalize=False), "Step 3: Texture generated successfully." | |