prithivMLmods's picture
Update app.py
03b41ea verified
raw
history blame
22.1 kB
import gradio as gr
import spaces
import torch
from diffusers import AutoencoderKL, TCDScheduler
from diffusers.models.model_loading_utils import load_state_dict
# Remove ImageSlider import
# from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from controlnet_union import ControlNetModel_Union
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
from PIL import Image, ImageDraw
import numpy as np
# --- Model Loading (Unchanged) ---
config_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="config_promax.json",
)
config = ControlNetModel_Union.load_config(config_file)
controlnet_model = ControlNetModel_Union.from_config(config)
model_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="diffusion_pytorch_model_promax.safetensors",
)
sstate_dict = load_state_dict(model_file)
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
)
model.to(device="cuda", dtype=torch.float16)
#----------------------
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to("cuda")
pipe = StableDiffusionXLFillPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=model,
variant="fp16",
).to("cuda")
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
# --- Helper Functions (Mostly Unchanged) ---
def can_expand(source_width, source_height, target_width, target_height, alignment):
"""Checks if the image can be expanded based on the alignment."""
if alignment in ("Left", "Right") and source_width >= target_width:
return False
if alignment in ("Top", "Bottom") and source_height >= target_height:
return False
return True
def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
target_size = (width, height)
# Calculate the scaling factor to fit the image within the target size
scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
new_width = int(image.width * scale_factor)
new_height = int(image.height * scale_factor)
# Resize the source image to fit within target size
source = image.resize((new_width, new_height), Image.LANCZOS)
# Apply resize option using percentages
if resize_option == "Full":
resize_percentage = 100
elif resize_option == "50%":
resize_percentage = 50
elif resize_option == "33%":
resize_percentage = 33
elif resize_option == "25%":
resize_percentage = 25
else: # Custom
resize_percentage = custom_resize_percentage
# Calculate new dimensions based on percentage
resize_factor = resize_percentage / 100
new_width = int(source.width * resize_factor)
new_height = int(source.height * resize_factor)
# Ensure minimum size of 64 pixels
new_width = max(new_width, 64)
new_height = max(new_height, 64)
# Resize the image
source = source.resize((new_width, new_height), Image.LANCZOS)
# Calculate the overlap in pixels based on the percentage
overlap_x = int(new_width * (overlap_percentage / 100))
overlap_y = int(new_height * (overlap_percentage / 100))
# Ensure minimum overlap of 1 pixel
overlap_x = max(overlap_x, 1)
overlap_y = max(overlap_y, 1)
# Calculate margins based on alignment
if alignment == "Middle":
margin_x = (target_size[0] - new_width) // 2
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Left":
margin_x = 0
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Right":
margin_x = target_size[0] - new_width
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Top":
margin_x = (target_size[0] - new_width) // 2
margin_y = 0
elif alignment == "Bottom":
margin_x = (target_size[0] - new_width) // 2
margin_y = target_size[1] - new_height
# Adjust margins to eliminate gaps
margin_x = max(0, min(margin_x, target_size[0] - new_width))
margin_y = max(0, min(margin_y, target_size[1] - new_height))
# Create a new background image and paste the resized source image
background = Image.new('RGB', target_size, (255, 255, 255))
background.paste(source, (margin_x, margin_y))
# Create the mask
mask = Image.new('L', target_size, 255)
mask_draw = ImageDraw.Draw(mask)
# Calculate overlap areas
white_gaps_patch = 2 # Pixels to leave unmasked at edges if overlap is disabled for that edge
left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
# Adjust overlap boundaries based on alignment when specific overlap directions are *disabled*
# This prevents unmasking the absolute edge of the canvas in alignment modes
if alignment == "Left":
left_overlap = margin_x + overlap_x if overlap_left else margin_x # Keep edge masked if alignment is left
elif alignment == "Right":
right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width # Keep edge masked
elif alignment == "Top":
top_overlap = margin_y + overlap_y if overlap_top else margin_y # Keep edge masked
elif alignment == "Bottom":
bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height # Keep edge masked
# Ensure coordinates are within bounds
left_overlap = max(0, left_overlap)
top_overlap = max(0, top_overlap)
right_overlap = min(target_size[0], right_overlap)
bottom_overlap = min(target_size[1], bottom_overlap)
# Draw the mask (black rectangle for the area to keep)
if right_overlap > left_overlap and bottom_overlap > top_overlap:
mask_draw.rectangle([
(left_overlap, top_overlap),
(right_overlap, bottom_overlap)
], fill=0) # 0 means keep this area (not masked for inpainting)
# Invert the mask: White areas (255) will be inpainted. Black (0) is kept.
mask = Image.fromarray(255 - np.array(mask))
return background, mask
def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
# Create a preview image showing the mask
preview = background.copy().convert('RGBA')
# Create a semi-transparent red overlay for the masked (inpainting) area
red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 100)) # 100 alpha (~40% opacity)
# The mask is now white (255) where inpainting happens. Use this directly.
preview.paste(red_overlay, (0, 0), mask)
return preview
@spaces.GPU(duration=24)
def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
# Ensure alignment allows expansion, default to Middle if not
source_w, source_h = background.size # Use background size after initial resize/placement
target_w, target_h = width, height
if alignment in ("Left", "Right") and source_w >= target_w:
print(f"Warning: Source width ({source_w}) >= target width ({target_w}) with {alignment} alignment. Forcing Middle alignment.")
alignment = "Middle"
# Re-prepare mask/background with corrected alignment if needed (optional, depends if prepare func uses alignment early)
# background, mask = prepare_image_and_mask(...) # If needed
if alignment in ("Top", "Bottom") and source_h >= target_h:
print(f"Warning: Source height ({source_h}) >= target height ({target_h}) with {alignment} alignment. Forcing Middle alignment.")
alignment = "Middle"
# Re-prepare mask/background with corrected alignment if needed
# background, mask = prepare_image_and_mask(...) # If needed
# Image for ControlNet input (masked original content)
# The pipeline expects the original image content in the non-masked area
cnet_image = background.copy()
# The pipeline's `image` argument is the *initial* content for the *masked* area (often noise, but here we provide the background)
# The `mask_image` tells the pipeline *where* to perform the inpainting/outpainting.
# The controlnet `image` needs the original content visible in the non-masked area.
# ControlNet Union seems to work well by just passing the background with the source image pasted.
final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(final_prompt, "cuda", True)
# The pipeline call
# Note: The pipeline expects `image` (initial state for masked area) and `mask_image`
# The `control_image` is implicitly handled by the ControlNet attached to the pipeline
output_image = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
image=background, # Provide the initial canvas state
mask_image=mask, # Provide the mask (white is area to change)
control_image=cnet_image, # Pass the control image explicitly if needed by pipeline logic
num_inference_steps=num_inference_steps,
output_type="pil" # Ensure PIL output
).images[0]
# The pipeline should have already handled the compositing based on the mask
# If not, uncomment the paste operation below:
# final_image = background.copy().convert("RGBA") # Start with original background
# output_image = output_image.convert("RGBA")
# mask_rgba = mask.convert('L').point(lambda p: 255 if p > 128 else 0) # Ensure mask is binary 0/255
# final_image.paste(output_image, (0, 0), mask_rgba) # Paste generated content using the mask
# Return the single final image
return output_image
def clear_result():
"""Clears the result Image component."""
return gr.update(value=None)
# --- UI Helper Functions (Unchanged) ---
def preload_presets(target_ratio, ui_width, ui_height):
"""Updates the width and height sliders based on the selected aspect ratio."""
if target_ratio == "9:16":
changed_width = 720
changed_height = 1280
return changed_width, changed_height, gr.update() # Close accordion
elif target_ratio == "16:9":
changed_width = 1280
changed_height = 720
return changed_width, changed_height, gr.update() # Close accordion
elif target_ratio == "1:1":
changed_width = 1024
changed_height = 1024
return changed_width, changed_height, gr.update() # Close accordion
elif target_ratio == "Custom":
# Don't change sliders, just open accordion
return ui_width, ui_height, gr.update(open=True)
def select_the_right_preset(user_width, user_height):
"""Updates the radio button based on the current slider values."""
if user_width == 720 and user_height == 1280:
return "9:16"
elif user_width == 1280 and user_height == 720:
return "16:9"
elif user_width == 1024 and user_height == 1024:
return "1:1"
else:
return "Custom"
def toggle_custom_resize_slider(resize_option):
"""Shows/hides the custom resize slider."""
return gr.update(visible=(resize_option == "Custom"))
def update_history(new_image, history):
"""Updates the history gallery with the new image."""
if history is None:
history = []
# Ensure new_image is a PIL Image before adding
if isinstance(new_image, Image.Image):
history.insert(0, new_image)
return history
# --- Gradio UI Definition ---
css = """
.gradio-container {
width: 1200px !important;
margin: auto !important; /* Center the container */
}
h1 { text-align: center; }
footer { visibility: hidden; }
/* Ensure result image takes reasonable space */
#result-image img {
max-height: 768px; /* Adjust max height as needed */
object-fit: contain;
width: auto;
height: auto;
}
#history-gallery .thumbnail-item { /* Style history items */
height: 100px !important;
}
#history-gallery .gallery {
grid-template-rows: repeat(auto-fill, 100px) !important;
}
"""
title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>"""
with gr.Blocks(css=css) as demo:
with gr.Column():
gr.HTML(title)
with gr.Row():
with gr.Column(scale=1): # Left column for inputs
input_image = gr.Image(
type="pil",
label="Input Image",
height=400 # Give input image reasonable height
)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(label="Prompt (Optional)", placeholder="Describe the scene to expand...")
with gr.Column(scale=1):
run_button = gr.Button("Generate", variant="primary") # Make primary
with gr.Row():
target_ratio = gr.Radio(
label="Target Ratio",
choices=["9:16", "16:9", "1:1", "Custom"],
value="9:16",
scale=2
)
alignment_dropdown = gr.Dropdown(
choices=["Middle", "Left", "Right", "Top", "Bottom"],
value="Middle",
label="Align Source Image"
)
with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
with gr.Row():
width_slider = gr.Slider(
label="Target Width",
minimum=512, # Lowered minimum slightly
maximum=2048, # Increased maximum slightly
step=64, # Use steps of 64 common for SD
value=720,
)
height_slider = gr.Slider(
label="Target Height",
minimum=512,
maximum=2048,
step=64,
value=1280,
)
num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=12, step=1, value=4) # TCD/Lightning allows few steps
with gr.Group():
overlap_percentage = gr.Slider(
label="Mask overlap (%)",
minimum=1,
maximum=50,
value=12, # Default overlap
step=1
)
with gr.Row():
overlap_top = gr.Checkbox(label="Top", value=True)
overlap_right = gr.Checkbox(label="Right", value=True)
overlap_bottom = gr.Checkbox(label="Bottom", value=True)
overlap_left = gr.Checkbox(label="Left", value=True)
with gr.Row():
resize_option = gr.Radio(
label="Resize input within target",
choices=["Full", "50%", "33%", "25%", "Custom"],
value="Full"
)
custom_resize_percentage = gr.Slider(
label="Custom resize (%)",
minimum=1,
maximum=100,
step=1,
value=50,
visible=False # Initially hidden
)
preview_button = gr.Button("Preview Mask & Alignment")
preview_image = gr.Image(label="Mask Preview (Red = Outpaint Area)", type="pil", interactive=False)
gr.Examples(
examples=[
["./examples/example_1.webp", "A wide landscape view of the mountains", 1280, 720, "Middle"],
["./examples/example_2.jpg", "Full body shot of the astronaut on the moon", 720, 1280, "Middle"],
["./examples/example_3.jpg", "Expanding the sky and ground around the subject", 1024, 1024, "Middle"],
["./examples/example_3.jpg", "Expanding downwards from the subject", 1024, 1024, "Top"], # Align subject Top
["./examples/example_3.jpg", "Expanding upwards from the subject", 1024, 1024, "Bottom"], # Align subject Bottom
],
inputs=[input_image, prompt_input, width_slider, height_slider, alignment_dropdown],
label="Examples (Click to load)"
)
with gr.Column(scale=1): # Right column for output
# Replace ImageSlider with gr.Image
result = gr.Image(label="Generated Image", type="pil", interactive=False, elem_id="result-image")
use_as_input_button = gr.Button("Use Result as Input Image", visible=False) # Initially hidden
history_gallery = gr.Gallery(
label="History",
columns=6,
object_fit="contain",
interactive=False,
height=110, # Fixed height for the row
elem_id="history-gallery"
)
# --- Event Handling ---
def use_output_as_input(output_image):
"""Sets the generated output as the new input image."""
# output_image is now the single final image from gr.Image
return gr.update(value=output_image)
use_as_input_button.click(
fn=use_output_as_input,
inputs=[result], # Input is the result image component
outputs=[input_image] # Output updates the input image component
)
target_ratio.change(
fn=preload_presets,
inputs=[target_ratio, width_slider, height_slider],
outputs=[width_slider, height_slider, settings_panel], # Also control accordion state
queue=False
)
# Link sliders back to the ratio selector
width_slider.change(
fn=select_the_right_preset,
inputs=[width_slider, height_slider],
outputs=[target_ratio],
queue=False
)
height_slider.change(
fn=select_the_right_preset,
inputs=[width_slider, height_slider],
outputs=[target_ratio],
queue=False
)
resize_option.change(
fn=toggle_custom_resize_slider,
inputs=[resize_option],
outputs=[custom_resize_percentage],
queue=False
)
# Consolidate common inputs for generation
gen_inputs = [
input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
overlap_left, overlap_right, overlap_top, overlap_bottom
]
# Chain generation logic
run_button.click(
fn=clear_result,
inputs=None,
outputs=[result], # Clear the single image output
queue=False # Run clearing immediately
).then(
fn=infer,
inputs=gen_inputs,
outputs=[result], # Output the single image to the result component
).then(
# Update history with the single result image
fn=lambda res_img, hist: update_history(res_img, hist),
inputs=[result, history_gallery],
outputs=[history_gallery],
queue=False # Update history immediately after generation
).then(
# Show the 'Use as Input' button
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=[use_as_input_button],
queue=False # Show button immediately
)
prompt_input.submit(
fn=clear_result,
inputs=None,
outputs=[result],
queue=False
).then(
fn=infer,
inputs=gen_inputs,
outputs=[result],
).then(
fn=lambda res_img, hist: update_history(res_img, hist),
inputs=[result, history_gallery],
outputs=[history_gallery],
queue=False
).then(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=[use_as_input_button],
queue=False
)
preview_button.click(
fn=preview_image_and_mask,
inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
overlap_left, overlap_right, overlap_top, overlap_bottom],
outputs=preview_image,
queue=False # Preview should be fast
)
demo.queue(max_size=10).launch(ssr_mode=False, show_error=True) # Removed share=False for potential Hugging Face Spaces use