File size: 7,317 Bytes
38dbec8
 
 
 
eecf990
38dbec8
eecf990
 
38dbec8
1c05005
38dbec8
eecf990
a399d55
e2ccc8a
 
38dbec8
eecf990
e2ccc8a
 
 
 
 
38dbec8
eecf990
e2ccc8a
eecf990
1c05005
e2ccc8a
 
eecf990
 
38dbec8
1c05005
 
 
 
 
 
 
eecf990
 
 
 
 
38dbec8
aec7186
 
 
 
c882a68
 
 
 
aec7186
c882a68
aec7186
 
c882a68
daf9fe6
c882a68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eecf990
daf9fe6
1c05005
eecf990
1c05005
e973397
f779fbc
 
1c05005
 
 
 
 
 
 
 
 
daf9fe6
f779fbc
daf9fe6
 
aec7186
f779fbc
aec7186
 
 
f779fbc
aec7186
 
eecf990
daf9fe6
f779fbc
aec7186
eecf990
daf9fe6
f779fbc
daf9fe6
e973397
1c05005
38dbec8
eecf990
38dbec8
 
f779fbc
daf9fe6
 
eecf990
 
 
 
f779fbc
e973397
1c05005
eecf990
1c05005
eecf990
 
287be50
eecf990
 
 
 
f779fbc
e973397
 
f779fbc
 
e973397
eecf990
e973397
eecf990
 
daf9fe6
e973397
1c05005
e973397
eecf990
1c05005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e973397
 
1c05005
 
 
 
 
 
 
 
eecf990
38dbec8
eecf990
e973397
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import tempfile
from typing import Any
import torch
import numpy as np
from PIL import Image
import gradio as gr
import trimesh
from transparent_background import Remover
from diffusers import DiffusionPipeline

# Import and setup SPAR3D 
os.system("USE_CUDA=1 pip install -vv --no-build-isolation ./texture_baker ./uv_unwrapper")
import spar3d.utils as spar3d_utils
from spar3d.system import SPAR3D

# Constants
COND_WIDTH = 512
COND_HEIGHT = 512
COND_DISTANCE = 2.2
COND_FOVY = 0.591627
BACKGROUND_COLOR = [0.5, 0.5, 0.5]

# Initialize models
device = spar3d_utils.get_device()
bg_remover = Remover()
spar3d_model = SPAR3D.from_pretrained(
    "stabilityai/stable-point-aware-3d",
    config_name="config.yaml",
    weight_name="model.safetensors"
).eval().to(device)

# Initialize FLUX model
dtype = torch.bfloat16
flux_pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", 
    torch_dtype=dtype
).to(device)

# Initialize camera parameters
c2w_cond = spar3d_utils.default_cond_c2w(COND_DISTANCE)
intrinsic, intrinsic_normed_cond = spar3d_utils.create_intrinsic_from_fov_rad(
    COND_FOVY, COND_HEIGHT, COND_WIDTH
)

def create_rgba_image(rgb_image: Image.Image, mask: np.ndarray = None) -> Image.Image:
    """Create an RGBA image from RGB image and optional mask."""
    rgba_image = rgb_image.convert('RGBA')
    if mask is not None:
        print("[debug] mask shape before alpha:", mask.shape)
        # Ensure mask is 2D before converting to alpha
        if len(mask.shape) > 2:
            mask = mask.squeeze()
        alpha = Image.fromarray((mask * 255).astype(np.uint8))
        print("[debug] alpha size:", alpha.size)
        rgba_image.putalpha(alpha)
    return rgba_image
    
def create_batch(input_image: Image.Image) -> dict[str, Any]:
    """Prepare image batch for model input."""
    # Ensure input is RGBA
    if input_image.mode != 'RGBA':
        input_image = input_image.convert('RGBA')
    
    # Resize and convert to numpy array
    resized_image = input_image.resize((COND_WIDTH, COND_HEIGHT))
    img_array = np.array(resized_image).astype(np.float32) / 255.0
    
    print("[debug] img_array shape:", img_array.shape)
    
    # Split into RGB and alpha
    rgb = torch.from_numpy(img_array[..., :3]).float()
    alpha = torch.from_numpy(img_array[..., 3:4]).float()
    
    print("[debug] rgb tensor shape:", rgb.shape)
    print("[debug] alpha tensor shape:", alpha.shape)
    
    # Create background blend using torch.lerp()
    bg_tensor = torch.tensor(BACKGROUND_COLOR)[None, None, :]
    print("[debug] bg_tensor shape:", bg_tensor.shape)
    
    rgb_cond = torch.lerp(bg_tensor, rgb, alpha)
    print("[debug] rgb_cond shape:", rgb_cond.shape)

    batch = {
        "rgb_cond": rgb_cond.unsqueeze(0),
        "mask_cond": alpha.unsqueeze(0),
        "c2w_cond": c2w_cond.unsqueeze(0),
        "intrinsic_cond": intrinsic.unsqueeze(0), 
        "intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
    }
    
    # Final shapes check
    for k, v in batch.items():
        print(f"[debug] {k} final shape:", v.shape)
        
    return batch

