Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Update gradio_app.py
Browse files- gradio_app.py +12 -0
    	
        gradio_app.py
    CHANGED
    
    | @@ -88,6 +88,8 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig | |
| 88 | 
             
                try:
         | 
| 89 | 
             
                    # Generate image using FLUX
         | 
| 90 | 
             
                    generator = torch.Generator(device=device).manual_seed(seed)
         | 
|  | |
|  | |
| 91 | 
             
                    generated_image = flux_pipe(
         | 
| 92 | 
             
                        prompt=prompt,
         | 
| 93 | 
             
                        width=width,
         | 
| @@ -98,19 +100,24 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig | |
| 98 | 
             
                    ).images[0]
         | 
| 99 |  | 
| 100 | 
             
                    # Process the generated image
         | 
|  | |
| 101 | 
             
                    rgb_image = generated_image.convert('RGB')
         | 
| 102 |  | 
| 103 | 
             
                    # Remove background
         | 
|  | |
| 104 | 
             
                    no_bg_image = bg_remover.process(rgb_image)
         | 
| 105 |  | 
| 106 | 
             
                    # Convert to numpy array to extract mask
         | 
|  | |
| 107 | 
             
                    no_bg_array = np.array(no_bg_image)
         | 
| 108 | 
             
                    mask = (no_bg_array.sum(axis=2) > 0).astype(np.float32)
         | 
| 109 |  | 
| 110 | 
             
                    # Create RGBA image
         | 
|  | |
| 111 | 
             
                    rgba_image = create_rgba_image(rgb_image, mask)
         | 
| 112 |  | 
| 113 | 
             
                    # Auto crop with foreground
         | 
|  | |
| 114 | 
             
                    processed_image = spar3d_utils.foreground_crop(
         | 
| 115 | 
             
                        rgba_image,
         | 
| 116 | 
             
                        crop_ratio=1.3,
         | 
| @@ -118,12 +125,14 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig | |
| 118 | 
             
                        no_crop=False
         | 
| 119 | 
             
                    )
         | 
| 120 |  | 
|  | |
| 121 | 
             
                    # Prepare batch for 3D generation
         | 
| 122 | 
             
                    batch = create_batch(processed_image)
         | 
| 123 | 
             
                    batch = {k: v.to(device) for k, v in batch.items()}
         | 
| 124 |  | 
| 125 | 
             
                    # Generate mesh
         | 
| 126 | 
             
                    with torch.no_grad():
         | 
|  | |
| 127 | 
             
                        with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
         | 
| 128 | 
             
                            trimesh_mesh, _ = spar3d_model.generate_mesh(
         | 
| 129 | 
             
                                batch,
         | 
| @@ -135,8 +144,11 @@ def generate_and_process_3d(prompt: str, seed: int = 42, width: int = 1024, heig | |
| 135 | 
             
                            trimesh_mesh = trimesh_mesh[0]
         | 
| 136 |  | 
| 137 | 
             
                    # Export to GLB
         | 
|  | |
| 138 | 
             
                    temp_dir = tempfile.mkdtemp()
         | 
| 139 | 
             
                    output_path = os.path.join(temp_dir, 'output.glb')
         | 
|  | |
|  | |
| 140 | 
             
                    trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
         | 
| 141 |  | 
| 142 | 
             
                    return output_path, generated_image
         | 
|  | |
| 88 | 
             
                try:
         | 
| 89 | 
             
                    # Generate image using FLUX
         | 
| 90 | 
             
                    generator = torch.Generator(device=device).manual_seed(seed)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    print("[debug] generating the image using Flux")
         | 
| 93 | 
             
                    generated_image = flux_pipe(
         | 
| 94 | 
             
                        prompt=prompt,
         | 
| 95 | 
             
                        width=width,
         | 
|  | |
| 100 | 
             
                    ).images[0]
         | 
| 101 |  | 
| 102 | 
             
                    # Process the generated image
         | 
| 103 | 
            +
                    print("[debug] converting the image to rgb")
         | 
| 104 | 
             
                    rgb_image = generated_image.convert('RGB')
         | 
| 105 |  | 
| 106 | 
             
                    # Remove background
         | 
| 107 | 
            +
                    print("[debug] removing the background by calling bg_remover.process(rgb_image)")
         | 
| 108 | 
             
                    no_bg_image = bg_remover.process(rgb_image)
         | 
| 109 |  | 
| 110 | 
             
                    # Convert to numpy array to extract mask
         | 
| 111 | 
            +
                    print("[debug] converting to numpy array to extract the mask")
         | 
| 112 | 
             
                    no_bg_array = np.array(no_bg_image)
         | 
| 113 | 
             
                    mask = (no_bg_array.sum(axis=2) > 0).astype(np.float32)
         | 
| 114 |  | 
| 115 | 
             
                    # Create RGBA image
         | 
| 116 | 
            +
                    print("[debug] creating the RGBA image using create_rgba_image(rgb_image, mask)")
         | 
| 117 | 
             
                    rgba_image = create_rgba_image(rgb_image, mask)
         | 
| 118 |  | 
| 119 | 
             
                    # Auto crop with foreground
         | 
| 120 | 
            +
                    print(f"[debug] auto-cromming the rgba_image using spar3d_utils.foreground_crop(...). newsize=(COND_WIDTH, COND_HEIGHT) = ({COND_WIDTH}, {COND_HEIGHT})")
         | 
| 121 | 
             
                    processed_image = spar3d_utils.foreground_crop(
         | 
| 122 | 
             
                        rgba_image,
         | 
| 123 | 
             
                        crop_ratio=1.3,
         | 
|  | |
| 125 | 
             
                        no_crop=False
         | 
| 126 | 
             
                    )
         | 
| 127 |  | 
| 128 | 
            +
                    print("[debug] preparing the batch by calling create_batch(processed_image)")
         | 
| 129 | 
             
                    # Prepare batch for 3D generation
         | 
| 130 | 
             
                    batch = create_batch(processed_image)
         | 
| 131 | 
             
                    batch = {k: v.to(device) for k, v in batch.items()}
         | 
| 132 |  | 
| 133 | 
             
                    # Generate mesh
         | 
| 134 | 
             
                    with torch.no_grad():
         | 
| 135 | 
            +
                        print("[debug] calling torch.autocast(....) to generate the mesh")
         | 
| 136 | 
             
                        with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16):
         | 
| 137 | 
             
                            trimesh_mesh, _ = spar3d_model.generate_mesh(
         | 
| 138 | 
             
                                batch,
         | 
|  | |
| 144 | 
             
                            trimesh_mesh = trimesh_mesh[0]
         | 
| 145 |  | 
| 146 | 
             
                    # Export to GLB
         | 
| 147 | 
            +
                    print("[debug] creating tmp dir for the .glb output")
         | 
| 148 | 
             
                    temp_dir = tempfile.mkdtemp()
         | 
| 149 | 
             
                    output_path = os.path.join(temp_dir, 'output.glb')
         | 
| 150 | 
            +
                    
         | 
| 151 | 
            +
                    print("[debug] calling trimesh_mesh.export(...) to export to .glb")
         | 
| 152 | 
             
                    trimesh_mesh.export(output_path, file_type="glb", include_normals=True)
         | 
| 153 |  | 
| 154 | 
             
                    return output_path, generated_image
         | 
