File size: 10,507 Bytes
38dbec8
 
 
 
eecf990
38dbec8
eecf990
 
38dbec8
1c05005
38dbec8
eecf990
a399d55
e2ccc8a
 
38dbec8
eecf990
e2ccc8a
 
 
 
 
38dbec8
eecf990
e2ccc8a
eecf990
1c05005
e2ccc8a
 
eecf990
 
38dbec8
1c05005
 
 
 
 
 
 
eecf990
 
 
 
 
38dbec8
aec7186
 
 
 
c882a68
 
 
 
aec7186
c882a68
aec7186
 
03dc078
daf9fe6
c882a68
2728300
 
 
c882a68
 
03dc078
2728300
 
 
 
 
 
 
4b4ce8a
9e70cab
 
c882a68
03dc078
c882a68
4b4ce8a
 
c882a68
 
4b4ce8a
 
9e70cab
c882a68
4b4ce8a
 
 
 
2728300
9e70cab
 
03dc078
4b4ce8a
c882a68
4b4ce8a
 
 
 
 
c882a68
 
4b4ce8a
c882a68
 
4b4ce8a
 
 
c882a68
 
eecf990
751171e
 
4b4ce8a
9e70cab
4b4ce8a
9e70cab
751171e
4b4ce8a
 
 
 
 
 
 
751171e
 
4b4ce8a
 
 
 
 
 
 
 
 
 
 
 
751171e
 
4b4ce8a
751171e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0a67b8
1c05005
b0a67b8
 
 
 
eecf990
751171e
 
 
 
1c05005
e973397
f779fbc
1c05005
 
 
 
 
 
 
 
 
f779fbc
daf9fe6
 
f779fbc
3b58a26
aec7186
3b58a26
aec7186
3b58a26
 
 
eecf990
751171e
daf9fe6
e973397
1c05005
38dbec8
eecf990
38dbec8
dc16672
 
 
 
38dbec8
751171e
f779fbc
daf9fe6
eecf990
 
751171e
 
 
 
 
 
 
 
 
 
eecf990
 
f779fbc
e973397
1c05005
eecf990
b0a67b8
eecf990
 
287be50
eecf990
 
 
 
f779fbc
e973397
 
f779fbc
 
e973397
eecf990
a6bc9a4
eecf990
 
daf9fe6
751171e
 
e973397
dc16672
e973397
eecf990
1c05005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6bc9a4
 
 
 
1c05005
e973397
 
1c05005
 
 
 
a6bc9a4
1c05005
 
 
eecf990
38dbec8
eecf990
e973397
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import os
import tempfile
from typing import Any
import torch
import numpy as np
from PIL import Image
import gradio as gr
import trimesh
from transparent_background import Remover
from diffusers import DiffusionPipeline

# Import and setup SPAR3D 
os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
import spar3d.utils as spar3d_utils
from spar3d.system import SPAR3D

# Constants
COND_WIDTH = 512
COND_HEIGHT = 512
COND_DISTANCE = 2.2
COND_FOVY = 0.591627
BACKGROUND_COLOR = [0.5, 0.5, 0.5]

# Initialize models
device = spar3d_utils.get_device()
bg_remover = Remover()
spar3d_model = SPAR3D.from_pretrained(
    "stabilityai/stable-point-aware-3d",
    config_name="config.yaml",
    weight_name="model.safetensors"
).eval().to(device)

# Initialize FLUX model
dtype = torch.bfloat16
flux_pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", 
    torch_dtype=dtype
).to(device)

# Initialize camera parameters
c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
    COND_FOVY, COND_HEIGHT, COND_WIDTH
)

def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.Image:
    """Create an RGBA image from RGB image and optional mask."""
    rgba_image = rgb_image.convert('RGBA')
    if mask is not None:
        print("[debug] mask shape before alpha:", mask.shape)
        # Ensure mask is 2D before converting to alpha
        if len(mask.shape) > 2:
            mask = mask.squeeze()
        alpha = Image.fromarray((mask * 255).astype(np.uint8))
        print("[debug] alpha size:", alpha.size)
        rgba_image.putalpha(alpha)
    return rgba_image