def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, height: int = 1024) -> tuple[str | None, Image.Image | None]:
    """Generate image from prompt and convert to 3D model."""
    try:
        # Generate image using FLUX
        generator = torch.Generator(device=device).manual_seed(seed)

        print("[debug] generating the image using Flux")
        generated_image = flux_pipe(
            prompt=prompt,
            width=width,
            height=height,
            num_inference_steps=4,
            generator=generator,
            guidance_scale=0.0
        ).images[0]
        
        # Process the generated image
        print("[debug] converting the image to rgb")
        rgb_image = generated_image.convert('RGB')
        
        # Remove background
        print("[debug] removing the background by calling bg_remover.process(rgb_image)")
        no_bg_image = bg_remover.process(rgb_image)
        
        # Convert to numpy array to extract mask
        print("[debug] converting to numpy array to extract the mask")
        no_bg_array = np.array(no_bg_image)
        mask = (no_bg_array.sum(axis=2) > 0).astype(np.float32)
        
        # Create RGBA image
        print("[debug] creating the RGBA image using create_rgba_image(rgb_image, mask)")
        rgba_image = create_rgba_image(rgb_image, mask)
        
        # Auto crop with foreground
        print(f"[debug] auto-cromming the rgba_image using spar3d_utils.foreground_crop(...). newsize=(COND_WIDTH, COND_HEIGHT) = ({COND_WIDTH}, {COND_HEIGHT})")
        processed_image = spar3d_utils.foreground_crop(
            rgba_image,
            crop_ratio=1.3,
            newsize=(COND_WIDTH, COND_HEIGHT),
            no_crop=False
        )

        print("[debug] preparing the batch by calling create_batch(processed_image)")
        # Prepare batch for 3D generation
        batch = create_batch(processed_image)
        batch = {k: v.to(device) for k, v in batch.items()}

        # Generate mesh
        with torch.no_grad():
            print("[debug] calling torch.autocast(....) to generate the mesh")
            with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
                trimesh_mesh, _ = spar3d_model.generate_mesh(
                    batch,
                    1024,  # texture_resolution
                    remesh="none",
                    vertex_count=-1,
                    estimate_illumination=True
                )
                trimesh_mesh = trimesh_mesh[0]

        # Export to GLB
        print("[debug] creating tmp dir for the .glb output")
        temp_dir = tempfile.mkdtemp()
        output_path = os.path.join(temp_dir, 'output.glb')
        
        print("[debug] calling trimesh_mesh.export(...) to export to .glb")
        trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
        
        return output_path, generated_image
        
    except Exception as e:
        print(f"Error during generation: {str(e)}")
        return None, None

# Create Gradio interface
demo = gr.Interface(
    fn=generate_and_process_3d,
    inputs=[
        gr.Text(
            label="Enter your prompt",
            placeholder="Describe what you want to generate..."
        ),
        gr.Slider(
            label="Seed",
            minimum=0,
            maximum=np.iinfo(np.int32).max,
            step=1,
            value=42
        ),
        gr.Slider(
            label="Width",
            minimum=256,
            maximum=2048,
            step=32,
            value=1024
        ),
        gr.Slider(
            label="Height",
            minimum=256,
            maximum=2048,
            step=32,
            value=1024
        )
    ],
    outputs=[
        gr.File(
            label="Download 3D Model",
            file_types=[".glb"]
        ),
        gr.Image(
            label="Generated Image",
            type="pil"
        )
    ],
    title="Text to 3D Model Generator",
    description="Enter a text prompt to generate an image that will be converted into a 3D model",
)

if __name__ == "__main__":
    demo.queue().launch()