prithivMLmods's picture
Update app.py
11d7c13 verified
raw
history blame
9.31 kB
import gradio as gr
import spaces
import torch
from diffusers import AutoencoderKL, TCDScheduler
# (Assume ControlNet manual load or from_pretrained is already working)
from controlnet_union import ControlNetModel_Union
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from PIL import Image, ImageDraw
import numpy as np
# --- Load ControlNet and SDXL Fill Pipeline ---
# (Either manual download or via from_pretrained)
controlnet_model = ControlNetModel_Union.from_pretrained(
"xinsir/controlnet-union-sdxl-1.0",
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16
).to("cuda")
pipe = StableDiffusionXLFillPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=controlnet_model,
variant="fp16",
).to("cuda")
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
# --- Utility functions ---
def can_expand(source_width, source_height, target_width, target_height, alignment):
if alignment in ("Left", "Right") and source_width >= target_width:
return False
if alignment in ("Top", "Bottom") and source_height >= target_height:
return False
return True
def prepare_image_and_mask(image, width, height, overlap_percentage,
resize_option, custom_resize_percentage,
alignment, overlap_left, overlap_right,
overlap_top, overlap_bottom):
target = (width, height)
scale = min(target[0] / image.width, target[1] / image.height)
w, h = int(image.width * scale), int(image.height * scale)
src = image.resize((w, h), Image.LANCZOS)
# Resize percentage
if resize_option == "Full": pct = 100
elif resize_option == "50%": pct = 50
elif resize_option == "33%": pct = 33
elif resize_option == "25%": pct = 25
else: pct = custom_resize_percentage
rw, rh = max(int(src.width * pct / 100), 64), max(int(src.height * pct / 100), 64)
src = src.resize((rw, rh), Image.LANCZOS)
ox = max(int(rw * overlap_percentage / 100), 1)
oy = max(int(rh * overlap_percentage / 100), 1)
# Margins
if alignment == "Middle": mx, my = (width - rw)//2, (height - rh)//2
elif alignment == "Left": mx, my = 0, (height - rh)//2
elif alignment == "Right": mx, my = width - rw, (height - rh)//2
elif alignment == "Top": mx, my = (width - rw)//2, 0
else: mx, my = (width - rw)//2, height - rh
mx, my = max(0, min(mx, width - rw)), max(0, min(my, height - rh))
bg = Image.new("RGB", target, (255,255,255))
bg.paste(src, (mx, my))
mask = Image.new("L", target, 255)
d = ImageDraw.Draw(mask)
lx = mx + (ox if overlap_left else 2)
rx = mx + rw - (ox if overlap_right else 2)
ty = my + (oy if overlap_top else 2)
by = my + rh - (oy if overlap_bottom else 2)
# Edge adjustments
if alignment == "Left": lx = mx + (ox if overlap_left else 0)
if alignment == "Right": rx = mx + rw - (ox if overlap_right else 0)
if alignment == "Top": ty = my + (oy if overlap_top else 0)
if alignment == "Bottom": by = my + rh - (oy if overlap_bottom else 0)
d.rectangle([(lx, ty), (rx, by)], fill=0)
return bg, mask
def preview_image_and_mask(*args):
bg, mask = prepare_image_and_mask(*args)
vis = bg.copy().convert("RGBA")
red = Image.new("RGBA", bg.size, (255,0,0,64))
overlay = Image.new("RGBA", bg.size, (0,0,0,0))
overlay.paste(red, (0,0), mask)
return Image.alpha_composite(vis, overlay)
# --- Fixed infer: return list for slider ---
@spaces.GPU(duration=24)
def infer(image, width, height, overlap_percentage, num_inference_steps,
resize_option, custom_resize_percentage, prompt_input,
alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
background, mask = prepare_image_and_mask(
image, width, height, overlap_percentage,
resize_option, custom_resize_percentage,
alignment, overlap_left, overlap_right,
overlap_top, overlap_bottom
)
if not can_expand(background.width, background.height, width, height, alignment):
alignment = "Middle"
hole = background.copy()
hole.paste(0, (0,0), mask)
final_prompt = f"{prompt_input} , high quality, 4k"
embeds = pipe.encode_prompt(final_prompt, "cuda", True)
# Run pipeline and grab last frame
gen = pipe(
prompt_embeds=embeds[0],
negative_prompt_embeds=embeds[1],
pooled_prompt_embeds=embeds[2],
negative_pooled_prompt_embeds=embeds[3],
image=hole,
num_inference_steps=num_inference_steps
)
last = None
for img in gen:
last = img
out = last.convert("RGBA")
hole.paste(out, (0,0), mask)
# Return a list: [input_hole_image, final_output]
return [background, hole]
def clear_result():
return gr.update(value=None)
def preload_presets(ratio, w, h):
if ratio == "9:16": return 720, 1280, gr.update()
if ratio == "16:9": return 1280, 720, gr.update()
if ratio == "1:1": return 1024, 1024, gr.update()
return w, h, gr.update(open=True)
def select_the_right_preset(w, h):
if (w,h) == (720,1280): return "9:16"
if (w,h) == (1280,720): return "16:9"
if (w,h) == (1024,1024): return "1:1"
return "Custom"
def toggle_custom_resize_slider(opt):
return gr.update(visible=(opt=="Custom"))
def update_history(img, history):
history = history or []
history.insert(0, img)
return history
css = ".gradio-container { width: 1200px !important; }"
title = "<h1 align='center'>Diffusers Image Outpaint Lightning</h1>"
with gr.Blocks(css=css) as demo:
gr.HTML(title)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
prompt_input = gr.Textbox(label="Prompt (Optional)")
run_button = gr.Button("Generate")
target_ratio = gr.Radio(["9:16","16:9","1:1","Custom"], value="9:16", label="Expected Ratio")
alignment_dropdown = gr.Dropdown(["Middle","Left","Right","Top","Bottom"], value="Middle", label="Alignment")
with gr.Accordion("Advanced settings", open=False) as adv:
width_slider = gr.Slider(720,1536,step=8, value=720, label="Target Width")
height_slider = gr.Slider(720,1536,step=8, value=1280, label="Target Height")
num_steps = gr.Slider(4,12,step=1, value=8, label="Steps")
overlap_pct = gr.Slider(1,50,step=1, value=10, label="Mask overlap (%)")
overlap_top = gr.Checkbox(label="Overlap Top", value=True)
overlap_right = gr.Checkbox(label="Overlap Right", value=True)
overlap_left = gr.Checkbox(label="Overlap Left", value=True)
overlap_bottom= gr.Checkbox(label="Overlap Bottom", value=True)
resize_opt = gr.Radio(["Full","50%","33%","25%","Custom"], value="Full", label="Resize input image")
custom_resize = gr.Slider(1,100,step=1, value=50, visible=False, label="Custom resize (%)")
preview_btn = gr.Button("Preview alignment and mask")
gr.Examples(
examples=[
["./examples/example_1.webp",1280,720,"Middle"],
["./examples/example_2.jpg",1440,810,"Left"],
["./examples/example_3.jpg",1024,1024,"Top"],
["./examples/example_3.jpg",1024,1024,"Bottom"]
],
inputs=[input_image,width_slider,height_slider,alignment_dropdown]
)
with gr.Column():
result = ImageSlider(label="Comparison", interactive=False, type="pil", slider_color="pink")
history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain")
preview_image = gr.Image(label="Preview")
# Callbacks
run_button.click(clear_result, None, result)
run_button.click(
infer,
inputs=[ input_image, width_slider, height_slider, overlap_pct, num_steps,
resize_opt, custom_resize, prompt_input, alignment_dropdown,
overlap_left, overlap_right, overlap_top, overlap_bottom],
outputs=result
).then(update_history, inputs=[result, history_gallery], outputs=history_gallery)
target_ratio.change(preload_presets, [target_ratio, width_slider, height_slider], [width_slider, height_slider, adv])
width_slider.change(select_the_right_preset, [width_slider, height_slider], target_ratio)
height_slider.change(select_the_right_preset, [width_slider, height_slider], target_ratio)
resize_opt.change(toggle_custom_resize_slider, resize_opt, custom_resize)
preview_btn.click(preview_image_and_mask,
[input_image, width_slider, height_slider, overlap_pct, resize_opt, custom_resize, alignment_dropdown,
overlap_left, overlap_right, overlap_top, overlap_bottom],
preview_image)
demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)