Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,7 +19,7 @@ import random
|
|
| 19 |
|
| 20 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
|
| 22 |
-
# --- Helper Dataclasses (Identical to
|
| 23 |
@dataclass
|
| 24 |
class BoundingBox:
|
| 25 |
xmin: int
|
|
@@ -48,7 +48,7 @@ class DetectionResult:
|
|
| 48 |
ymax=detection_dict['box']['ymax']))
|
| 49 |
|
| 50 |
|
| 51 |
-
# --- Helper Functions (Identical to
|
| 52 |
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
|
| 53 |
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 54 |
if not contours:
|
|
@@ -127,7 +127,7 @@ def make_diptych(image):
|
|
| 127 |
return Image.fromarray(diptych_np)
|
| 128 |
|
| 129 |
|
| 130 |
-
# --- Custom Attention Processor (
|
| 131 |
class CustomFluxAttnProcessor2_0:
|
| 132 |
def __init__(self, height=44, width=88, attn_enforce=1.0):
|
| 133 |
if not hasattr(F, "scaled_dot_product_attention"):
|
|
@@ -197,7 +197,6 @@ class CustomFluxAttnProcessor2_0:
|
|
| 197 |
print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---")
|
| 198 |
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
|
| 199 |
pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
|
| 200 |
-
# pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
|
| 201 |
|
| 202 |
pipe.transformer.to(torch.bfloat16)
|
| 203 |
pipe.controlnet.to(torch.bfloat16)
|
|
@@ -213,21 +212,21 @@ print("--- All models loaded successfully! ---")
|
|
| 213 |
|
| 214 |
def get_duration(
|
| 215 |
input_image: Image.Image,
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
attn_enforce: float
|
| 219 |
-
ctrl_scale: float
|
| 220 |
-
width: int
|
| 221 |
-
height: int
|
| 222 |
-
pixel_offset: int
|
| 223 |
-
num_steps: int
|
| 224 |
-
guidance: float
|
| 225 |
-
real_guidance: float
|
| 226 |
-
seed: int
|
| 227 |
-
randomize_seed: bool
|
| 228 |
progress=gr.Progress(track_tqdm=True)
|
| 229 |
):
|
| 230 |
-
if width > 768
|
| 231 |
return 210
|
| 232 |
else:
|
| 233 |
return 120
|
|
@@ -236,17 +235,18 @@ def get_duration(
|
|
| 236 |
def run_diptych_prompting(
|
| 237 |
input_image: Image.Image,
|
| 238 |
subject_name: str,
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
| 250 |
progress=gr.Progress(track_tqdm=True)
|
| 251 |
):
|
| 252 |
if randomize_seed:
|
|
@@ -255,40 +255,42 @@ def run_diptych_prompting(
|
|
| 255 |
actual_seed = seed
|
| 256 |
|
| 257 |
if input_image is None: raise gr.Error("Please upload a reference image.")
|
| 258 |
-
if not
|
| 259 |
-
if not target_prompt: raise gr.Error("Please provide a target prompt.")
|
| 260 |
|
| 261 |
-
# 1. Prepare dimensions
|
| 262 |
padded_width = width + pixel_offset * 2
|
| 263 |
padded_height = height + pixel_offset * 2
|
| 264 |
diptych_size = (padded_width * 2, padded_height)
|
| 265 |
-
|
| 266 |
-
# 2. Prepare prompts and images
|
| 267 |
-
progress(0, desc="Resizing and segmenting reference image...")
|
| 268 |
-
base_prompt = f"a photo of {subject_name}"
|
| 269 |
-
diptych_text_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, {base_prompt}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
|
| 270 |
-
|
| 271 |
reference_image = input_image.resize((padded_width, padded_height)).convert("RGB")
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
|
|
|
| 274 |
progress(0.2, desc="Creating diptych and mask...")
|
| 275 |
mask_image = np.concatenate([np.zeros((padded_height, padded_width, 3)), np.ones((padded_height, padded_width, 3)) * 255], axis=1)
|
| 276 |
mask_image = Image.fromarray(mask_image.astype(np.uint8))
|
| 277 |
-
diptych_image_prompt = make_diptych(
|
| 278 |
|
| 279 |
-
#
|
| 280 |
progress(0.3, desc="Setting up attention processors...")
|
| 281 |
new_attn_procs = base_attn_procs.copy()
|
| 282 |
for k in new_attn_procs:
|
| 283 |
-
# Use full diptych dimensions for the attention processor
|
| 284 |
new_attn_procs[k] = CustomFluxAttnProcessor2_0(height=padded_height // 16, width=padded_width * 2 // 16, attn_enforce=attn_enforce)
|
| 285 |
pipe.transformer.set_attn_processor(new_attn_procs)
|
| 286 |
|
| 287 |
-
#
|
| 288 |
progress(0.4, desc="Running diffusion process...")
|
| 289 |
generator = torch.Generator(device="cuda").manual_seed(actual_seed)
|
| 290 |
-
|
| 291 |
-
prompt=
|
| 292 |
height=diptych_size[1],
|
| 293 |
width=diptych_size[0],
|
| 294 |
control_image=diptych_image_prompt,
|
|
@@ -301,14 +303,13 @@ def run_diptych_prompting(
|
|
| 301 |
true_guidance_scale=real_guidance
|
| 302 |
).images[0]
|
| 303 |
|
| 304 |
-
#
|
| 305 |
progress(0.95, desc="Finalizing image...")
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
# Crop the pixel offset padding
|
| 309 |
-
result = result.crop((pixel_offset, pixel_offset, padded_width - pixel_offset, padded_height - pixel_offset))
|
| 310 |
|
| 311 |
-
|
|
|
|
| 312 |
|
| 313 |
|
| 314 |
# --- Gradio UI Definition ---
|
|
@@ -318,18 +319,29 @@ css = '''
|
|
| 318 |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
| 319 |
gr.Markdown(
|
| 320 |
"""
|
| 321 |
-
# Diptych Prompting: Zero-Shot Subject-Driven Image Generation
|
| 322 |
### Gradio Demo for the paper "[Large-Scale Text-to-Image Model with Inpainting is a Zero-Shot Subject-Driven Image Generator](https://diptychprompting.github.io/)"
|
|
|
|
| 323 |
"""
|
| 324 |
)
|
| 325 |
with gr.Row():
|
| 326 |
with gr.Column(scale=1):
|
| 327 |
-
input_image = gr.Image(type="pil", label="
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
run_button = gr.Button("Generate Image", variant="primary")
|
|
|
|
| 331 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement")
|
|
|
|
| 333 |
ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale")
|
| 334 |
num_steps = gr.Slider(minimum=20, maximum=50, value=28, step=1, label="Inference Steps")
|
| 335 |
guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Distilled Guidance Scale")
|
|
@@ -339,9 +351,85 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
| 339 |
pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)")
|
| 340 |
seed = gr.Slider(minimum=0, maximum=9223372036854775807, value=42, step=1, label="Seed")
|
| 341 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
|
|
|
| 342 |
with gr.Column(scale=1):
|
| 343 |
output_image = gr.Image(type="pil", label="Generated Image")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
gr.Examples(
|
| 346 |
examples=[
|
| 347 |
["./assets/cat_squished.png", "a cat toy", "a cat toy riding a skate"],
|
|
@@ -349,16 +437,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
| 349 |
["./assets/bear_plushie.jpg", "a bear plushie", "a bear plushie drinking bubble tea"]
|
| 350 |
],
|
| 351 |
inputs=[input_image, subject_name, target_prompt],
|
| 352 |
-
outputs=output_image,
|
| 353 |
-
fn=
|
| 354 |
-
cache_examples="lazy"
|
| 355 |
-
)
|
| 356 |
-
|
| 357 |
-
run_button.click(
|
| 358 |
-
fn=run_diptych_prompting,
|
| 359 |
-
inputs=[input_image, subject_name, target_prompt, attn_enforce, ctrl_scale, width, height, pixel_offset, num_steps, guidance, real_guidance, seed, randomize_seed],
|
| 360 |
-
outputs=output_image
|
| 361 |
)
|
| 362 |
|
| 363 |
if __name__ == "__main__":
|
| 364 |
-
demo.launch(share=True)
|
|
|
|
| 19 |
|
| 20 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
|
| 22 |
+
# --- Helper Dataclasses (Identical to previous version) ---
|
| 23 |
@dataclass
|
| 24 |
class BoundingBox:
|
| 25 |
xmin: int
|
|
|
|
| 48 |
ymax=detection_dict['box']['ymax']))
|
| 49 |
|
| 50 |
|
| 51 |
+
# --- Helper Functions (Identical to previous version) ---
|
| 52 |
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
|
| 53 |
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 54 |
if not contours:
|
|
|
|
| 127 |
return Image.fromarray(diptych_np)
|
| 128 |
|
| 129 |
|
| 130 |
+
# --- Custom Attention Processor (Identical to previous version) ---
|
| 131 |
class CustomFluxAttnProcessor2_0:
|
| 132 |
def __init__(self, height=44, width=88, attn_enforce=1.0):
|
| 133 |
if not hasattr(F, "scaled_dot_product_attention"):
|
|
|
|
| 197 |
print("--- Loading Models: This may take a few minutes and requires >40GB VRAM ---")
|
| 198 |
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", torch_dtype=torch.bfloat16)
|
| 199 |
pipe = FluxControlNetInpaintingPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet, torch_dtype=torch.bfloat16).to(device)
|
|
|
|
| 200 |
|
| 201 |
pipe.transformer.to(torch.bfloat16)
|
| 202 |
pipe.controlnet.to(torch.bfloat16)
|
|
|
|
| 212 |
|
| 213 |
def get_duration(
|
| 214 |
input_image: Image.Image,
|
| 215 |
+
do_segmentation: bool,
|
| 216 |
+
full_prompt: str,
|
| 217 |
+
attn_enforce: float,
|
| 218 |
+
ctrl_scale: float,
|
| 219 |
+
width: int,
|
| 220 |
+
height: int,
|
| 221 |
+
pixel_offset: int,
|
| 222 |
+
num_steps: int,
|
| 223 |
+
guidance: float,
|
| 224 |
+
real_guidance: float,
|
| 225 |
+
seed: int,
|
| 226 |
+
randomize_seed: bool,
|
| 227 |
progress=gr.Progress(track_tqdm=True)
|
| 228 |
):
|
| 229 |
+
if width > 768 or height > 768:
|
| 230 |
return 210
|
| 231 |
else:
|
| 232 |
return 120
|
|
|
|
| 235 |
def run_diptych_prompting(
|
| 236 |
input_image: Image.Image,
|
| 237 |
subject_name: str,
|
| 238 |
+
do_segmentation: bool,
|
| 239 |
+
full_prompt: str,
|
| 240 |
+
attn_enforce: float,
|
| 241 |
+
ctrl_scale: float,
|
| 242 |
+
width: int,
|
| 243 |
+
height: int,
|
| 244 |
+
pixel_offset: int,
|
| 245 |
+
num_steps: int,
|
| 246 |
+
guidance: float,
|
| 247 |
+
real_guidance: float,
|
| 248 |
+
seed: int,
|
| 249 |
+
randomize_seed: bool,
|
| 250 |
progress=gr.Progress(track_tqdm=True)
|
| 251 |
):
|
| 252 |
if randomize_seed:
|
|
|
|
| 255 |
actual_seed = seed
|
| 256 |
|
| 257 |
if input_image is None: raise gr.Error("Please upload a reference image.")
|
| 258 |
+
if not full_prompt: raise gr.Error("Full Prompt is empty. Please fill out the prompt fields.")
|
|
|
|
| 259 |
|
| 260 |
+
# 1. Prepare dimensions and reference image
|
| 261 |
padded_width = width + pixel_offset * 2
|
| 262 |
padded_height = height + pixel_offset * 2
|
| 263 |
diptych_size = (padded_width * 2, padded_height)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
reference_image = input_image.resize((padded_width, padded_height)).convert("RGB")
|
| 265 |
+
|
| 266 |
+
# 2. Process reference image based on segmentation flag
|
| 267 |
+
progress(0, desc="Preparing reference image...")
|
| 268 |
+
if do_segmentation:
|
| 269 |
+
if not subject_name:
|
| 270 |
+
raise gr.Error("Subject Name is required when 'Do Segmentation' is checked.")
|
| 271 |
+
progress(0.05, desc="Segmenting reference image...")
|
| 272 |
+
processed_image = segment_image(reference_image, subject_name, object_detector, segmentator, segment_processor)
|
| 273 |
+
else:
|
| 274 |
+
processed_image = reference_image
|
| 275 |
|
| 276 |
+
# 3. Create diptych and mask
|
| 277 |
progress(0.2, desc="Creating diptych and mask...")
|
| 278 |
mask_image = np.concatenate([np.zeros((padded_height, padded_width, 3)), np.ones((padded_height, padded_width, 3)) * 255], axis=1)
|
| 279 |
mask_image = Image.fromarray(mask_image.astype(np.uint8))
|
| 280 |
+
diptych_image_prompt = make_diptych(processed_image)
|
| 281 |
|
| 282 |
+
# 4. Setup Attention Processor
|
| 283 |
progress(0.3, desc="Setting up attention processors...")
|
| 284 |
new_attn_procs = base_attn_procs.copy()
|
| 285 |
for k in new_attn_procs:
|
|
|
|
| 286 |
new_attn_procs[k] = CustomFluxAttnProcessor2_0(height=padded_height // 16, width=padded_width * 2 // 16, attn_enforce=attn_enforce)
|
| 287 |
pipe.transformer.set_attn_processor(new_attn_procs)
|
| 288 |
|
| 289 |
+
# 5. Run Inference
|
| 290 |
progress(0.4, desc="Running diffusion process...")
|
| 291 |
generator = torch.Generator(device="cuda").manual_seed(actual_seed)
|
| 292 |
+
full_diptych_result = pipe(
|
| 293 |
+
prompt=full_prompt,
|
| 294 |
height=diptych_size[1],
|
| 295 |
width=diptych_size[0],
|
| 296 |
control_image=diptych_image_prompt,
|
|
|
|
| 303 |
true_guidance_scale=real_guidance
|
| 304 |
).images[0]
|
| 305 |
|
| 306 |
+
# 6. Final cropping
|
| 307 |
progress(0.95, desc="Finalizing image...")
|
| 308 |
+
final_image = full_diptych_result.crop((padded_width, 0, padded_width * 2, padded_height))
|
| 309 |
+
final_image = final_image.crop((pixel_offset, pixel_offset, padded_width - pixel_offset, padded_height - pixel_offset))
|
|
|
|
|
|
|
| 310 |
|
| 311 |
+
# 7. Return all outputs
|
| 312 |
+
return final_image, processed_image, full_diptych_result, full_prompt
|
| 313 |
|
| 314 |
|
| 315 |
# --- Gradio UI Definition ---
|
|
|
|
| 319 |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
| 320 |
gr.Markdown(
|
| 321 |
"""
|
| 322 |
+
# Diptych Prompting: Zero-Shot Subject-Driven & Style-Driven Image Generation
|
| 323 |
### Gradio Demo for the paper "[Large-Scale Text-to-Image Model with Inpainting is a Zero-Shot Subject-Driven Image Generator](https://diptychprompting.github.io/)"
|
| 324 |
+
This demo implements both subject-driven generation and style transfer with advanced controls.
|
| 325 |
"""
|
| 326 |
)
|
| 327 |
with gr.Row():
|
| 328 |
with gr.Column(scale=1):
|
| 329 |
+
input_image = gr.Image(type="pil", label="Reference Image")
|
| 330 |
+
|
| 331 |
+
with gr.Group() as subject_driven_group:
|
| 332 |
+
subject_name = gr.Textbox(label="Subject Name", placeholder="e.g., a plush bear")
|
| 333 |
+
|
| 334 |
+
target_prompt = gr.Textbox(label="Target Prompt", placeholder="e.g., riding a skateboard on the moon")
|
| 335 |
+
|
| 336 |
run_button = gr.Button("Generate Image", variant="primary")
|
| 337 |
+
|
| 338 |
with gr.Accordion("Advanced Settings", open=False):
|
| 339 |
+
mode = gr.Radio(["Subject-Driven", "Style-Driven (unstable)"], label="Generation Mode", value="Subject-Driven")
|
| 340 |
+
with gr.Group(visible=False) as style_driven_group:
|
| 341 |
+
original_style_description = gr.Textbox(label="Original Image Description", placeholder="e.g., in watercolor painting style")
|
| 342 |
+
do_segmentation = gr.Checkbox(label="Do Segmentation", value=True)
|
| 343 |
attn_enforce = gr.Slider(minimum=1.0, maximum=2.0, value=1.3, step=0.05, label="Attention Enforcement")
|
| 344 |
+
full_prompt = gr.Textbox(label="Full Prompt (Auto-generated, editable)", lines=3)
|
| 345 |
ctrl_scale = gr.Slider(minimum=0.5, maximum=1.0, value=0.95, step=0.01, label="ControlNet Scale")
|
| 346 |
num_steps = gr.Slider(minimum=20, maximum=50, value=28, step=1, label="Inference Steps")
|
| 347 |
guidance = gr.Slider(minimum=1.0, maximum=10.0, value=3.5, step=0.1, label="Distilled Guidance Scale")
|
|
|
|
| 351 |
pixel_offset = gr.Slider(minimum=0, maximum=32, value=8, step=1, label="Padding (Pixel Offset)")
|
| 352 |
seed = gr.Slider(minimum=0, maximum=9223372036854775807, value=42, step=1, label="Seed")
|
| 353 |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
| 354 |
+
|
| 355 |
with gr.Column(scale=1):
|
| 356 |
output_image = gr.Image(type="pil", label="Generated Image")
|
| 357 |
+
with gr.Accordion("Other Outputs", open=False) as other_outputs_accordion:
|
| 358 |
+
processed_ref_image = gr.Image(label="Processed Reference (Left Panel)")
|
| 359 |
+
full_diptych_image = gr.Image(label="Full Diptych Output")
|
| 360 |
+
final_prompt_used = gr.Textbox(label="Final Prompt Used")
|
| 361 |
|
| 362 |
+
# --- UI Event Handlers ---
|
| 363 |
+
|
| 364 |
+
def toggle_mode_visibility(mode_choice):
|
| 365 |
+
"""Hides/shows the relevant input textboxes based on mode."""
|
| 366 |
+
if mode_choice == "Subject-Driven":
|
| 367 |
+
return gr.update(visible=True), gr.update(visible=False)
|
| 368 |
+
else:
|
| 369 |
+
return gr.update(visible=False), gr.update(visible=True)
|
| 370 |
+
|
| 371 |
+
def update_derived_fields(mode_choice, subject, style_desc, target):
|
| 372 |
+
"""Updates the full prompt and segmentation checkbox based on other inputs."""
|
| 373 |
+
if mode_choice == "Subject-Driven":
|
| 374 |
+
prompt = f"A diptych with two side-by-side images of same {subject}. On the left, a photo of {subject}. On the right, replicate this {subject} exactly but as {target}"
|
| 375 |
+
return gr.update(value=prompt), gr.update(value=True)
|
| 376 |
+
else: # Style-Driven
|
| 377 |
+
prompt = f"A diptych with two side-by-side images of same style. On the left, {style_desc}. On the right, replicate this style exactly but as {target}"
|
| 378 |
+
return gr.update(value=prompt), gr.update(value=False)
|
| 379 |
+
|
| 380 |
+
# --- UI Connections ---
|
| 381 |
+
|
| 382 |
+
# When mode changes, toggle visibility of the specific prompt fields
|
| 383 |
+
mode.change(
|
| 384 |
+
fn=toggle_mode_visibility,
|
| 385 |
+
inputs=mode,
|
| 386 |
+
outputs=[subject_driven_group, style_driven_group],
|
| 387 |
+
queue=False
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# A list of all inputs that affect the full prompt or segmentation checkbox
|
| 391 |
+
prompt_component_inputs = [mode, subject_name, original_style_description, target_prompt]
|
| 392 |
+
# A list of the UI elements that are derived from the above inputs
|
| 393 |
+
derived_outputs = [full_prompt, do_segmentation]
|
| 394 |
+
|
| 395 |
+
# When any prompt component changes, update the derived fields
|
| 396 |
+
for component in prompt_component_inputs:
|
| 397 |
+
# Use .then() to chain the update after the visibility toggle for the mode radio
|
| 398 |
+
if component == mode:
|
| 399 |
+
component.change(update_derived_fields, inputs=prompt_component_inputs, outputs=derived_outputs, queue=False)
|
| 400 |
+
else:
|
| 401 |
+
component.input(update_derived_fields, inputs=prompt_component_inputs, outputs=derived_outputs, queue=False)
|
| 402 |
+
|
| 403 |
+
run_button.click(
|
| 404 |
+
fn=run_diptych_prompting,
|
| 405 |
+
inputs=[
|
| 406 |
+
input_image, subject_name, do_segmentation, full_prompt, attn_enforce,
|
| 407 |
+
ctrl_scale, width, height, pixel_offset, num_steps, guidance,
|
| 408 |
+
real_guidance, seed, randomize_seed
|
| 409 |
+
],
|
| 410 |
+
outputs=[output_image, processed_ref_image, full_diptych_image, final_prompt_used]
|
| 411 |
+
)
|
| 412 |
+
def run_subject_driven_example(input_image, subject_name, target_prompt):
|
| 413 |
+
# Construct the full prompt for subject-driven mode
|
| 414 |
+
full_prompt = f"A diptych with two side-by-side images of same {subject_name}. On the left, a photo of {subject_name}. On the right, replicate this {subject_name} exactly but as {target_prompt}"
|
| 415 |
+
|
| 416 |
+
# Call the main function with all arguments, using defaults for subject-driven mode
|
| 417 |
+
return run_diptych_prompting(
|
| 418 |
+
input_image=input_image,
|
| 419 |
+
subject_name=subject_name,
|
| 420 |
+
do_segmentation=True,
|
| 421 |
+
full_prompt=full_prompt,
|
| 422 |
+
attn_enforce=1.3,
|
| 423 |
+
ctrl_scale=0.95,
|
| 424 |
+
width=768,
|
| 425 |
+
height=768,
|
| 426 |
+
pixel_offset=8,
|
| 427 |
+
num_steps=28,
|
| 428 |
+
guidance=3.5,
|
| 429 |
+
real_guidance=4.5,
|
| 430 |
+
seed=42,
|
| 431 |
+
randomize_seed=False,
|
| 432 |
+
)
|
| 433 |
gr.Examples(
|
| 434 |
examples=[
|
| 435 |
["./assets/cat_squished.png", "a cat toy", "a cat toy riding a skate"],
|
|
|
|
| 437 |
["./assets/bear_plushie.jpg", "a bear plushie", "a bear plushie drinking bubble tea"]
|
| 438 |
],
|
| 439 |
inputs=[input_image, subject_name, target_prompt],
|
| 440 |
+
outputs=[output_image, processed_ref_image, full_diptych_image, final_prompt_used],
|
| 441 |
+
fn=run_subject_driven_example,
|
| 442 |
+
cache_examples="lazy"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
)
|
| 444 |
|
| 445 |
if __name__ == "__main__":
|
| 446 |
+
demo.launch(share=True, debug=True)
|