File size: 11,430 Bytes
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d4bcdf
 
 
1d5bb62
 
 
 
 
 
 
 
6d4bcdf
 
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d4bcdf
1d5bb62
 
6d4bcdf
1d5bb62
 
 
 
 
 
 
 
 
6d4bcdf
 
 
 
1d5bb62
6d4bcdf
 
 
 
 
 
 
1d5bb62
6d4bcdf
1d5bb62
6d4bcdf
1d5bb62
 
6d4bcdf
1d5bb62
 
6d4bcdf
1d5bb62
 
6d4bcdf
1d5bb62
 
 
 
 
 
 
 
 
 
6d4bcdf
 
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
6d4bcdf
 
 
 
 
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d4bcdf
 
 
 
 
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d4bcdf
1d5bb62
 
 
 
 
 
 
 
 
 
6d4bcdf
 
 
1d5bb62
 
6d4bcdf
1d5bb62
6d4bcdf
 
1d5bb62
 
6d4bcdf
 
1d5bb62
 
 
6d4bcdf
1d5bb62
6d4bcdf
 
 
 
 
1d5bb62
6d4bcdf
 
 
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d4bcdf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import os
import threading
from dataclasses import dataclass
from urllib.parse import urlparse

import gradio as gr
import numpy as np
import spaces
import torch
from diffusers.models import AutoencoderKLWan
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from einops import rearrange
from jaxtyping import Float
from PIL import Image
from torch import Tensor

from wan.pipeline_wan_t2tex_extra import WanT2TexPipeline
from wan.wan_t2tex_transformer_3d_extra import WanT2TexTransformer3DModel

from . import tensor_to_pil
from utils.file_utils import save_tensor_to_file, load_tensor_from_file

TEX_PIPE = None
VAE = None
LATENTS_MEAN, LATENTS_STD = None, None
TEX_PIPE_LOCK = threading.Lock()

@dataclass
class Config:
    video_base_name: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
    seqtex_transformer_path: str = "VAST-AI/SeqTex-Transformer"
    min_noise_level_index: int = 15 # refer to paper [WorldMem](https://arxiv.org/pdf/2504.12369v1)

    num_views: int = 4
    uv_num_views: int = 1
    mv_height: int = 512
    mv_width: int = 512
    uv_height: int = 1024
    uv_width: int = 1024

    flow_shift: float = 5.0
    eval_guidance_scale: float = 1.0
    eval_num_inference_steps: int = 30
    eval_seed: int = 42

cfg = Config()

def get_seqtex_pipe():
    """
    Lazy load the SeqTex pipeline for texture generation.
    Must be called within @spaces.GPU context.
    """
    global TEX_PIPE, VAE, LATENTS_MEAN, LATENTS_STD
    if TEX_PIPE is not None:
        return TEX_PIPE
    gr.Info("First called, loading SeqTex pipeline... It may take about 1 minute.")
    with TEX_PIPE_LOCK:
        if TEX_PIPE is not None:
            return TEX_PIPE
            
        # Load transformer with auto-configured LoRA adapter first
        transformer = WanT2TexTransformer3DModel.from_pretrained(
            cfg.seqtex_transformer_path,
            token=os.environ["SEQTEX_SPACE_TOKEN"]
        )

        assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to VAST-AI/SeqTex-Transformer."
        # Pipeline - pass the pre-loaded transformer to avoid re-loading
        TEX_PIPE = WanT2TexPipeline.from_pretrained(
            cfg.video_base_name,
            transformer=transformer,
            torch_dtype=torch.bfloat16
        )
        del(transformer)

        VAE = AutoencoderKLWan.from_pretrained(cfg.video_base_name, subfolder="vae", torch_dtype=torch.float32)
        TEX_PIPE.vae = VAE
        
        # Some useful parameters - delay CUDA initialization until GPU context
        LATENTS_MEAN = torch.tensor(VAE.config.latents_mean).view(
            1, VAE.config.z_dim, 1, 1, 1
        ).to(torch.float32)
        LATENTS_STD = 1.0 / torch.tensor(VAE.config.latents_std).view(
            1, VAE.config.z_dim, 1, 1, 1
        ).to(torch.float32)

        scheduler: FlowMatchEulerDiscreteScheduler = (
            FlowMatchEulerDiscreteScheduler.from_config(
                TEX_PIPE.scheduler.config, shift=cfg.flow_shift
            )
        )
        min_noise_level_index = scheduler.config.num_train_timesteps - cfg.min_noise_level_index # in our scheduler, the first time is noise. set to 1000 - 15 typically
        setattr(TEX_PIPE, "min_noise_level_index", min_noise_level_index)
        min_noise_level_timestep = scheduler.timesteps[min_noise_level_index]
        setattr(TEX_PIPE, "min_noise_level_timestep", min_noise_level_timestep)
        setattr(TEX_PIPE, "min_noise_level_sigma", min_noise_level_timestep / 1000.)   
        return TEX_PIPE.to("cuda")

