import spaces import argparse import os import time from os import path from safetensors.torch import load_file from huggingface_hub import hf_hub_download import imageio import numpy as np import torch import rembg from PIL import Image from torchvision.transforms import v2 from pytorch_lightning import seed_everything from omegaconf import OmegaConf from einops import rearrange, repeat from tqdm import tqdm from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler import gradio as gr import shutil import tempfile from functools import partial from optimum.quanto import quantize, qfloat8, freeze from flux_8bit_lora import FluxPipeline from src.utils.train_util import instantiate_from_config from src.utils.camera_util import ( FOV_to_intrinsics, get_zero123plus_input_cameras, get_circular_camera_poses, ) from src.utils.mesh_util import save_obj, save_glb from src.utils.infer_util import remove_background, resize_foreground, images_to_video # Set up cache path cache_path = path.join(path.dirname(path.abspath(__file__)), "models") os.environ["TRANSFORMERS_CACHE"] = cache_path os.environ["HF_HUB_CACHE"] = cache_path os.environ["HF_HOME"] = cache_path huggingface_token = os.getenv("HUGGINGFACE_TOKEN") if not path.exists(cache_path): os.makedirs(cache_path, exist_ok=True) torch.backends.cuda.matmul.allow_tf32 = True class timer: def __init__(self, method_name="timed process"): self.method = method_name def __enter__(self): self.start = time.time() print(f"{self.method} starts") def __exit__(self, exc_type, exc_val, exc_tb): end = time.time() print(f"{self.method} took {str(round(end - self.start, 2))}s") def find_cuda(): cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') if cuda_home and os.path.exists(cuda_home): return cuda_home nvcc_path = shutil.which('nvcc') if nvcc_path: cuda_path = os.path.dirname(os.path.dirname(nvcc_path)) return cuda_path return None cuda_path = find_cuda() if cuda_path: print(f"CUDA installation found at: {cuda_path}") else: print("CUDA installation not found") base_model = "black-forest-labs/FLUX.1-dev" pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16, token=huggingface_token) print('Loading and fusing lora, please wait...') pipe.load_lora_weights(hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors")) # We need this scaling because SimpleTuner fixes the alpha to 16, might be fixed later in diffusers # See https://github.com/huggingface/diffusers/issues/9134 pipe.fuse_lora(lora_scale=1.) pipe.unload_lora_weights() print('Quantizing, please wait...') quantize(pipe.transformer, qfloat8) freeze(pipe.transformer) print('Model quantized!') pipe.to('cuda') # Load 3D generation models config_path = 'configs/instant-mesh-large.yaml' config = OmegaConf.load(config_path) config_name = os.path.basename(config_path).replace('.yaml', '') model_config = config.model_config infer_config = config.infer_config IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False device = torch.device('cuda') # Load diffusion model for 3D generation print('Loading diffusion model ...') pipeline = DiffusionPipeline.from_pretrained( "sudo-ai/zero123plus-v1.2", custom_pipeline="zero123plus", torch_dtype=torch.float16, ) pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( pipeline.scheduler.config, timestep_spacing='trailing' ) # Load custom white-background UNet unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") state_dict = torch.load(unet_ckpt_path, map_location='cpu') pipeline.unet.load_state_dict(state_dict, strict=True) pipeline = pipeline.to(device) # Load reconstruction model print('Loading reconstruction model ...') model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model") model = instantiate_from_config(model_config) state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} model.load_state_dict(state_dict, strict=True) model = model.to(device) print('Loading Finished!') def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation) if is_flexicubes: cameras = torch.linalg.inv(c2ws) cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1) else: extrinsics = c2ws.flatten(-2) intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2) cameras = torch.cat([extrinsics, intrinsics], dim=-1) cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1) return cameras def preprocess(input_image, do_remove_background): rembg_session = rembg.new_session() if do_remove_background else None if do_remove_background: input_image = remove_background(input_image, rembg_session) input_image = resize_foreground(input_image, 0.85) return input_image ts_cutoff = 2 @spaces.GPU def generate_flux_image(prompt, height, width, steps, scales, seed): return pipe( prompt=prompt, width=int(height), height=int(width), num_inference_steps=int(steps), generator=torch.Generator().manual_seed(int(seed)), guidance_scale=float(scales), timestep_to_start_cfg=ts_cutoff, ).images[0] @spaces.GPU def generate_mvs(input_image, sample_steps, sample_seed): seed_everything(sample_seed) z123_image = pipeline( input_image, num_inference_steps=sample_steps ).images[0] show_image = np.asarray(z123_image, dtype=np.uint8) show_image = torch.from_numpy(show_image) show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2) show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3) show_image = Image.fromarray(show_image.numpy()) return z123_image, show_image @spaces.GPU def make3d(images): global model if IS_FLEXICUBES: model.init_flexicubes_geometry(device, use_renderer=False) model = model.eval() images = np.asarray(images, dtype=np.float32) / 255.0 images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device) render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device) images = images.unsqueeze(0).to(device) images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name mesh_basename = os.path.basename(mesh_fpath).split('.')[0] mesh_dirname = os.path.dirname(mesh_fpath) mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb") with torch.no_grad(): planes = model.forward_planes(images, input_cameras) mesh_out = model.extract_mesh( planes, use_texture_map=False, **infer_config, ) vertices, faces, vertex_colors = mesh_out vertices = vertices[:, [1, 2, 0]] save_glb(vertices, faces, vertex_colors, mesh_glb_fpath) save_obj(vertices, faces, vertex_colors, mesh_fpath) return mesh_fpath, mesh_glb_fpath with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """

Flux Image to 3D Model Generator

""" ) with gr.Row(): with gr.Column(scale=3): prompt = gr.Textbox( label="Your Image Description", placeholder="E.g., A serene landscape with mountains and a lake at sunset", lines=3 ) with gr.Accordion("Advanced Settings", open=False): with gr.Group(): with gr.Row(): height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024) width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024) with gr.Row(): steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28) scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5) seed = gr.Number(label="Seed (for reproducibility)", value=3413, precision=0) generate_btn = gr.Button("Generate 3D Model", variant="primary") with gr.Column(scale=4): flux_output = gr.Image(label="Generated Flux Image") mv_show_images = gr.Image(label="Generated Multi-views") with gr.Row(): with gr.Tab("OBJ"): output_model_obj = gr.Model3D(label="Output Model (OBJ Format)") with gr.Tab("GLB"): output_model_glb = gr.Model3D(label="Output Model (GLB Format)") mv_images = gr.State() def process_pipeline(prompt, height, width, steps, scales, seed): flux_image = generate_flux_image(prompt, height, width, steps, scales, seed) processed_image = preprocess(flux_image, do_remove_background=True) mv_images, show_image = generate_mvs(processed_image, steps, seed) obj_path, glb_path = make3d(mv_images) return flux_image, show_image, obj_path, glb_path generate_btn.click( fn=process_pipeline, inputs=[prompt, height, width, steps, scales, seed], outputs=[flux_output, mv_show_images, output_model_obj, output_model_glb] ) if __name__ == "__main__": demo.launch()