import secrets from pathlib import Path from typing import cast import gradio as gr import numpy as np import spaces import torch from diffusers import FluxFillPipeline from gradio.components.image_editor import EditorValue from PIL import Image, ImageFilter, ImageOps DEVICE = "cuda" EXAMPLES_DIR = Path(__file__).parent / "examples" MAX_SEED = np.iinfo(np.int32).max SYSTEM_PROMPT = r"""This two-panel split-frame image showcases a furniture in as a product shot versus styled in a room. [LEFT] standalone product shot image the furniture on a white background. [RIGHT] integrated example within a room scene.""" MASK_CONTEXT_PADDING = 16 * 8 if not torch.cuda.is_available(): def _dummy_pipe(image: Image.Image, *args, **kwargs): # noqa: ARG001 return {"images": [image]} pipe = _dummy_pipe else: state_dict, network_alphas = FluxFillPipeline.lora_state_dict( pretrained_model_name_or_path_or_dict="blanchon/FluxFillFurniture", weight_name="pytorch_lora_weights3.safetensors", return_alphas=True, ) if not all(("lora" in key or "dora_scale" in key) for key in state_dict): msg = "Invalid LoRA checkpoint." raise ValueError(msg) pipe = FluxFillPipeline.from_pretrained( "black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16 ).to(DEVICE) FluxFillPipeline.load_lora_into_transformer( state_dict=state_dict, network_alphas=network_alphas, transformer=pipe.transformer, ) pipe.to(DEVICE) def make_example(image_path: Path, mask_path: Path) -> EditorValue: background_image = Image.open(image_path) background_image = background_image.convert("RGB") background = np.array(background_image) mask_image = Image.open(mask_path) mask_image = mask_image.convert("RGB") mask = np.array(mask_image) mask = mask[:, :, 0] mask = np.where(mask == 255, 0, 255) # noqa: PLR2004 if background.shape[0] != mask.shape[0] or background.shape[1] != mask.shape[1]: msg = "Background and mask must have the same shape" raise ValueError(msg) layer = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) layer[:, :, 3] = mask composite = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) composite[:, :, :3] = background composite[:, :, 3] = np.where(mask == 255, 0, 255) # noqa: PLR2004 return { "background": background, "layers": [layer], "composite": composite, } def remove_padding(image, original_size): # Get current dimensions padded_width, padded_height = image.size original_width, original_height = original_size # Calculate cropping box left = (padded_width - original_width) // 2 top = (padded_height - original_height) // 2 right = left + original_width bottom = top + original_height # Crop to original size return image.crop((left, top, right, bottom)) @spaces.GPU(duration=150) def infer( furniture_image_input: Image.Image, room_image_input: EditorValue, furniture_prompt: str = "", seed: int = 42, randomize_seed: bool = False, guidance_scale: float = 3.5, num_inference_steps: int = 20, max_dimension: int = 720, num_images_per_prompt: int = 2, progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008 ): # Ensure max_dimension is a multiple of 16 (for VAE) max_dimension = (max_dimension // 16) * 16 room_image = room_image_input["background"] if room_image is None: msg = "Room image is required" raise ValueError(msg) room_image = cast("Image.Image", room_image) room_mask = room_image_input["layers"][0] if room_mask is None: msg = "Room mask is required" raise ValueError(msg) room_mask = cast("Image.Image", room_mask) mask_bbox_x_min, mask_bbox_y_min, mask_bbox_x_max, mask_bbox_y_max = ( room_mask.getbbox(alpha_only=False) ) # Add MASK_CONTEXT_PADDING (16 pixels) for the context mask_bbox_x_min -= MASK_CONTEXT_PADDING mask_bbox_x_min = max(mask_bbox_x_min, 0) mask_bbox_y_min -= MASK_CONTEXT_PADDING mask_bbox_y_min = max(mask_bbox_y_min, 0) mask_bbox_x_max += MASK_CONTEXT_PADDING mask_bbox_x_max = min(mask_bbox_x_max, room_mask.width) mask_bbox_y_max += MASK_CONTEXT_PADDING mask_bbox_y_max = min(mask_bbox_y_max, room_mask.height) bbox_longest_side = max( mask_bbox_x_max - mask_bbox_x_min, mask_bbox_y_max - mask_bbox_y_min, ) room_image_cropped = room_image.crop(( mask_bbox_x_min, mask_bbox_y_min, mask_bbox_x_max, mask_bbox_y_max, )) room_image_cropped = ImageOps.pad( room_image_cropped, (bbox_longest_side, bbox_longest_side), # White padding color=(255, 255, 255), centering=(0.5, 0.5), ) room_image_cropped = ImageOps.fit( room_image_cropped, (max_dimension, max_dimension), method=Image.Resampling.LANCZOS, centering=(0.5, 0.5), ) room_mask_cropped = room_mask.crop(( mask_bbox_x_min, mask_bbox_y_min, mask_bbox_x_max, mask_bbox_y_max, )) # room_mask_cropped.save("room_mask_croppedv1.png") room_mask_cropped = ImageOps.pad( room_mask_cropped, (max_dimension, max_dimension), # White padding color=(255, 255, 255), centering=(0.5, 0.5), ) room_mask_cropped = ImageOps.fit( room_mask_cropped, (max_dimension, max_dimension), method=Image.Resampling.LANCZOS, centering=(0.5, 0.5), ) # room_image_cropped.save("room_image_cropped.png") # room_mask_cropped.save("room_mask_cropped.png") # _room_image = ImageOps.fit( # _room_image, # (max_dimension, max_dimension), # method=Image.Resampling.LANCZOS, # centering=(0.5, 0.5), # ) _room_image.save("room_image.png") # _room_mask_with_white_background = Image.new( # "RGB", _room_mask.size, (255, 255, 255) # ) # _room_mask_with_white_background.paste(_room_mask, (0, 0), _room_mask) _room_mask_with_white_background.save("room_mask.png") furniture_image = ImageOps.pad( furniture_image_input, (max_dimension, max_dimension), # White padding color=(255, 255, 255), centering=(0.5, 0.5), ) _furniture_image.save("furniture_image.png") furniture_mask = Image.new("RGB", (max_dimension, max_dimension), (255, 255, 255)) image = Image.new( "RGB", (max_dimension * 2, max_dimension), (255, 255, 255), ) # Paste on the center of the image image.paste(furniture_image, (0, 0)) image.paste(room_image_cropped, (max_dimension, 0)) mask = Image.new( "RGB", (max_dimension * 2, max_dimension), (255, 255, 255), ) mask.paste(furniture_mask, (0, 0)) mask.paste(room_mask_cropped, (max_dimension, 0), room_mask_cropped) # Invert the mask mask = ImageOps.invert(mask) # Blur the mask mask = mask.filter(ImageFilter.GaussianBlur(radius=10)) # Convert to 3 channel mask = mask.convert("L") if randomize_seed: seed = secrets.randbelow(MAX_SEED) prompt = ( furniture_prompt + ".\n" + SYSTEM_PROMPT if furniture_prompt else SYSTEM_PROMPT ) # image.save("image.png") # mask.save("mask.png") results_images = pipe( prompt=prompt, image=image, mask_image=mask, height=max_dimension, width=max_dimension * 2, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, generator=torch.Generator("cpu").manual_seed(seed), )["images"] final_images = [] for image in results_images: final_image = room_image.copy() # Downscale back to the bbox_longest_side image_generated = image.crop(( max_dimension, 0, max_dimension * 2, max_dimension, )) image_generated = image_generated.resize((bbox_longest_side, bbox_longest_side)) # Crop back to the bbox (remove the padding) image_generated = remove_padding( image_generated, ( mask_bbox_x_max - mask_bbox_x_min, mask_bbox_y_max - mask_bbox_y_min, ), ) # Paste the image on the room image as the crop was done # on the room image final_image.paste(image_generated, (mask_bbox_x_min, mask_bbox_y_min)) final_images.append(final_image) return final_images, seed intro_markdown = r"""