prithivMLmods's picture
Update app.py
d4884bc verified
raw
history blame
20.9 kB
import gradio as gr
import spaces
import torch
from diffusers import AutoencoderKL, TCDScheduler
from diffusers.models.model_loading_utils import load_state_dict
from huggingface_hub import hf_hub_download
from controlnet_union import ControlNetModel_Union
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
from PIL import Image, ImageDraw
import numpy as np
# --- Configuration and Model Loading ---
# Load ControlNet Union
config_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="config_promax.json",
)
config = ControlNetModel_Union.load_config(config_file)
controlnet_model = ControlNetModel_Union.from_config(config)
model_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="diffusion_pytorch_model_promax.safetensors",
)
sstate_dict = load_state_dict(model_file)
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
)
model.to(device="cuda", dtype=torch.float16)
# Load VAE
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to("cuda")
# --- Load Multiple Pipelines ---
pipelines = {}
# Load RealVisXL V5.0 Lightning
pipe_v5 = StableDiffusionXLFillPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=model, # Use the same controlnet
variant="fp16",
).to("cuda")
pipe_v5.scheduler = TCDScheduler.from_config(pipe_v5.scheduler.config)
pipelines["RealVisXL V5.0 Lightning"] = pipe_v5
# Load RealVisXL V4.0 Lightning
pipe_v4 = StableDiffusionXLFillPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0_Lightning",
torch_dtype=torch.float16,
vae=vae, # Use the same VAE
controlnet=model, # Use the same controlnet
variant="fp16",
).to("cuda")
pipe_v4.scheduler = TCDScheduler.from_config(pipe_v4.scheduler.config)
pipelines["RealVisXL V4.0 Lightning"] = pipe_v4
# --- Helper Functions ---
def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
target_size = (width, height)
# Calculate the scaling factor to fit the image within the target size
scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
new_width = int(image.width * scale_factor)
new_height = int(image.height * scale_factor)
# Resize the source image to fit within target size
source = image.resize((new_width, new_height), Image.LANCZOS)
# Apply resize option using percentages
if resize_option == "Full":
resize_percentage = 100
elif resize_option == "50%":
resize_percentage = 50
elif resize_option == "33%":
resize_percentage = 33
elif resize_option == "25%":
resize_percentage = 25
else: # Custom
resize_percentage = custom_resize_percentage
# Calculate new dimensions based on percentage
resize_factor = resize_percentage / 100
new_width = int(source.width * resize_factor)
new_height = int(source.height * resize_factor)
# Ensure minimum size of 64 pixels
new_width = max(new_width, 64)
new_height = max(new_height, 64)
# Resize the image
source = source.resize((new_width, new_height), Image.LANCZOS)
# Calculate the overlap in pixels based on the percentage
overlap_x = int(new_width * (overlap_percentage / 100))
overlap_y = int(new_height * (overlap_percentage / 100))
# Ensure minimum overlap of 1 pixel
overlap_x = max(overlap_x, 1)
overlap_y = max(overlap_y, 1)
# Calculate margins based on alignment
if alignment == "Middle":
margin_x = (target_size[0] - new_width) // 2
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Left":
margin_x = 0
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Right":
margin_x = target_size[0] - new_width
margin_y = (target_size[1] - new_height) // 2
elif alignment == "Top":
margin_x = (target_size[0] - new_width) // 2
margin_y = 0
elif alignment == "Bottom":
margin_x = (target_size[0] - new_width) // 2
margin_y = target_size[1] - new_height
else: # Default to Middle if alignment is somehow invalid
margin_x = (target_size[0] - new_width) // 2
margin_y = (target_size[1] - new_height) // 2
# Adjust margins to eliminate gaps
margin_x = max(0, min(margin_x, target_size[0] - new_width))
margin_y = max(0, min(margin_y, target_size[1] - new_height))
# Create a new background image and paste the resized source image
background = Image.new('RGB', target_size, (255, 255, 255))
background.paste(source, (margin_x, margin_y))
# Create the mask
mask = Image.new('L', target_size, 255) # White background (area to be filled)
mask_draw = ImageDraw.Draw(mask)
# Calculate overlap areas (where the mask should be black = keep original)
white_gaps_patch = 2 # Small value to ensure no tiny gaps at edges if overlap is off
# Determine the coordinates for the black rectangle (the non-masked area)
# Start with the full area covered by the pasted image
left_black = margin_x
top_black = margin_y
right_black = margin_x + new_width
bottom_black = margin_y + new_height
# Adjust the black area based on overlap checkboxes
if overlap_left:
left_black += overlap_x
else:
# If not overlapping left, ensure the black mask starts exactly at the image edge or slightly inside
left_black += white_gaps_patch if alignment != "Left" else 0
if overlap_right:
right_black -= overlap_x
else:
# If not overlapping right, ensure the black mask ends exactly at the image edge or slightly inside
right_black -= white_gaps_patch if alignment != "Right" else 0
if overlap_top:
top_black += overlap_y
else:
# If not overlapping top, ensure the black mask starts exactly at the image edge or slightly inside
top_black += white_gaps_patch if alignment != "Top" else 0
if overlap_bottom:
bottom_black -= overlap_y
else:
# If not overlapping bottom, ensure the black mask ends exactly at the image edge or slightly inside
bottom_black -= white_gaps_patch if alignment != "Bottom" else 0
# Ensure coordinates are valid (left < right, top < bottom)
left_black = min(left_black, target_size[0])
top_black = min(top_black, target_size[1])
right_black = max(left_black, right_black) # Ensure right >= left
bottom_black = max(top_black, bottom_black) # Ensure bottom >= top
right_black = min(right_black, target_size[0])
bottom_black = min(bottom_black, target_size[1])
# Draw the black rectangle onto the white mask
# The area *inside* this rectangle will be kept (mask value 0)
# The area *outside* this rectangle will be filled (mask value 255)
if right_black > left_black and bottom_black > top_black:
mask_draw.rectangle(
[(left_black, top_black), (right_black, bottom_black)],
fill=0 # Black means keep this area
)
return background, mask
@spaces.GPU(duration=24)
def infer(selected_model_name, image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
if image is None:
raise gr.Error("Please upload an input image.")
try:
# Select the pipeline based on the dropdown choice
pipe = pipelines[selected_model_name]
background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
# Create the controlnet input image (original image pasted on white bg, with masked area blacked out)
cnet_image = background.copy()
# Create a black image of the same size as the mask
black_fill = Image.new('RGB', mask.size, (0, 0, 0))
# Paste the black fill using the mask (where mask is 255/white, paste black)
cnet_image.paste(black_fill, (0, 0), mask)
final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(final_prompt, "cuda", True)
# Generate the image
generator = torch.Generator(device="cuda").manual_seed(np.random.randint(0, 2**32)) # Add random seed
# The pipeline expects the 'image' argument to be the background with the original content
# and the 'mask_image' argument to define the area to *inpaint* (white area in our mask)
result_image = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
image=background, # The background containing the original image
mask_image=mask, # The mask (white = fill, black = keep)
control_image=cnet_image, # ControlNet input image
num_inference_steps=num_inference_steps,
generator=generator, # Use generator for reproducibility if needed
output_type="pil" # Ensure PIL output
).images[0]
# The pipeline directly returns the final composited image.
# No need for manual pasting like before.
return result_image
except Exception as e:
print(f"Error during inference: {e}")
import traceback
traceback.print_exc()
# Return the background image or raise a Gradio error for clarity
# raise gr.Error(f"Inference failed: {e}")
# Or return the prepared background/mask for debugging
background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
# Combine background and mask for visualization
debug_img = Image.blend(background.convert("RGBA"), mask.convert("RGBA"), 0.5)
return debug_img # Return a debug image or None
def clear_result():
"""Clears the result Image."""
return gr.update(value=None)
def preload_presets(target_ratio, ui_width, ui_height):
"""Updates the width and height sliders based on the selected aspect ratio."""
if target_ratio == "9:16":
changed_width = 720
changed_height = 1280
return changed_width, changed_height, gr.update(open=False) # Close accordion on preset
elif target_ratio == "16:9":
changed_width = 1280
changed_height = 720
return changed_width, changed_height, gr.update(open=False) # Close accordion on preset
elif target_ratio == "1:1":
changed_width = 1024
changed_height = 1024
return changed_width, changed_height, gr.update(open=False) # Close accordion on preset
elif target_ratio == "Custom":
# When switching to Custom, keep current slider values but open accordion
return ui_width, ui_height, gr.update(open=True)
# Should not happen, but return current values if it does
return ui_width, ui_height, gr.update()
def select_the_right_preset(user_width, user_height):
if user_width == 720 and user_height == 1280:
return "9:16"
elif user_width == 1280 and user_height == 720:
return "16:9"
elif user_width == 1024 and user_height == 1024:
return "1:1"
else:
return "Custom"
def toggle_custom_resize_slider(resize_option):
return gr.update(visible=(resize_option == "Custom"))
def update_history(new_image, history):
"""Updates the history gallery with the new image."""
if new_image is None: # Don't add None to history (e.g., on clear or error)
return history
if history is None:
history = []
# Prepend the new image (as PIL or path depending on Gallery config)
history.insert(0, new_image)
# Limit history size if desired (e.g., keep last 12)
max_history = 12
if len(history) > max_history:
history = history[:max_history]
return history
# --- CSS and Title ---
css = """
h1 {
text-align: center;
display: block;
}
.gradio-container {
max-width: 1280px !important;
margin: auto !important;
}
"""
title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
<p align="center">Expand images using ControlNet Union and Lightning models. Choose a base model below.</p>
"""
# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
with gr.Column():
gr.HTML(title)
with gr.Row():
with gr.Column(scale=2): # Input column
input_image = gr.Image(
type="pil",
label="Input Image"
)
# --- Model Selector ---
model_selector = gr.Dropdown(
label="Select Model",
choices=list(pipelines.keys()),
value="RealVisXL V5.0 Lightning", # Default model
)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(label="Prompt (Describe the desired output)", placeholder="e.g., beautiful landscape, photorealistic")
with gr.Column(scale=1, min_width=120):
run_button = gr.Button("Generate", variant="primary")
with gr.Row():
target_ratio = gr.Radio(
label="Target Ratio",
choices=["9:16", "16:9", "1:1", "Custom"],
value="9:16", # Default ratio
scale=2
)
alignment_dropdown = gr.Dropdown(
choices=["Middle", "Left", "Right", "Top", "Bottom"],
value="Middle",
label="Align Input Image"
)
with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
with gr.Column():
with gr.Row():
width_slider = gr.Slider(
label="Target Width",
minimum=512, # Lowered minimum slightly
maximum=1536,
step=64, # Steps of 64 common for SDXL
value=720, # Default width
)
height_slider = gr.Slider(
label="Target Height",
minimum=512, # Lowered minimum slightly
maximum=1536,
step=64, # Steps of 64
value=1280, # Default height
)
num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
with gr.Group():
overlap_percentage = gr.Slider(
label="Mask overlap (%)",
info="Percentage of the input image edge to keep (reduces seams)",
minimum=1,
maximum=50,
value=10, # Default overlap
step=1
)
gr.Markdown("Select edges to apply overlap:")
with gr.Row():
overlap_top = gr.Checkbox(label="Top", value=True)
overlap_right = gr.Checkbox(label="Right", value=True)
overlap_left = gr.Checkbox(label="Left", value=True)
overlap_bottom = gr.Checkbox(label="Bottom", value=True)
with gr.Row():
resize_option = gr.Radio(
label="Resize input image before placing",
info="Scale the input image relative to its fitted size",
choices=["Full", "50%", "33%", "25%", "Custom"],
value="Full" # Default resize option
)
custom_resize_percentage = gr.Slider(
label="Custom resize (%)",
minimum=1,
maximum=100,
step=1,
value=50,
visible=False # Initially hidden
)
gr.Examples(
examples=[
["./examples/example_1.webp", "RealVisXL V5.0 Lightning", 1280, 720, "Middle"],
["./examples/example_2.jpg", "RealVisXL V4.0 Lightning", 1440, 810, "Left"],
["./examples/example_3.jpg", "RealVisXL V5.0 Lightning", 1024, 1024, "Top"],
["./examples/example_3.jpg", "RealVisXL V5.0 Lightning", 1024, 1024, "Bottom"],
],
inputs=[input_image, model_selector, width_slider, height_slider, alignment_dropdown],
label="Examples (Prompt is optional)"
)
with gr.Column(scale=3): # Output column
result = gr.Image(
interactive=False,
label="Generated Image",
format="png",
)
history_gallery = gr.Gallery(
label="History",
columns=4, # Adjust columns as needed
object_fit="contain",
interactive=False,
show_label=True,
allow_preview=True,
preview=True
)
# --- Event Listeners ---
# Update sliders and accordion based on ratio selection
target_ratio.change(
fn=preload_presets,
inputs=[target_ratio, width_slider, height_slider],
outputs=[width_slider, height_slider, settings_panel],
queue=False
)
# Update ratio selection based on slider changes
width_slider.change(
fn=select_the_right_preset,
inputs=[width_slider, height_slider],
outputs=[target_ratio],
queue=False
)
height_slider.change(
fn=select_the_right_preset,
inputs=[width_slider, height_slider],
outputs=[target_ratio],
queue=False
)
# Show/hide custom resize slider
resize_option.change(
fn=toggle_custom_resize_slider,
inputs=[resize_option],
outputs=[custom_resize_percentage],
queue=False
)
# Define inputs for the main inference function
infer_inputs = [
model_selector, input_image, width_slider, height_slider, overlap_percentage,
num_inference_steps, resize_option, custom_resize_percentage, prompt_input,
alignment_dropdown, overlap_left, overlap_right, overlap_top, overlap_bottom
]
# --- Run Button Click ---
run_button.click(
fn=clear_result,
inputs=None,
outputs=[result], # Clear only the main result image
queue=False # Clearing should be fast
).then(
fn=infer,
inputs=infer_inputs,
outputs=[result], # Output to the main result image
).then(
fn=update_history, # Use the specific update function
inputs=[result, history_gallery], # Pass the result and current history
outputs=[history_gallery], # Update the history gallery
)
# --- Prompt Submit (Enter Key) ---
prompt_input.submit(
fn=clear_result,
inputs=None,
outputs=[result],
queue=False
).then(
fn=infer,
inputs=infer_inputs,
outputs=[result],
).then(
fn=update_history,
inputs=[result, history_gallery],
outputs=[history_gallery],
)
# --- Launch App ---
# Make sure you have example images at the specified paths or remove/update the gr.Examples section
# Create an 'examples' directory and place images like 'example_1.webp', 'example_2.jpg', 'example_3.jpg' inside it.
demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)