@torch.amp.autocast('cuda', dtype=torch.float32)
def encode_images(
    images: Float[Tensor, "B F H W C"], encode_as_first: bool = False
) -> Float[Tensor, "B C' F H/8 W/8"]:
    """
    Encode images to latent space using VAE.
    Every frame is seen as a separate image, without any awareness of the temporal dimension.
    :param images: Input images tensor with shape [B, F, H, W, C].
    :param encode_as_first: Whether to encode all frames as the first frame.
    :return: Encoded latents with shape [B, C', F, H/8, W/8].
    """
    global VAE, LATENTS_MEAN, LATENTS_STD
    VAE = VAE.to("cuda").requires_grad_(False)
    LATENTS_MEAN = LATENTS_MEAN.to("cuda")
    LATENTS_STD = LATENTS_STD.to("cuda")

    if images.min() < - 0.1:
        # images are in [-1, 1] range
        images = (images + 1.0) / 2.0  # Normalize to [0, 1] range
    if encode_as_first:
        # encode all the frame as the first one
        B = images.shape[0]
        images = rearrange(images, "B F H W C -> (B F) C 1 H W")
        latents = (VAE.encode(images).latent_dist.sample() - LATENTS_MEAN) * LATENTS_STD
        latents = rearrange(latents, "(B F) C 1 H W -> B C F H W", B=B)
    else:
        raise NotImplementedError("Currently only support encode as first frame.")

    return latents

@torch.amp.autocast('cuda', dtype=torch.float32)
def decode_images(latents: Float[Tensor, "B C F H W"], decode_as_first: bool = False):
    """
    Decode latents back to images using VAE.
    :param latents: Input latents with shape [B, C, F, H, W].
    :param decode_as_first: Whether to decode all frames as the first frame.
    :return: Decoded images with shape [B, C, F*Nv, H*8, W*8].
    """
    global VAE, LATENTS_MEAN, LATENTS_STD
    VAE = VAE.to("cuda").requires_grad_(False)
    LATENTS_MEAN = LATENTS_MEAN.to("cuda")
    LATENTS_STD = LATENTS_STD.to("cuda")

    if decode_as_first:
        F = latents.shape[2]
        latents = latents.to(VAE.dtype)
        latents = latents / LATENTS_STD + LATENTS_MEAN
        latents = rearrange(latents, "B C F H W -> (B F) C 1 H W")
        images = VAE.decode(latents, return_dict=False)[0]
        images = rearrange(images, "(B F) C Nv H W -> B C (F Nv) H W", F=F, Nv=1)
    else:
        raise NotImplementedError("Currently only support decode as first frame.")
    return images

def convert_img_to_tensor(image: Image.Image, device="cuda") -> Float[Tensor, "H W C"]:
    """
    Convert a PIL Image to a tensor. If Image is RGBA, mask it with black background using a-channel mask.
    :param image: PIL Image to convert. [0, 255]
    :param device: Target device for the tensor.
    :return: Tensor representation of the image. [0.0, 1.0], still [H, W, C]
    """
    # Convert to RGBA to ensure alpha channel exists
    image = image.convert("RGBA")
    np_img = np.array(image)
    rgb = np_img[..., :3]
    alpha = np_img[..., 3:4] / 255.0  # Normalize alpha to [0, 1]
    # Blend with black background using alpha mask
    rgb = rgb * alpha
    rgb = rgb.astype(np.float32) / 255.0  # Normalize to [0, 1]
    tensor = torch.from_numpy(rgb)
    if device != "cpu":
        tensor = tensor.to(device)
    return tensor