def create_batch(input_image: Image.Image) -> dict[str, Any]:
    """Prepare image batch for model input."""
    # Resize and convert input image to numpy array
    resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
    img_array = np.array(resized_image).astype(np.float32) / 255.0
    print("[debug] img_array shape:", img_array.shape)
    
    # Extract RGB and alpha channels
    if img_array.shape[-1] == 4:  # RGBA
        rgb = img_array[..., :3]
        mask = img_array[..., 3:4]
    else:  # RGB
        rgb = img_array
        mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
    
    # Convert to tensors while keeping channel-last format
    rgb = torch.from_numpy(rgb).float()  # [H, W, 3]
    mask = torch.from_numpy(mask).float()  # [H, W, 1]
    print("[debug] rgb tensor shape:", rgb.shape)
    print("[debug] mask tensor shape:", mask.shape)
    
    # Create background blend (match channel-last format)
    bg_tensor = torch.tensor(BACKGROUND_COLOR).view(1, 1, 3)  # [1, 1, 3]
    print("[debug] bg_tensor shape:", bg_tensor.shape)
    
    # Blend RGB with background using mask (all in channel-last format)
    rgb_cond = torch.lerp(bg_tensor, rgb, mask)  # [H, W, 3]
    print("[debug] rgb_cond shape after blend:", rgb_cond.shape)

    # Move channels to correct dimension and add batch dimension
    # Important: For SPAR3D image tokenizer, we need [B, H, W, C] format
    rgb_cond = rgb_cond.unsqueeze(0)  # [1, H, W, 3]
    mask = mask.unsqueeze(0)  # [1, H, W, 1]
    
    print("[debug] rgb_cond final shape:", rgb_cond.shape)
    print("[debug] mask final shape:", mask.shape)

    # Create the batch dictionary
    batch = {
        "rgb_cond": rgb_cond,  # [1, H, W, 3]
        "mask_cond": mask,  # [1, H, W, 1]
        "c2w_cond": c2w_cond.unsqueeze(0),  # [1, 4, 4]
        "intrinsic_cond": intrinsic.unsqueeze(0),  # [1, 3, 3]
        "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),  # [1, 3, 3]
    }
    
    print("\nFinal batch shapes:")
    for k, v in batch.items():
        print(f"[debug] {k} final shape:", v.shape)
    print("\nrgb_cond max:", batch["rgb_cond"].max())
    print("rgb_cond min:", batch["rgb_cond"].min())
    print("mask_cond unique values:", torch.unique(batch["mask_cond"]))
        
    return batch

def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
    """Process batch through model and generate point cloud."""
    print("\n[debug] Starting forward_model")
    print("[debug] Input rgb_cond shape:", batch["rgb_cond"].shape)
    print("[debug] Input mask_cond shape:", batch["mask_cond"].shape)
    
    batch_size = batch["rgb_cond"].shape[0]
    assert batch_size == 1, f"Expected batch size 1, got {batch_size}"
    
    # Print value ranges for debugging
    print("\nValue ranges:")
    print("rgb_cond max:", batch["rgb_cond"].max())
    print("rgb_cond min:", batch["rgb_cond"].min())
    print("mask_cond unique values:", torch.unique(batch["mask_cond"]))
    
    # Generate point cloud tokens
    print("\n[debug] Generating point cloud tokens")
    try:
        cond_tokens = system.forward_pdiff_cond(batch)
        print("[debug] cond_tokens shape:", cond_tokens.shape)
    except Exception as e:
        print("\n[ERROR] Failed in forward_pdiff_cond:")
        print(e)
        print("\nInput tensor properties:")
        print("rgb_cond dtype:", batch["rgb_cond"].dtype)
        print("rgb_cond device:", batch["rgb_cond"].device)
        print("rgb_cond requires_grad:", batch["rgb_cond"].requires_grad)
        raise
    
    # Sample points
    print("\n[debug] Sampling points")
    sample_iter = system.sampler.sample_batch_progressive(
        batch_size,
        cond_tokens,
        guidance_scale=guidance_scale,
        device=device
    )
    
    # Get final samples
    for x in sample_iter:
        samples = x["xstart"]
    
    print("[debug] samples shape before permute:", samples.shape)
    pc_cond = samples.permute(0, 2, 1).float()
    print("[debug] pc_cond shape after permute:", pc_cond.shape)
    
    # Normalize point cloud
    pc_cond = spar3d_utils.normalize_pc_bbox(pc_cond)
    print("[debug] pc_cond shape after normalize:", pc_cond.shape)
    
    # Subsample to 512 points
    pc_cond = pc_cond[:, torch.randperm(pc_cond.shape[1])[:512]]
    print("[debug] pc_cond final shape:", pc_cond.shape)
    
    return pc_cond

