Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import threading | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
import torch.nn.functional as F | |
# Add FLUX imports | |
from diffusers import (AutoencoderKL, EulerAncestralDiscreteScheduler, | |
FluxControlNetModel, FluxControlNetPipeline) | |
from einops import rearrange | |
from PIL import Image | |
from torchvision.transforms import ToPILImage | |
from .controlnet_union import ControlNetModel_Union | |
from .pipeline_controlnet_union_sd_xl import \ | |
StableDiffusionXLControlNetUnionPipeline | |
from .render_utils import get_silhouette_image | |
from utils.file_utils import load_tensor_from_file | |
IMG_PIPE = None | |
IMG_PIPE_LOCK = threading.Lock() | |
# Add FLUX pipeline variables | |
FLUX_PIPE = None | |
FLUX_PIPE_LOCK = threading.Lock() | |
FLUX_SUFFIX = None | |
FLUX_NEGATIVE = None | |
CPU_OFFLOAD = False | |
def get_flux_pipe(): | |
""" | |
Lazy load the FLUX pipeline with ControlNet for image generation. | |
""" | |
global FLUX_PIPE, FLUX_SUFFIX, FLUX_NEGATIVE | |
if FLUX_PIPE is not None: | |
return FLUX_PIPE | |
gr.Info("First called, loading FLUX pipeline... It may take about 1 minute.") | |
with FLUX_PIPE_LOCK: | |
if FLUX_PIPE is not None: | |
return FLUX_PIPE | |
FLUX_SUFFIX = ", albedo texture, high-quality, 8K, flat shaded, diffuse color only, orthographic view, seamless texture pattern, detailed surface texture." | |
FLUX_NEGATIVE = "ugly, PBR, lighting, shadows, highlights, specular, reflections, ambient occlusion, global illumination, bloom, glare, lens flare, glow, shiny, glossy, noise, grain, blurry, bokeh, depth of field." | |
base_model = 'black-forest-labs/FLUX.1-dev' | |
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0' | |
controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16) | |
assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to black-forest-labs/FLUX.1-dev." | |
FLUX_PIPE = FluxControlNetPipeline.from_pretrained( | |
base_model, | |
controlnet=controlnet, | |
torch_dtype=torch.bfloat16, | |
token=os.environ["SEQTEX_SPACE_TOKEN"] | |
) | |
# Use model CPU offload for better GPU utilization during inference | |
if CPU_OFFLOAD: | |
FLUX_PIPE.enable_model_cpu_offload() | |
else: | |
FLUX_PIPE.to("cuda") | |
return FLUX_PIPE | |
def get_sdxl_pipe(): | |
""" | |
Lazy load the SDXL pipeline with ControlNet for image generation. | |
""" | |
global IMG_PIPE | |
if IMG_PIPE is not None: | |
return IMG_PIPE | |
gr.Info("First called, loading SDXL pipeline... It may take about 20 seconds.") | |
with IMG_PIPE_LOCK: | |
if IMG_PIPE is not None: | |
return IMG_PIPE | |
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler") | |
# when test with other base model, you need to change the vae also. | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
controlnet_model = ControlNetModel_Union.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True) | |
IMG_PIPE = StableDiffusionXLControlNetUnionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet_model, | |
vae=vae, | |
torch_dtype=torch.float16, | |
scheduler=eulera_scheduler, | |
) | |
# Use model CPU offload for better GPU utilization during inference | |
if CPU_OFFLOAD: | |
IMG_PIPE.enable_model_cpu_offload() | |
else: | |
IMG_PIPE.to("cuda") | |
return IMG_PIPE | |
def generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image: | |
""" | |
Generate image condition using SDXL model with ControlNet based on depth and normal images. | |
:param depth_img: Depth image from the selected view. | |
:param normal_img: Normal image (Camera Coordinate System) from the selected view. | |
:param text_prompt: Text prompt for image generation. | |
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground. | |
:param seed: Random seed for image generation. | |
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False). | |
:param image_height: Height of the output image. | |
:param image_width: Width of the output image. | |
:param progress: Progress callback for Gradio. | |
:return: Generated image condition (e.g., PIL Image). | |
""" | |
progress(0.1, desc="Loading SDXL pipeline...") | |
pipeline = get_sdxl_pipe() | |
progress(0.3, desc="SDXL pipeline loaded successfully.") | |
positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft diffuse lighting, uniform lighting, flat lighting, even illumination, matte surface, low contrast, uniform color, foreground" | |
negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, harsh lighting, high contrast, bright highlights, specular reflections, shiny surface, glossy, reflective, strong shadows, dramatic lighting, spotlight, direct sunlight, glare, bloom, lens flare' | |
img_generation_resolution = 1024 # SDXL performs better at 1024x1024 | |
image = pipeline(prompt=[positive_prompt]*1, | |
image_list=[0, depth_img, 0, 0, normal_img, 0], | |
negative_prompt=[negative_prompt]*1, | |
generator=torch.Generator(device="cuda").manual_seed(seed), | |
width=img_generation_resolution, | |
height=img_generation_resolution, | |
num_inference_steps=50, | |
union_control=True, | |
union_control_type=torch.Tensor([0, 1, 0, 0, 1, 0]).to("cuda"), # use depth and normal images | |
progress=progress, | |
).images[0] | |
progress(0.9, desc="Condition tensor generated successfully.") | |
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(pipeline.device) | |
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to(pipeline.device) # Ensure mask is in the correct shape | |
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1] | |
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False) | |
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False) | |
# Apply edge refinement if enabled | |
if edge_refinement: | |
# Convert to CUDA device for edge refinement | |
rgb_tensor_cuda = rgb_tensor.to("cuda") | |
mask_tensor_cuda = mask_tensor.to("cuda") | |
rgb_tensor_cuda = refine_image_edges(rgb_tensor_cuda, mask_tensor_cuda) | |
rgb_tensor = rgb_tensor_cuda.to(pipeline.device) | |
background_tensor = torch.zeros_like(rgb_tensor) | |
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor) | |
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W") | |
rgb_tensor = rgb_tensor / 255. | |
to_img = ToPILImage() | |
condition_image = to_img(rgb_tensor.cpu()) | |
progress(1, desc="Condition image generated successfully.") | |
return condition_image | |
def generate_flux_condition(depth_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image: | |
""" | |
Generate image condition using FLUX model with ControlNet based on depth image only. | |
Note: FLUX.1-dev-ControlNet-Union-Pro-2.0 does not support normal control, only depth. | |
:param depth_img: Depth image from the selected view. | |
:param text_prompt: Text prompt for image generation. | |
:param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground. | |
:param seed: Random seed for image generation. | |
:param image_height: Height of the output image. | |
:param image_width: Width of the output image. | |
:param progress: Progress callback for Gradio. | |
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False). | |
:return: Generated image condition (PIL Image). | |
""" | |
progress(0.1, desc="Loading FLUX pipeline...") | |
pipeline = get_flux_pipe() | |
progress(0.3, desc="FLUX pipeline loaded successfully.") | |
# Enhanced prompt for better results | |
positive_prompt = text_prompt + FLUX_SUFFIX | |
negative_prompt = FLUX_NEGATIVE | |
# Get image dimensions | |
width, height = depth_img.size | |
progress(0.5, desc="Generating image with FLUX (including onload and cpu offload)...") | |
# Generate image using FLUX ControlNet with depth control | |
# model_cpu_offload handles GPU loading automatically | |
image = pipeline( | |
prompt=positive_prompt, | |
negative_prompt=negative_prompt, | |
control_image=depth_img, | |
width=width, | |
height=height, | |
controlnet_conditioning_scale=0.8, # Recommended for depth | |
control_guidance_end=0.8, | |
num_inference_steps=30, | |
guidance_scale=3.5, | |
generator=torch.Generator(device="cuda").manual_seed(seed), | |
).images[0] | |
progress(0.9, desc="Applying mask and resizing...") | |
# Convert to tensor and apply mask | |
rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to("cuda") | |
mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to("cuda") | |
mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1] | |
# Resize to target dimensions | |
rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False) | |
mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False) | |
# Apply mask (blend with black background) | |
background_tensor = torch.zeros_like(rgb_tensor) | |
if edge_refinement: | |
# replace edge with inner values | |
rgb_tensor = refine_image_edges(rgb_tensor, mask_tensor) | |
rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor) | |
# Convert back to PIL Image | |
rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W") | |
rgb_tensor = rgb_tensor / 255.0 | |
to_img = ToPILImage() | |
condition_image = to_img(rgb_tensor.cpu()) | |
progress(1, desc="FLUX condition image generated successfully.") | |
return condition_image | |
def refine_image_edges(rgb_tensor, mask_tensor): | |
""" | |
Refine image edges using advanced morphological operations to remove white edges while preserving object boundaries. | |
Algorithm: | |
1. Erode mask to get eroded_mask | |
2. Double erode mask to get double_eroded_mask | |
3. XOR eroded_mask and double_eroded_mask to get circle_valid_mask | |
4. Use circle_valid_mask to extract circle_rgb (clean edge values) | |
5. Dilate circle_rgb to cover the edge region | |
6. Final result: use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background | |
:param rgb_tensor: RGB image tensor of shape (1, C, H, W) on CUDA device | |
:param mask_tensor: Mask tensor of shape (1, 1, H, W) on CUDA device, normalized to [0, 1] | |
:return: refined_rgb_tensor | |
""" | |
# Convert tensors to numpy for OpenCV processing | |
rgb_np = rgb_tensor.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) # (H, W, C) | |
mask_np = mask_tensor.squeeze().cpu().numpy() # Remove batch and channel dimensions | |
original_mask_np = (mask_np * 255).astype(np.uint8) # Convert to 0-255 range | |
# Create morphological kernel (3x3 as requested) | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) | |
# Step 1: Erode mask to get eroded_mask | |
eroded_mask_np = cv2.erode(original_mask_np, kernel, iterations=3) | |
# Step 2: Double erode mask to get double_eroded_mask | |
double_eroded_mask_np = cv2.erode(eroded_mask_np, kernel, iterations=5) | |
# Step 3: XOR eroded_mask and double_eroded_mask to get circle_valid_mask | |
circle_valid_mask_np = cv2.bitwise_xor(eroded_mask_np, double_eroded_mask_np) | |
# Step 4: Use circle_valid_mask to extract circle_rgb (clean edge values) | |
circle_valid_mask_3c = cv2.cvtColor(circle_valid_mask_np, cv2.COLOR_GRAY2BGR) / 255.0 | |
circle_rgb_np = (rgb_np * circle_valid_mask_3c).astype(np.uint8) | |
# Step 5: Dilate circle_rgb to cover the edge region (using iterations=6 directly) | |
dilated_circle_rgb_np = cv2.dilate(circle_rgb_np, kernel, iterations=8) | |
# Step 6: Final composition | |
# Use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background | |
double_eroded_mask_3c = cv2.cvtColor(double_eroded_mask_np, cv2.COLOR_GRAY2BGR) / 255.0 | |
# Final result: original RGB where double_eroded_mask is valid, dilated_circle_rgb elsewhere | |
refined_rgb_np = (rgb_np * double_eroded_mask_3c + | |
dilated_circle_rgb_np * (1 - double_eroded_mask_3c)).astype(np.uint8) | |
# Convert refined RGB back to tensor | |
refined_rgb_tensor = torch.from_numpy(refined_rgb_np).float().permute(2, 0, 1).unsqueeze(0).to("cuda") | |
return refined_rgb_tensor | |
def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()): | |
""" | |
Generate the image condition based on the selected view's silhouette and text prompt. | |
:param position_imgs: Position images from different views. | |
:param normal_imgs: Normal images from different views. | |
:param mask_imgs: Mask images from different views. | |
:param w2c: World-to-camera transformation matrices. | |
:param text_prompt: The text prompt for image generation. | |
:param selected_view: The selected view for image generation. | |
:param seed: Random seed for image generation. | |
:param model: The image generation model type, supports "SDXL" and "FLUX". | |
:param progress: Progress callback for Gradio. | |
:param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True). | |
:return: Generated condition image and status message. | |
""" | |
# If any input is a file path, load the tensor from file | |
if isinstance(position_imgs, str): | |
position_imgs = load_tensor_from_file(position_imgs, map_location="cuda") | |
if isinstance(normal_imgs, str): | |
normal_imgs = load_tensor_from_file(normal_imgs, map_location="cuda") | |
if isinstance(mask_imgs, str): | |
mask_imgs = load_tensor_from_file(mask_imgs, map_location="cuda") | |
if isinstance(w2c, str): | |
w2c = load_tensor_from_file(w2c, map_location="cuda") | |
position_imgs = position_imgs.to("cuda") | |
normal_imgs = normal_imgs.to("cuda") | |
mask_imgs = mask_imgs.to("cuda") | |
w2c = w2c.to("cuda") | |
progress(0, desc="Handling geometry information...") | |
silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view) | |
depth_img = silhouette[0] | |
normal_img = silhouette[1] | |
mask = silhouette[2] | |
try: | |
if model == "SDXL": | |
condition = generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress) | |
return condition, "SDXL condition generated successfully." | |
elif model == "FLUX": | |
# FLUX only supports depth control, not normal | |
raise NotImplementedError("FLUX model not supported in HF space, please delete it and use it locally") | |
condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress) | |
return condition, "FLUX condition generated successfully (depth-only control)." | |
else: | |
raise ValueError(f"Unsupported image generation model type: {model}. Supported models: 'SDXL', 'FLUX'.") | |
finally: | |
torch.cuda.empty_cache() |