FurnitureDemo / app.py
blanchon's picture
Revert batching mask
d97d8b8
raw
history blame
7.26 kB
import secrets
from typing import cast
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxFillPipeline
from gradio.components.image_editor import EditorValue
from PIL import Image, ImageFilter, ImageOps
DEVICE = "cuda"
MAX_SEED = np.iinfo(np.int32).max
# FIXED_DIMENSION = 900
FIXED_DIMENSION = 512 + (512 // 2)
FIXED_DIMENSION = (FIXED_DIMENSION // 16) * 16
SYSTEM_PROMPT = r"""This two-panel split-frame image showcases a furniture in as a product shot versus styled in a room.
[LEFT] standalone product shot image the furniture on a white background.
[RIGHT] integrated example within a room scene."""
if not torch.cuda.is_available():
def _dummy_pipe(image: list[Image.Image], *args, **kwargs): # noqa: ARG001
return {"images": image}
pipe = _dummy_pipe
else:
state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
pretrained_model_name_or_path_or_dict="blanchon/FluxFillFurniture",
weight_name="pytorch_lora_weights3.safetensors",
return_alphas=True,
)
if not all(("lora" in key or "dora_scale" in key) for key in state_dict):
msg = "Invalid LoRA checkpoint."
raise ValueError(msg)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to(DEVICE)
FluxFillPipeline.load_lora_into_transformer(
state_dict=state_dict,
network_alphas=network_alphas,
transformer=pipe.transformer,
)
pipe.to(DEVICE)
@spaces.GPU
def infer(
furniture_image: Image.Image,
room_image: EditorValue,
prompt: str = "",
seed: int = 42,
randomize_seed: bool = False,
guidance_scale: float = 3.5,
num_inference_steps: int = 28,
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
):
_room_image = room_image["background"]
if _room_image is None:
msg = "Room image is required"
raise ValueError(msg)
_room_image = cast(Image.Image, _room_image)
_room_image = ImageOps.fit(
_room_image,
(FIXED_DIMENSION, FIXED_DIMENSION),
method=Image.Resampling.LANCZOS,
centering=(0.5, 0.5),
)
_room_mask = room_image["layers"][0]
if _room_mask is None:
msg = "Room mask is required"
raise ValueError(msg)
_room_mask = cast(Image.Image, _room_mask)
_room_mask = ImageOps.fit(
_room_mask,
(FIXED_DIMENSION, FIXED_DIMENSION),
method=Image.Resampling.LANCZOS,
centering=(0.5, 0.5),
)
furniture_image = ImageOps.fit(
furniture_image,
(FIXED_DIMENSION, FIXED_DIMENSION),
method=Image.Resampling.LANCZOS,
centering=(0.5, 0.5),
)
_furniture_image = Image.new(
"RGB",
(FIXED_DIMENSION, FIXED_DIMENSION),
(255, 255, 255),
)
_furniture_image.paste(furniture_image, (0, 0))
_furniture_mask = Image.new(
"RGB", (FIXED_DIMENSION, FIXED_DIMENSION), (255, 255, 255)
)
image = Image.new(
"RGB",
(FIXED_DIMENSION * 2, FIXED_DIMENSION),
(255, 255, 255),
)
# Paste on the center of the image
image.paste(_furniture_image, (0, 0))
image.paste(_room_image, (FIXED_DIMENSION, 0))
mask = Image.new(
"RGB",
(FIXED_DIMENSION * 2, FIXED_DIMENSION),
(255, 255, 255),
)
mask.paste(_furniture_mask, (0, 0))
mask.paste(_room_mask, (FIXED_DIMENSION, 0), _room_mask)
# Invert the mask
mask = ImageOps.invert(mask)
# Blur the mask
mask = mask.filter(ImageFilter.GaussianBlur(radius=10))
# Convert to 3 channel
mask = mask.convert("L")
if randomize_seed:
seed = secrets.randbelow(MAX_SEED)
prompt = prompt + ".\n" + SYSTEM_PROMPT if prompt else SYSTEM_PROMPT
batch_size = 4
results_images = pipe(
prompt=[prompt] * batch_size,
image=[image] * batch_size,
mask_image=mask,
height=FIXED_DIMENSION,
width=FIXED_DIMENSION * 2,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cpu").manual_seed(seed),
)["images"]
print(len(results_images))
cropped_images = [
image.crop((FIXED_DIMENSION, 0, FIXED_DIMENSION * 2, FIXED_DIMENSION))
for image in results_images
]
return cropped_images, seed
intro_markdown = """
# AnyFurnish
AnyFurnish is a tool that allows you to generate furniture images using Flux.1 Fill Dev.
"""
css = """
#col-container {
margin: 0 auto;
max-width: 1000px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(intro_markdown)
with gr.Row():
with gr.Column():
with gr.Column():
furniture_image = gr.Image(
label="Furniture Image",
type="pil",
sources=["upload"],
image_mode="RGB",
height=300,
)
room_image = gr.ImageEditor(
label="Room Image - Draw mask for inpainting",
type="pil",
sources=["upload"],
image_mode="RGBA",
layers=False,
crop_size="1:1",
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
height=300,
)
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter a custom furniture description (optional)",
container=False,
)
run_button = gr.Button("Run")
results = gr.Gallery(
label="Results",
format="png",
show_label=False,
columns=2,
height=600,
)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=30,
step=0.5,
value=50,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
furniture_image,
room_image,
prompt,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
],
outputs=[results, seed],
)
demo.launch()