SeqTex / utils /image_generation.py
yuanze1024's picture
init space 2
6d4bcdf
import os
import threading
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn.functional as F
# Add FLUX imports
from diffusers import (AutoencoderKL, EulerAncestralDiscreteScheduler,
FluxControlNetModel, FluxControlNetPipeline)
from einops import rearrange
from PIL import Image
from torchvision.transforms import ToPILImage
from .controlnet_union import ControlNetModel_Union
from .pipeline_controlnet_union_sd_xl import \
StableDiffusionXLControlNetUnionPipeline
from .render_utils import get_silhouette_image
from utils.file_utils import load_tensor_from_file
IMG_PIPE = None
IMG_PIPE_LOCK = threading.Lock()
# Add FLUX pipeline variables
FLUX_PIPE = None
FLUX_PIPE_LOCK = threading.Lock()
FLUX_SUFFIX = None
FLUX_NEGATIVE = None
CPU_OFFLOAD = False
def get_flux_pipe():
"""
Lazy load the FLUX pipeline with ControlNet for image generation.
"""
global FLUX_PIPE, FLUX_SUFFIX, FLUX_NEGATIVE
if FLUX_PIPE is not None:
return FLUX_PIPE
gr.Info("First called, loading FLUX pipeline... It may take about 1 minute.")
with FLUX_PIPE_LOCK:
if FLUX_PIPE is not None:
return FLUX_PIPE
FLUX_SUFFIX = ", albedo texture, high-quality, 8K, flat shaded, diffuse color only, orthographic view, seamless texture pattern, detailed surface texture."
FLUX_NEGATIVE = "ugly, PBR, lighting, shadows, highlights, specular, reflections, ambient occlusion, global illumination, bloom, glare, lens flare, glow, shiny, glossy, noise, grain, blurry, bokeh, depth of field."
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0'
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to black-forest-labs/FLUX.1-dev."
FLUX_PIPE = FluxControlNetPipeline.from_pretrained(
base_model,
controlnet=controlnet,
torch_dtype=torch.bfloat16,
token=os.environ["SEQTEX_SPACE_TOKEN"]
)
# Use model CPU offload for better GPU utilization during inference
if CPU_OFFLOAD:
FLUX_PIPE.enable_model_cpu_offload()
else:
FLUX_PIPE.to("cuda")
return FLUX_PIPE
def get_sdxl_pipe():
"""
Lazy load the SDXL pipeline with ControlNet for image generation.
"""
global IMG_PIPE
if IMG_PIPE is not None:
return IMG_PIPE
gr.Info("First called, loading SDXL pipeline... It may take about 20 seconds.")
with IMG_PIPE_LOCK:
if IMG_PIPE is not None:
return IMG_PIPE
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
# when test with other base model, you need to change the vae also.
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
controlnet_model = ControlNetModel_Union.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True)
IMG_PIPE = StableDiffusionXLControlNetUnionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet_model,
vae=vae,
torch_dtype=torch.float16,
scheduler=eulera_scheduler,
)
# Use model CPU offload for better GPU utilization during inference
if CPU_OFFLOAD:
IMG_PIPE.enable_model_cpu_offload()
else:
IMG_PIPE.to("cuda")
return IMG_PIPE
def generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image:
"""
Generate image condition using SDXL model with ControlNet based on depth and normal images.
:param depth_img: Depth image from the selected view.
:param normal_img: Normal image (Camera Coordinate System) from the selected view.
:param text_prompt: Text prompt for image generation.
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground.
:param seed: Random seed for image generation.
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False).
:param image_height: Height of the output image.
:param image_width: Width of the output image.
:param progress: Progress callback for Gradio.
:return: Generated image condition (e.g., PIL Image).
"""
progress(0.1, desc="Loading SDXL pipeline...")
pipeline = get_sdxl_pipe()
progress(0.3, desc="SDXL pipeline loaded successfully.")
positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft diffuse lighting, uniform lighting, flat lighting, even illumination, matte surface, low contrast, uniform color, foreground"
negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, harsh lighting, high contrast, bright highlights, specular reflections, shiny surface, glossy, reflective, strong shadows, dramatic lighting, spotlight, direct sunlight, glare, bloom, lens flare'
img_generation_resolution = 1024 # SDXL performs better at 1024x1024
image = pipeline(prompt=[positive_prompt]*1,
image_list=[0, depth_img, 0, 0, normal_img, 0],
negative_prompt=[negative_prompt]*1,
generator=torch.Generator(device="cuda").manual_seed(seed),
width=img_generation_resolution,
height=img_generation_resolution,
num_inference_steps=50,
union_control=True,
union_control_type=torch.Tensor([0, 1, 0, 0, 1, 0]).to("cuda"), # use depth and normal images
progress=progress,
).images[0]
progress(0.9, desc="Condition tensor generated successfully.")
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(pipeline.device)
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to(pipeline.device) # Ensure mask is in the correct shape
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1]
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
# Apply edge refinement if enabled
if edge_refinement:
# Convert to CUDA device for edge refinement
rgb_tensor_cuda = rgb_tensor.to("cuda")
mask_tensor_cuda = mask_tensor.to("cuda")
rgb_tensor_cuda = refine_image_edges(rgb_tensor_cuda, mask_tensor_cuda)
rgb_tensor = rgb_tensor_cuda.to(pipeline.device)
background_tensor = torch.zeros_like(rgb_tensor)
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor)
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W")
rgb_tensor = rgb_tensor / 255.
to_img = ToPILImage()
condition_image = to_img(rgb_tensor.cpu())
progress(1, desc="Condition image generated successfully.")
return condition_image
def generate_flux_condition(depth_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image:
"""
Generate image condition using FLUX model with ControlNet based on depth image only.
Note: FLUX.1-dev-ControlNet-Union-Pro-2.0 does not support normal control, only depth.
:param depth_img: Depth image from the selected view.
:param text_prompt: Text prompt for image generation.
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground.
:param seed: Random seed for image generation.
:param image_height: Height of the output image.
:param image_width: Width of the output image.
:param progress: Progress callback for Gradio.
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False).
:return: Generated image condition (PIL Image).
"""
progress(0.1, desc="Loading FLUX pipeline...")
pipeline = get_flux_pipe()
progress(0.3, desc="FLUX pipeline loaded successfully.")
# Enhanced prompt for better results
positive_prompt = text_prompt + FLUX_SUFFIX
negative_prompt = FLUX_NEGATIVE
# Get image dimensions
width, height = depth_img.size
progress(0.5, desc="Generating image with FLUX (including onload and cpu offload)...")
# Generate image using FLUX ControlNet with depth control
# model_cpu_offload handles GPU loading automatically
image = pipeline(
prompt=positive_prompt,
negative_prompt=negative_prompt,
control_image=depth_img,
width=width,
height=height,
controlnet_conditioning_scale=0.8, # Recommended for depth
control_guidance_end=0.8,
num_inference_steps=30,
guidance_scale=3.5,
generator=torch.Generator(device="cuda").manual_seed(seed),
).images[0]
progress(0.9, desc="Applying mask and resizing...")
# Convert to tensor and apply mask
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to("cuda")
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to("cuda")
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1]
# Resize to target dimensions
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False)
# Apply mask (blend with black background)
background_tensor = torch.zeros_like(rgb_tensor)
if edge_refinement:
# replace edge with inner values
rgb_tensor = refine_image_edges(rgb_tensor, mask_tensor)
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor)
# Convert back to PIL Image
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W")
rgb_tensor = rgb_tensor / 255.0
to_img = ToPILImage()
condition_image = to_img(rgb_tensor.cpu())
progress(1, desc="FLUX condition image generated successfully.")
return condition_image
def refine_image_edges(rgb_tensor, mask_tensor):
"""
Refine image edges using advanced morphological operations to remove white edges while preserving object boundaries.
Algorithm:
1. Erode mask to get eroded_mask
2. Double erode mask to get double_eroded_mask
3. XOR eroded_mask and double_eroded_mask to get circle_valid_mask
4. Use circle_valid_mask to extract circle_rgb (clean edge values)
5. Dilate circle_rgb to cover the edge region
6. Final result: use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background
:param rgb_tensor: RGB image tensor of shape (1, C, H, W) on CUDA device
:param mask_tensor: Mask tensor of shape (1, 1, H, W) on CUDA device, normalized to [0, 1]
:return: refined_rgb_tensor
"""
# Convert tensors to numpy for OpenCV processing
rgb_np = rgb_tensor.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) # (H, W, C)
mask_np = mask_tensor.squeeze().cpu().numpy() # Remove batch and channel dimensions
original_mask_np = (mask_np * 255).astype(np.uint8) # Convert to 0-255 range
# Create morphological kernel (3x3 as requested)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
# Step 1: Erode mask to get eroded_mask
eroded_mask_np = cv2.erode(original_mask_np, kernel, iterations=3)
# Step 2: Double erode mask to get double_eroded_mask
double_eroded_mask_np = cv2.erode(eroded_mask_np, kernel, iterations=5)
# Step 3: XOR eroded_mask and double_eroded_mask to get circle_valid_mask
circle_valid_mask_np = cv2.bitwise_xor(eroded_mask_np, double_eroded_mask_np)
# Step 4: Use circle_valid_mask to extract circle_rgb (clean edge values)
circle_valid_mask_3c = cv2.cvtColor(circle_valid_mask_np, cv2.COLOR_GRAY2BGR) / 255.0
circle_rgb_np = (rgb_np * circle_valid_mask_3c).astype(np.uint8)
# Step 5: Dilate circle_rgb to cover the edge region (using iterations=6 directly)
dilated_circle_rgb_np = cv2.dilate(circle_rgb_np, kernel, iterations=8)
# Step 6: Final composition
# Use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background
double_eroded_mask_3c = cv2.cvtColor(double_eroded_mask_np, cv2.COLOR_GRAY2BGR) / 255.0
# Final result: original RGB where double_eroded_mask is valid, dilated_circle_rgb elsewhere
refined_rgb_np = (rgb_np * double_eroded_mask_3c +
dilated_circle_rgb_np * (1 - double_eroded_mask_3c)).astype(np.uint8)
# Convert refined RGB back to tensor
refined_rgb_tensor = torch.from_numpy(refined_rgb_np).float().permute(2, 0, 1).unsqueeze(0).to("cuda")
return refined_rgb_tensor
@spaces.GPU()
def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()):
"""
Generate the image condition based on the selected view's silhouette and text prompt.
:param position_imgs: Position images from different views.
:param normal_imgs: Normal images from different views.
:param mask_imgs: Mask images from different views.
:param w2c: World-to-camera transformation matrices.
:param text_prompt: The text prompt for image generation.
:param selected_view: The selected view for image generation.
:param seed: Random seed for image generation.
:param model: The image generation model type, supports "SDXL" and "FLUX".
:param progress: Progress callback for Gradio.
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True).
:return: Generated condition image and status message.
"""
# If any input is a file path, load the tensor from file
if isinstance(position_imgs, str):
position_imgs = load_tensor_from_file(position_imgs, map_location="cuda")
if isinstance(normal_imgs, str):
normal_imgs = load_tensor_from_file(normal_imgs, map_location="cuda")
if isinstance(mask_imgs, str):
mask_imgs = load_tensor_from_file(mask_imgs, map_location="cuda")
if isinstance(w2c, str):
w2c = load_tensor_from_file(w2c, map_location="cuda")
position_imgs = position_imgs.to("cuda")
normal_imgs = normal_imgs.to("cuda")
mask_imgs = mask_imgs.to("cuda")
w2c = w2c.to("cuda")
progress(0, desc="Handling geometry information...")
silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view)
depth_img = silhouette[0]
normal_img = silhouette[1]
mask = silhouette[2]
try:
if model == "SDXL":
condition = generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
return condition, "SDXL condition generated successfully."
elif model == "FLUX":
# FLUX only supports depth control, not normal
raise NotImplementedError("FLUX model not supported in HF space, please delete it and use it locally")
condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress)
return condition, "FLUX condition generated successfully (depth-only control)."
else:
raise ValueError(f"Unsupported image generation model type: {model}. Supported models: 'SDXL', 'FLUX'.")
finally:
torch.cuda.empty_cache()