def generate_and_process_3d(prompt: str, seed: int = 42) -> tuple[str | None, Image.Image | None]:
    """Generate image from prompt and convert to 3D model."""

    width: int = 1024
    height: int = 1024
    
    try:
        # Set random seeds
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        # Generate image using FLUX
        generator = torch.Generator(device=device).manual_seed(seed)
        print("[debug] generating the image using Flux")
        generated_image = flux_pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=4,
            generator=generator,
            guidance_scale=0.0
        ).images[0]
        
        print("[debug] converting the image to rgb")
        rgb_image = generated_image.convert('RGB')
        
        print("[debug] removing the background by calling bg_remover.process(rgb_image)")
        # bg_remover returns a PIL Image already, no need to convert
        no_bg_image = bg_remover.process(rgb_image)
        print(f"[debug] no_bg_image type: {type(no_bg_image)}, mode: {no_bg_image.mode}")
        
        # Convert to RGBA if not already
        rgba_image = no_bg_image.convert('RGBA')
        print(f"[debug] rgba_image mode: {rgba_image.mode}")
        
        print("[debug] auto-cropping the rgba_image using spar3d_utils.foreground_crop(...)")
        processed_image = spar3d_utils.foreground_crop(
            rgba_image,
            crop_ratio=1.3,
            newsize=(COND_WIDTH, COND_HEIGHT),
            no_crop=False
        )
        
        # Show the processed image alpha channel for debugging
        alpha = np.array(processed_image)[:, :, 3]
        print(f"[debug] Alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, unique: {np.unique(alpha)}")

        # Prepare batch for processing
        print("[debug] preparing the batch by calling create_batch(processed_image)")
        batch = create_batch(processed_image)
        batch = {k: v.to(device) for k, v in batch.items()}

        # Generate point cloud
        pc_cond = forward_model(
            batch,
            spar3d_model,
            guidance_scale=3.0,
            seed=seed,
            device=device
        )
        batch["pc_cond"] = pc_cond

        # Generate mesh
        with torch.no_grad():
            print("[debug] calling torch.autocast(....) to generate the mesh")
            with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
                trimesh_mesh, _ = spar3d_model.generate_mesh(
                    batch,
                    2048,  # texture_resolution
                    remesh="none",
                    vertex_count=-1,
                    estimate_illumination=True
                )
                trimesh_mesh = trimesh_mesh[0]

        # Export to GLB
        print("[debug] creating tmp dir for the .glb output")
        temp_dir = tempfile.mkdtemp()
        output_path = os.path.join(temp_dir, 'output.glb')
        
        print("[debug] calling trimesh_mesh.export(...) to export to .glb")
        trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
        
        return output_path, output_path, generated_image
        
    except Exception as e:
        print(f"Error during generation: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None
        
# Create Gradio interface
demo = gr.Interface(
    fn=generate_and_process_3d,
    inputs=[
        gr.Text(
            label="Enter your prompt",
            placeholder="Describe what you want to generate..."
        ),
        gr.Slider(
            label="Seed",
            minimum=0,
            maximum=np.iinfo(np.int32).max,
            step=1,
            value=42
        )
    ],
    outputs=[
        gr.Model3D(
            label="3D Model Preview",
            clear_color=[0.0, 0.0, 0.0, 0.0],
        ),
        gr.File(
            label="Download 3D Model",
            file_types=[".glb"]
        ),
        gr.Image(
            label="Generated Image",
            type="pil"
        ),
    ],
    title="Text to 3D Model Generator",
    description="Enter a text prompt to generate an image that will be converted into a 3D model",
)

if __name__ == "__main__":
    demo.queue().launch()