Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,430 Bytes
1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf 1d5bb62 6d4bcdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
import os
import threading
from dataclasses import dataclass
from urllib.parse import urlparse
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers.models import AutoencoderKLWan
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from einops import rearrange
from jaxtyping import Float
from PIL import Image
from torch import Tensor
from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline
from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel
from . import tensor_to_pil
from utils.file_utils import save_tensor_to_file, load_tensor_from_file
TEX_PIPE = None
VAE = None
LATENTS_MEAN, LATENTS_STD = None, None
TEX_PIPE_LOCK = threading.Lock()
@dataclass
class Config:
video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
seqtex_transformer_path: str = "VAST-AI/SeqTex-Transformer"
min_noise_level_index: int = 15 # refer to paper [WorldMem](https://arxiv.org/pdf/2504.12369v1)
num_views: int = 4
uv_num_views: int = 1
mv_height: int = 512
mv_width: int = 512
uv_height: int = 1024
uv_width: int = 1024
flow_shift: float = 5.0
eval_guidance_scale: float = 1.0
eval_num_inference_steps: int = 30
eval_seed: int = 42
cfg = Config()
def get_seqtex_pipe():
"""
Lazy load the SeqTex pipeline for texture generation.
Must be called within @spaces.GPU context.
"""
global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD
if TEX_PIPE is not None:
return TEX_PIPE
gr.Info("First called, loading SeqTex pipeline... It may take about 1 minute.")
with TEX_PIPE_LOCK:
if TEX_PIPE is not None:
return TEX_PIPE
# Load transformer with auto-configured LoRA adapter first
transformer = WanT2TexTransformer3DModel.from_pretrained(
cfg.seqtex_transformer_path,
token=os.environ["SEQTEX_SPACE_TOKEN"]
)
assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to VAST-AI/SeqTex-Transformer."
# Pipeline - pass the pre-loaded transformer to avoid re-loading
TEX_PIPE = WanT2TexPipeline.from_pretrained(
cfg.video_base_name,
transformer=transformer,
torch_dtype=torch.bfloat16
)
del(transformer)
VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32)
TEX_PIPE.vae = VAE
# Some useful parameters - delay CUDA initialization until GPU context
LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view(
1, VAE.config.z_dim, 1, 1, 1
).to(torch.float32)
LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view(
1, VAE.config.z_dim, 1, 1, 1
).to(torch.float32)
scheduler: FlowMatchEulerDiscreteScheduler = (
FlowMatchEulerDiscreteScheduler.from_config(
TEX_PIPE.scheduler.config, shift=cfg.flow_shift
)
)
min_noise_level_index = scheduler.config.num_train_timesteps - cfg.min_noise_level_index # in our scheduler, the first time is noise. set to 1000 - 15 typically
setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index)
min_noise_level_timestep = scheduler.timesteps[min_noise_level_index]
setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep)
setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.)
return TEX_PIPE.to("cuda")
@torch.amp.autocast('cuda', dtype=torch.float32)
def encode_images(
images: Float[Tensor, "B F H W C"], encode_as_first: bool = False
) -> Float[Tensor, "B C' F H/8 W/8"]:
"""
Encode images to latent space using VAE.
Every frame is seen as a separate image, without any awareness of the temporal dimension.
:param images: Input images tensor with shape [B, F, H, W, C].
:param encode_as_first: Whether to encode all frames as the first frame.
:return: Encoded latents with shape [B, C', F, H/8, W/8].
"""
global VAE, LATENTS_MEAN, LATENTS_STD
VAE = VAE.to("cuda").requires_grad_(False)
LATENTS_MEAN = LATENTS_MEAN.to("cuda")
LATENTS_STD = LATENTS_STD.to("cuda")
if images.min() < - 0.1:
# images are in [-1, 1] range
images = (images + 1.0) / 2.0 # Normalize to [0, 1] range
if encode_as_first:
# encode all the frame as the first one
B = images.shape[0]
images = rearrange(images, "B F H W C -> (B F) C 1 H W")
latents = (VAE.encode(images).latent_dist.sample() - LATENTS_MEAN) * LATENTS_STD
latents = rearrange(latents, "(B F) C 1 H W -> B C F H W", B=B)
else:
raise NotImplementedError("Currently only support encode as first frame.")
return latents
@torch.amp.autocast('cuda', dtype=torch.float32)
def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
"""
Decode latents back to images using VAE.
:param latents: Input latents with shape [B, C, F, H, W].
:param decode_as_first: Whether to decode all frames as the first frame.
:return: Decoded images with shape [B, C, F*Nv, H*8, W*8].
"""
global VAE, LATENTS_MEAN, LATENTS_STD
VAE = VAE.to("cuda").requires_grad_(False)
LATENTS_MEAN = LATENTS_MEAN.to("cuda")
LATENTS_STD = LATENTS_STD.to("cuda")
if decode_as_first:
F = latents.shape[2]
latents = latents.to(VAE.dtype)
latents = latents / LATENTS_STD + LATENTS_MEAN
latents = rearrange(latents, "B C F H W -> (B F) C 1 H W")
images = VAE.decode(latents, return_dict=False)[0]
images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1)
else:
raise NotImplementedError("Currently only support decode as first frame.")
return images
def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H W C"]:
"""
Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask.
:param image: PIL Image to convert. [0, 255]
:param device: Target device for the tensor.
:return: Tensor representation of the image. [0.0, 1.0], still [H, W, C]
"""
# Convert to RGBA to ensure alpha channel exists
image = image.convert("RGBA")
np_img = np.array(image)
rgb = np_img[..., :3]
alpha = np_img[..., 3:4] / 255.0 # Normalize alpha to [0, 1]
# Blend with black background using alpha mask
rgb = rgb * alpha
rgb = rgb.astype(np.float32) / 255.0 # Normalize to [0, 1]
tensor = torch.from_numpy(rgb)
if device != "cpu":
tensor = tensor.to(device)
return tensor
@spaces.GPU(duration=90)
@torch.no_grad
@torch.inference_mode
def generate_texture(position_map_path, normal_map_path, position_images_path, normal_images_path, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()):
"""
Use SeqTex to generate texture for the mesh based on the image condition.
:param position_images_path: File path to position images tensor
:param normal_images_path: File path to normal images tensor
:param condition_image: Image condition generated from the selected view.
:param text_prompt: Text prompt for texture generation.
:param selected_view: The view selected for generating the image condition.
:return: File paths of generated texture map and multi-view frames, and PIL images
"""
position_map = load_tensor_from_file(position_map_path, map_location=device)
normal_map = load_tensor_from_file(normal_map_path, map_location=device)
position_images = load_tensor_from_file(position_images_path, map_location=device)
normal_images = load_tensor_from_file(normal_images_path, map_location=device)
progress(0, desc="Loading SeqTex pipeline...")
tex_pipe = get_seqtex_pipe()
# assert tex_pipe is in gpu
assert tex_pipe.device.type == "cuda", "SeqTex pipeline must be loaded in GPU context."
progress(0.2, desc="SeqTex pipeline loaded successfully.")
view_id_map = {
"First View": 0,
"Second View": 1,
"Third View": 2,
"Fourth View": 3
}
view_id = view_id_map[selected_view]
progress(0.3, desc="Encoding position and normal images...")
nat_seq = torch.cat([position_images.unsqueeze(0), normal_images.unsqueeze(0)], dim=0) # 1 F H W C
uv_seq = torch.cat([position_map.unsqueeze(0), normal_map.unsqueeze(0)], dim=0)
nat_latents = encode_images(nat_seq, encode_as_first=True) # B C F H W
uv_latents = encode_images(uv_seq, encode_as_first=True) # B C F' H' W'
nat_pos_latents, nat_norm_latents = torch.chunk(nat_latents, 2, dim=0)
uv_pos_latents, uv_norm_latents = torch.chunk(uv_latents, 2, dim=0)
nat_geo_latents = torch.cat([nat_pos_latents, nat_norm_latents], dim=1)
uv_geo_latents = torch.cat([uv_pos_latents, uv_norm_latents], dim=1)
cond_model_latents = (nat_geo_latents, uv_geo_latents)
num_frames = cfg.num_views * (2 ** sum(VAE.config.temperal_downsample))
uv_num_frames = cfg.uv_num_views * (2 ** sum(VAE.config.temperal_downsample))
progress(0.4, desc="Encoding condition image...")
if isinstance(condition_image, Image.Image):
condition_image = condition_image.resize((cfg.mv_width, cfg.mv_height), Image.LANCZOS)
# Convert PIL Image to tensor
condition_image = convert_img_to_tensor(condition_image, device=device)
condition_image = condition_image.unsqueeze(0).unsqueeze(0)
gt_latents = (encode_images(condition_image, encode_as_first=True), None)
progress(0.5, desc="Generating texture with SeqTex...")
latents = tex_pipe(
prompt=text_prompt,
negative_prompt=negative_prompt,
num_frames=num_frames,
generator=torch.Generator(device=device).manual_seed(cfg.eval_seed),
num_inference_steps=cfg.eval_num_inference_steps,
guidance_scale=cfg.eval_guidance_scale,
height=cfg.mv_height,
width=cfg.mv_width,
output_type="latent",
cond_model_latents=cond_model_latents,
# mask_indices=test_mask_indices,
uv_height=cfg.uv_height,
uv_width=cfg.uv_width,
uv_num_frames=uv_num_frames,
treat_as_first=True,
gt_condition=gt_latents,
inference_img_cond_frame=view_id,
use_qk_geometry=True,
task_type="img2tex", # img2tex
progress=progress,
).frames
mv_latents, uv_latents = latents
progress(0.9, desc="Decoding generated latents to images...")
mv_frames = decode_images(mv_latents, decode_as_first=True) # B C 4 H W
uv_frames = decode_images(uv_latents, decode_as_first=True) # B C 1 H W
uv_map_pred = uv_frames[:, :, -1, ...]
uv_map_pred.squeeze_(0)
mv_out = rearrange(mv_frames[:, :, :cfg.num_views, ...], "B C (F N) H W -> N C (B H) (F W)", N=1)[0]
mv_out = torch.clamp(mv_out, 0.0, 1.0)
uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0)
progress(1, desc="Texture generated successfully.")
uv_map_pred_path = save_tensor_to_file(uv_map_pred, prefix="uv_map_pred")
return uv_map_pred_path, tensor_to_pil(uv_map_pred, normalize=False), tensor_to_pil(mv_out, normalize=False), "Step 3: Texture generated successfully."
|