gokaygokay's picture
Upload 43 files
3d535fa verified
raw
history blame
5.02 kB
import os
import tempfile
import time
import gradio as gr
import torch
from PIL import Image
from diffusers import FluxPipeline
from huggingface_hub import hf_hub_download
from sf3d.system import SF3D
import sf3d.utils as sf3d_utils
from gradio_litmodel3d import LitModel3D
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
torch.backends.cuda.matmul.allow_tf32 = True
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
# Set up environment and cache
cache_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path
if not os.path.exists(cache_path):
os.makedirs(cache_path, exist_ok=True)
# Initialize Flux pipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, token=huggingface_token)
pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
pipe.fuse_lora(lora_scale=0.125)
pipe.to(device="cuda", dtype=torch.bfloat16)
# Initialize SF3D model
sf3d_model = SF3D.from_pretrained(
"stabilityai/stable-fast-3d",
config_name="config.yaml",
weight_name="model.safetensors",
token=huggingface_token
)
sf3d_model.eval().cuda()
# Constants for SF3D
COND_WIDTH, COND_HEIGHT = 512, 512
COND_DISTANCE, COND_FOVY_DEG = 1.6, 40
BACKGROUND_COLOR = [0.5, 0.5, 0.5]
c2w_cond = sf3d_utils.default_cond_c2w(COND_DISTANCE)
intrinsic, intrinsic_normed_cond = sf3d_utils.create_intrinsic_from_fov_deg(
COND_FOVY_DEG, COND_HEIGHT, COND_WIDTH
)
def generate_image(prompt, height, width, steps, scales, seed):
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
return pipe(
prompt=[prompt],
generator=torch.Generator().manual_seed(int(seed)),
num_inference_steps=int(steps),
guidance_scale=float(scales),
height=int(height),
width=int(width),
max_sequence_length=256
).images[0]
def create_batch(input_image: Image.Image) -> dict:
img_cond = torch.from_numpy(
np.asarray(input_image.resize((COND_WIDTH, COND_HEIGHT))).astype(np.float32) / 255.0
).float().clip(0, 1)
mask_cond = img_cond[:, :, -1:]
rgb_cond = torch.lerp(
torch.tensor(BACKGROUND_COLOR)[None, None, :], img_cond[:, :, :3], mask_cond
)
batch_elem = {
"rgb_cond": rgb_cond,
"mask_cond": mask_cond,
"c2w_cond": c2w_cond.unsqueeze(0),
"intrinsic_cond": intrinsic.unsqueeze(0),
"intrinsic_normed_cond": intrinsic_normed_cond.unsqueeze(0),
}
return {k: v.unsqueeze(0) for k, v in batch_elem.items()}
def generate_3d_model(input_image):
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.float16):
model_batch = create_batch(input_image)
model_batch = {k: v.cuda() for k, v in model_batch.items()}
trimesh_mesh, _ = sf3d_model.generate_mesh(model_batch, 1024)
trimesh_mesh = trimesh_mesh[0]
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".glb")
trimesh_mesh.export(tmp_file.name, file_type="glb", include_normals=True)
return tmp_file.name
def process_and_generate(prompt, height, width, steps, scales, seed):
# Generate image from prompt
generated_image = generate_image(prompt, height, width, steps, scales, seed)
# Generate 3D model from the image
glb_file = generate_3d_model(generated_image)
return generated_image, glb_file
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Text-to-3D Model Generator")
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Your Image Description", lines=3)
with gr.Accordion("Advanced Settings", open=False):
height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8)
scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
seed = gr.Number(label="Seed", value=3413, precision=0)
generate_btn = gr.Button("Generate 3D Model", variant="primary")
with gr.Column(scale=4):
output_image = gr.Image(label="Generated Image")
output_3d = LitModel3D(label="3D Model", clear_color=[0.0, 0.0, 0.0, 0.0])
generate_btn.click(
process_and_generate,
inputs=[prompt, height, width, steps, scales, seed],
outputs=[output_image, output_3d]
)
if __name__ == "__main__":
demo.launch()