@spaces.GPU(duration=90)
@torch.no_grad
@torch.inference_mode
def generate_texture(position_map_path, normal_map_path, position_images_path, normal_images_path, condition_image, text_prompt, selected_view, negative_prompt=None, device="cuda", progress=gr.Progress()):
    """
    Use SeqTex to generate texture for the mesh based on the image condition.
    :param position_images_path: File path to position images tensor
    :param normal_images_path: File path to normal images tensor
    :param condition_image: Image condition generated from the selected view.
    :param text_prompt: Text prompt for texture generation.
    :param selected_view: The view selected for generating the image condition.
    :return: File paths of generated texture map and multi-view frames, and PIL images
    """
    position_map = load_tensor_from_file(position_map_path, map_location=device)
    normal_map = load_tensor_from_file(normal_map_path, map_location=device)
    position_images = load_tensor_from_file(position_images_path, map_location=device)
    normal_images = load_tensor_from_file(normal_images_path, map_location=device)

    progress(0, desc="Loading SeqTex pipeline...")
    tex_pipe = get_seqtex_pipe()
    # assert tex_pipe is in gpu
    assert tex_pipe.device.type == "cuda", "SeqTex pipeline must be loaded in GPU context."
    progress(0.2, desc="SeqTex pipeline loaded successfully.")
    view_id_map = {
        "First View": 0,
        "Second View": 1,
        "Third View": 2,
        "Fourth View": 3
    }
    view_id = view_id_map[selected_view]

    progress(0.3, desc="Encoding position and normal images...")
    nat_seq = torch.cat([position_images.unsqueeze(0), normal_images.unsqueeze(0)], dim=0) # 1 F H W C
    uv_seq = torch.cat([position_map.unsqueeze(0), normal_map.unsqueeze(0)], dim=0)
    nat_latents = encode_images(nat_seq, encode_as_first=True) # B C F H W
    uv_latents = encode_images(uv_seq, encode_as_first=True) # B C F' H' W'
    nat_pos_latents, nat_norm_latents = torch.chunk(nat_latents, 2, dim=0)
    uv_pos_latents, uv_norm_latents = torch.chunk(uv_latents, 2, dim=0)
    nat_geo_latents = torch.cat([nat_pos_latents, nat_norm_latents], dim=1)
    uv_geo_latents = torch.cat([uv_pos_latents, uv_norm_latents], dim=1)
    cond_model_latents = (nat_geo_latents, uv_geo_latents)

    num_frames = cfg.num_views * (2 ** sum(VAE.config.temperal_downsample))
    uv_num_frames = cfg.uv_num_views * (2 ** sum(VAE.config.temperal_downsample))
    
    progress(0.4, desc="Encoding condition image...")
    if isinstance(condition_image, Image.Image):
        condition_image = condition_image.resize((cfg.mv_width, cfg.mv_height), Image.LANCZOS)
        # Convert PIL Image to tensor
        condition_image = convert_img_to_tensor(condition_image, device=device)
        condition_image = condition_image.unsqueeze(0).unsqueeze(0)
    gt_latents = (encode_images(condition_image, encode_as_first=True), None)

    progress(0.5, desc="Generating texture with SeqTex...")
    latents = tex_pipe(
        prompt=text_prompt,
        negative_prompt=negative_prompt,
        num_frames=num_frames,
        generator=torch.Generator(device=device).manual_seed(cfg.eval_seed),
        num_inference_steps=cfg.eval_num_inference_steps,
        guidance_scale=cfg.eval_guidance_scale,
        height=cfg.mv_height,
        width=cfg.mv_width,
        output_type="latent",

        cond_model_latents=cond_model_latents,
        # mask_indices=test_mask_indices,
        uv_height=cfg.uv_height,
        uv_width=cfg.uv_width,
        uv_num_frames=uv_num_frames,
        treat_as_first=True,
        gt_condition=gt_latents,
        inference_img_cond_frame=view_id,
        use_qk_geometry=True,
        task_type="img2tex", # img2tex
        progress=progress,
    ).frames

    mv_latents, uv_latents = latents
    
    progress(0.9, desc="Decoding generated latents to images...")
    mv_frames = decode_images(mv_latents, decode_as_first=True) # B C 4 H W
    uv_frames = decode_images(uv_latents, decode_as_first=True) # B C 1 H W
    
    uv_map_pred = uv_frames[:, :, -1, ...]
    uv_map_pred.squeeze_(0)
    mv_out = rearrange(mv_frames[:, :, :cfg.num_views, ...], "B C (F N) H W -> N C (B H) (F W)", N=1)[0]

    mv_out = torch.clamp(mv_out, 0.0, 1.0)
    uv_map_pred = torch.clamp(uv_map_pred, 0.0, 1.0)

    progress(1, desc="Texture generated successfully.")
    uv_map_pred_path = save_tensor_to_file(uv_map_pred, prefix="uv_map_pred")
    return uv_map_pred_path, tensor_to_pil(uv_map_pred, normalize=False), tensor_to_pil(mv_out, normalize=False), "Step 3: Texture generated successfully."