FurnitureDemo / app.py
blanchon's picture
Set batch size to 4
8d39edd
raw
history blame
7.22 kB
import secrets
from typing import cast
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxFillPipeline
from gradio.components.image_editor import EditorValue
from PIL import Image, ImageFilter, ImageOps
DEVICE = "cuda"
MAX_SEED = np.iinfo(np.int32).max
# FIXED_DIMENSION = 900
FIXED_DIMENSION = 512 + (512 // 2)
FIXED_DIMENSION = (FIXED_DIMENSION // 16) * 16
SYSTEM_PROMPT = r"""This two-panel split-frame image showcases a furniture in as a product shot versus styled in a room.
[LEFT] standalone product shot image the furniture on a white background.
[RIGHT] integrated example within a room scene."""
if not torch.cuda.is_available():
def _dummy_pipe(image: list[Image.Image], *args, **kwargs): # noqa: ARG001
return {"images": image}
pipe = _dummy_pipe
else:
state_dict, network_alphas = FluxFillPipeline.lora_state_dict(
pretrained_model_name_or_path_or_dict="blanchon/FluxFillFurniture",
weight_name="pytorch_lora_weights3.safetensors",
return_alphas=True,
)
if not all(("lora" in key or "dora_scale" in key) for key in state_dict):
msg = "Invalid LoRA checkpoint."
raise ValueError(msg)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16
).to(DEVICE)
FluxFillPipeline.load_lora_into_transformer(
state_dict=state_dict,
network_alphas=network_alphas,
transformer=pipe.transformer,
)
pipe.to(DEVICE)
@spaces.GPU
def infer(
furniture_image: Image.Image,
room_image: EditorValue,
prompt: str = "",
seed: int = 42,
randomize_seed: bool = False,
guidance_scale: float = 3.5,
num_inference_steps: int = 28,
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
):
_room_image = room_image["background"]
if _room_image is None:
msg = "Room image is required"
raise ValueError(msg)
_room_image = cast(Image.Image, _room_image)
_room_image = ImageOps.fit(
_room_image,
(FIXED_DIMENSION, FIXED_DIMENSION),
method=Image.Resampling.LANCZOS,
centering=(0.5, 0.5),
)
_room_mask = room_image["layers"][0]
if _room_mask is None:
msg = "Room mask is required"
raise ValueError(msg)
_room_mask = cast(Image.Image, _room_mask)
_room_mask = ImageOps.fit(
_room_mask,
(FIXED_DIMENSION, FIXED_DIMENSION),
method=Image.Resampling.LANCZOS,
centering=(0.5, 0.5),
)
furniture_image = ImageOps.fit(
furniture_image,
(FIXED_DIMENSION, FIXED_DIMENSION),
method=Image.Resampling.LANCZOS,
centering=(0.5, 0.5),
)
_furniture_image = Image.new(
"RGB",
(FIXED_DIMENSION, FIXED_DIMENSION),
(255, 255, 255),
)
_furniture_image.paste(furniture_image, (0, 0))
_furniture_mask = Image.new(
"RGB", (FIXED_DIMENSION, FIXED_DIMENSION), (255, 255, 255)
)
image = Image.new(
"RGB",
(FIXED_DIMENSION * 2, FIXED_DIMENSION),
(255, 255, 255),
)
# Paste on the center of the image
image.paste(_furniture_image, (0, 0))
image.paste(_room_image, (FIXED_DIMENSION, 0))
mask = Image.new(
"RGB",
(FIXED_DIMENSION * 2, FIXED_DIMENSION),
(255, 255, 255),
)
mask.paste(_furniture_mask, (0, 0))
mask.paste(_room_mask, (FIXED_DIMENSION, 0), _room_mask)
# Invert the mask
mask = ImageOps.invert(mask)
# Blur the mask
mask = mask.filter(ImageFilter.GaussianBlur(radius=10))
# Convert to 3 channel
mask = mask.convert("RGB")
if randomize_seed:
seed = secrets.randbelow(MAX_SEED)
prompt = prompt + ".\n" + SYSTEM_PROMPT if prompt else SYSTEM_PROMPT
batch_size = 4
results_images = pipe(
prompt=[prompt] * batch_size,
image=[image] * batch_size,
mask_image=mask,
height=FIXED_DIMENSION,
width=FIXED_DIMENSION * 2,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cpu").manual_seed(seed),
)["images"]
print(len(results_images))
cropped_images = [
image.crop((FIXED_DIMENSION, 0, FIXED_DIMENSION * 2, FIXED_DIMENSION))
for image in results_images
]
return cropped_images, seed
intro_markdown = """
# AnyFurnish
AnyFurnish is a tool that allows you to generate furniture images using Flux.1 Fill Dev.
"""
css = """
#col-container {
margin: 0 auto;
max-width: 1000px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(intro_markdown)
with gr.Row():
with gr.Column():
with gr.Column():
furniture_image = gr.Image(
label="Furniture Image",
type="pil",
sources=["upload"],
image_mode="RGB",
height=300,
)
room_image = gr.ImageEditor(
label="Room Image - Draw mask for inpainting",
type="pil",
sources=["upload"],
image_mode="RGBA",
layers=False,
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
height=300,
)
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter a custom furniture description (optional)",
container=False,
)
run_button = gr.Button("Run")
results = gr.Gallery(
label="Results",
format="png",
show_label=False,
columns=2,
height=600,
)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=30,
step=0.5,
value=50,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
furniture_image,
room_image,
prompt,
seed,
randomize_seed,
guidance_scale,
num_inference_steps,
],
outputs=[results, seed],
)
demo.launch()