import os import threading import cv2 import gradio as gr import numpy as np import spaces import torch import torch.nn.functional as F # Add FLUX imports from diffusers import (AutoencoderKL, EulerAncestralDiscreteScheduler, FluxControlNetModel, FluxControlNetPipeline) from einops import rearrange from PIL import Image from torchvision.transforms import ToPILImage from .controlnet_union import ControlNetModel_Union from .pipeline_controlnet_union_sd_xl import \ StableDiffusionXLControlNetUnionPipeline from .render_utils import get_silhouette_image from utils.file_utils import load_tensor_from_file IMG_PIPE = None IMG_PIPE_LOCK = threading.Lock() # Add FLUX pipeline variables FLUX_PIPE = None FLUX_PIPE_LOCK = threading.Lock() FLUX_SUFFIX = None FLUX_NEGATIVE = None CPU_OFFLOAD = False def get_flux_pipe(): """ Lazy load the FLUX pipeline with ControlNet for image generation. """ global FLUX_PIPE, FLUX_SUFFIX, FLUX_NEGATIVE if FLUX_PIPE is not None: return FLUX_PIPE gr.Info("First called, loading FLUX pipeline... It may take about 1 minute.") with FLUX_PIPE_LOCK: if FLUX_PIPE is not None: return FLUX_PIPE FLUX_SUFFIX = ", albedo texture, high-quality, 8K, flat shaded, diffuse color only, orthographic view, seamless texture pattern, detailed surface texture." FLUX_NEGATIVE = "ugly, PBR, lighting, shadows, highlights, specular, reflections, ambient occlusion, global illumination, bloom, glare, lens flare, glow, shiny, glossy, noise, grain, blurry, bokeh, depth of field." base_model = 'black-forest-labs/FLUX.1-dev' controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0' controlnet = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16) assert os.environ["SEQTEX_SPACE_TOKEN"] != "", "Please set the SEQTEX_SPACE_TOKEN environment variable with your Hugging Face token, which has access to black-forest-labs/FLUX.1-dev." FLUX_PIPE = FluxControlNetPipeline.from_pretrained( base_model, controlnet=controlnet, torch_dtype=torch.bfloat16, token=os.environ["SEQTEX_SPACE_TOKEN"] ) # Use model CPU offload for better GPU utilization during inference if CPU_OFFLOAD: FLUX_PIPE.enable_model_cpu_offload() else: FLUX_PIPE.to("cuda") return FLUX_PIPE def get_sdxl_pipe(): """ Lazy load the SDXL pipeline with ControlNet for image generation. """ global IMG_PIPE if IMG_PIPE is not None: return IMG_PIPE gr.Info("First called, loading SDXL pipeline... It may take about 20 seconds.") with IMG_PIPE_LOCK: if IMG_PIPE is not None: return IMG_PIPE eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler") # when test with other base model, you need to change the vae also. vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) controlnet_model = ControlNetModel_Union.from_pretrained("xinsir/controlnet-union-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True) IMG_PIPE = StableDiffusionXLControlNetUnionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet_model, vae=vae, torch_dtype=torch.float16, scheduler=eulera_scheduler, ) # Use model CPU offload for better GPU utilization during inference if CPU_OFFLOAD: IMG_PIPE.enable_model_cpu_offload() else: IMG_PIPE.to("cuda") return IMG_PIPE def generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image: """ Generate image condition using SDXL model with ControlNet based on depth and normal images. :param depth_img: Depth image from the selected view. :param normal_img: Normal image (Camera Coordinate System) from the selected view. :param text_prompt: Text prompt for image generation. :param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground. :param seed: Random seed for image generation. :param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False). :param image_height: Height of the output image. :param image_width: Width of the output image. :param progress: Progress callback for Gradio. :return: Generated image condition (e.g., PIL Image). """ progress(0.1, desc="Loading SDXL pipeline...") pipeline = get_sdxl_pipe() progress(0.3, desc="SDXL pipeline loaded successfully.") positive_prompt = text_prompt + ", photo-realistic style, high quality, 8K, highly detailed texture, soft diffuse lighting, uniform lighting, flat lighting, even illumination, matte surface, low contrast, uniform color, foreground" negative_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, harsh lighting, high contrast, bright highlights, specular reflections, shiny surface, glossy, reflective, strong shadows, dramatic lighting, spotlight, direct sunlight, glare, bloom, lens flare' img_generation_resolution = 1024 # SDXL performs better at 1024x1024 image = pipeline(prompt=[positive_prompt]*1, image_list=[0, depth_img, 0, 0, normal_img, 0], negative_prompt=[negative_prompt]*1, generator=torch.Generator(device="cuda").manual_seed(seed), width=img_generation_resolution, height=img_generation_resolution, num_inference_steps=50, union_control=True, union_control_type=torch.Tensor([0, 1, 0, 0, 1, 0]).to("cuda"), # use depth and normal images progress=progress, ).images[0] progress(0.9, desc="Condition tensor generated successfully.") rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(pipeline.device) mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to(pipeline.device) # Ensure mask is in the correct shape mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1] rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False) mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False) # Apply edge refinement if enabled if edge_refinement: # Convert to CUDA device for edge refinement rgb_tensor_cuda = rgb_tensor.to("cuda") mask_tensor_cuda = mask_tensor.to("cuda") rgb_tensor_cuda = refine_image_edges(rgb_tensor_cuda, mask_tensor_cuda) rgb_tensor = rgb_tensor_cuda.to(pipeline.device) background_tensor = torch.zeros_like(rgb_tensor) rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor) rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W") rgb_tensor = rgb_tensor / 255. to_img = ToPILImage() condition_image = to_img(rgb_tensor.cpu()) progress(1, desc="Condition image generated successfully.") return condition_image def generate_flux_condition(depth_img, text_prompt, mask, seed=42, edge_refinement=False, image_height=1024, image_width=1024, progress=gr.Progress()) -> Image.Image: """ Generate image condition using FLUX model with ControlNet based on depth image only. Note: FLUX.1-dev-ControlNet-Union-Pro-2.0 does not support normal control, only depth. :param depth_img: Depth image from the selected view. :param text_prompt: Text prompt for image generation. :param mask: A mask image to apply to guide the subsequent pipeline to focus on the foreground. :param seed: Random seed for image generation. :param image_height: Height of the output image. :param image_width: Width of the output image. :param progress: Progress callback for Gradio. :param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: False). :return: Generated image condition (PIL Image). """ progress(0.1, desc="Loading FLUX pipeline...") pipeline = get_flux_pipe() progress(0.3, desc="FLUX pipeline loaded successfully.") # Enhanced prompt for better results positive_prompt = text_prompt + FLUX_SUFFIX negative_prompt = FLUX_NEGATIVE # Get image dimensions width, height = depth_img.size progress(0.5, desc="Generating image with FLUX (including onload and cpu offload)...") # Generate image using FLUX ControlNet with depth control # model_cpu_offload handles GPU loading automatically image = pipeline( prompt=positive_prompt, negative_prompt=negative_prompt, control_image=depth_img, width=width, height=height, controlnet_conditioning_scale=0.8, # Recommended for depth control_guidance_end=0.8, num_inference_steps=30, guidance_scale=3.5, generator=torch.Generator(device="cuda").manual_seed(seed), ).images[0] progress(0.9, desc="Applying mask and resizing...") # Convert to tensor and apply mask rgb_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to("cuda") mask_tensor = torch.from_numpy(np.array(mask)).float().unsqueeze(0).unsqueeze(0).to("cuda") mask_tensor = mask_tensor / 255.0 # Normalize mask to [0, 1] # Resize to target dimensions rgb_tensor = F.interpolate(rgb_tensor, (image_height, image_width), mode="bilinear", align_corners=False) mask_tensor = F.interpolate(mask_tensor, (image_height, image_width), mode="bilinear", align_corners=False) # Apply mask (blend with black background) background_tensor = torch.zeros_like(rgb_tensor) if edge_refinement: # replace edge with inner values rgb_tensor = refine_image_edges(rgb_tensor, mask_tensor) rgb_tensor = torch.lerp(background_tensor, rgb_tensor, mask_tensor) # Convert back to PIL Image rgb_tensor = rearrange(rgb_tensor, "1 C H W -> C H W") rgb_tensor = rgb_tensor / 255.0 to_img = ToPILImage() condition_image = to_img(rgb_tensor.cpu()) progress(1, desc="FLUX condition image generated successfully.") return condition_image def refine_image_edges(rgb_tensor, mask_tensor): """ Refine image edges using advanced morphological operations to remove white edges while preserving object boundaries. Algorithm: 1. Erode mask to get eroded_mask 2. Double erode mask to get double_eroded_mask 3. XOR eroded_mask and double_eroded_mask to get circle_valid_mask 4. Use circle_valid_mask to extract circle_rgb (clean edge values) 5. Dilate circle_rgb to cover the edge region 6. Final result: use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background :param rgb_tensor: RGB image tensor of shape (1, C, H, W) on CUDA device :param mask_tensor: Mask tensor of shape (1, 1, H, W) on CUDA device, normalized to [0, 1] :return: refined_rgb_tensor """ # Convert tensors to numpy for OpenCV processing rgb_np = rgb_tensor.squeeze().permute(1, 2, 0).cpu().numpy().astype(np.uint8) # (H, W, C) mask_np = mask_tensor.squeeze().cpu().numpy() # Remove batch and channel dimensions original_mask_np = (mask_np * 255).astype(np.uint8) # Convert to 0-255 range # Create morphological kernel (3x3 as requested) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) # Step 1: Erode mask to get eroded_mask eroded_mask_np = cv2.erode(original_mask_np, kernel, iterations=3) # Step 2: Double erode mask to get double_eroded_mask double_eroded_mask_np = cv2.erode(eroded_mask_np, kernel, iterations=5) # Step 3: XOR eroded_mask and double_eroded_mask to get circle_valid_mask circle_valid_mask_np = cv2.bitwise_xor(eroded_mask_np, double_eroded_mask_np) # Step 4: Use circle_valid_mask to extract circle_rgb (clean edge values) circle_valid_mask_3c = cv2.cvtColor(circle_valid_mask_np, cv2.COLOR_GRAY2BGR) / 255.0 circle_rgb_np = (rgb_np * circle_valid_mask_3c).astype(np.uint8) # Step 5: Dilate circle_rgb to cover the edge region (using iterations=6 directly) dilated_circle_rgb_np = cv2.dilate(circle_rgb_np, kernel, iterations=8) # Step 6: Final composition # Use double_eroded_mask for original RGB foreground, dilated_circle_rgb for background double_eroded_mask_3c = cv2.cvtColor(double_eroded_mask_np, cv2.COLOR_GRAY2BGR) / 255.0 # Final result: original RGB where double_eroded_mask is valid, dilated_circle_rgb elsewhere refined_rgb_np = (rgb_np * double_eroded_mask_3c + dilated_circle_rgb_np * (1 - double_eroded_mask_3c)).astype(np.uint8) # Convert refined RGB back to tensor refined_rgb_tensor = torch.from_numpy(refined_rgb_np).float().permute(2, 0, 1).unsqueeze(0).to("cuda") return refined_rgb_tensor @spaces.GPU() def generate_image_condition(position_imgs, normal_imgs, mask_imgs, w2c, text_prompt, selected_view="First View", seed=42, model="SDXL", edge_refinement=True, progress=gr.Progress()): """ Generate the image condition based on the selected view's silhouette and text prompt. :param position_imgs: Position images from different views. :param normal_imgs: Normal images from different views. :param mask_imgs: Mask images from different views. :param w2c: World-to-camera transformation matrices. :param text_prompt: The text prompt for image generation. :param selected_view: The selected view for image generation. :param seed: Random seed for image generation. :param model: The image generation model type, supports "SDXL" and "FLUX". :param progress: Progress callback for Gradio. :param edge_refinement: Whether to apply edge refinement to smooth mask boundaries (default: True). :return: Generated condition image and status message. """ # If any input is a file path, load the tensor from file if isinstance(position_imgs, str): position_imgs = load_tensor_from_file(position_imgs, map_location="cuda") if isinstance(normal_imgs, str): normal_imgs = load_tensor_from_file(normal_imgs, map_location="cuda") if isinstance(mask_imgs, str): mask_imgs = load_tensor_from_file(mask_imgs, map_location="cuda") if isinstance(w2c, str): w2c = load_tensor_from_file(w2c, map_location="cuda") position_imgs = position_imgs.to("cuda") normal_imgs = normal_imgs.to("cuda") mask_imgs = mask_imgs.to("cuda") w2c = w2c.to("cuda") progress(0, desc="Handling geometry information...") silhouette = get_silhouette_image(position_imgs, normal_imgs, mask_imgs=mask_imgs, w2c=w2c, selected_view=selected_view) depth_img = silhouette[0] normal_img = silhouette[1] mask = silhouette[2] try: if model == "SDXL": condition = generate_sdxl_condition(depth_img, normal_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress) return condition, "SDXL condition generated successfully." elif model == "FLUX": # FLUX only supports depth control, not normal raise NotImplementedError("FLUX model not supported in HF space, please delete it and use it locally") condition = generate_flux_condition(depth_img, text_prompt, mask, seed, edge_refinement=edge_refinement, progress=progress) return condition, "FLUX condition generated successfully (depth-only control)." else: raise ValueError(f"Unsupported image generation model type: {model}. Supported models: 'SDXL', 'FLUX'.") finally: torch.cuda.empty_cache()