Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,65 +1,36 @@
|
|
| 1 |
-
# app.py
|
| 2 |
-
|
| 3 |
import gradio as gr
|
|
|
|
| 4 |
from PIL import Image, ImageDraw
|
| 5 |
import torch
|
| 6 |
-
import numpy as np
|
| 7 |
from transformers import SamModel, SamProcessor
|
| 8 |
from diffusers import StableDiffusionInpaintPipeline
|
| 9 |
|
| 10 |
# Constants
|
| 11 |
IMG_SIZE = 512
|
| 12 |
|
| 13 |
-
# Initialize SAM model and processor on CPU
|
| 14 |
-
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
|
| 15 |
-
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
| 16 |
-
|
| 17 |
-
# Initialize Inpainting pipeline on CPU with a compatible model
|
| 18 |
-
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 19 |
-
"stabilityai/stable-diffusion-2-inpainting",
|
| 20 |
-
torch_dtype=torch.float32
|
| 21 |
-
).to("cpu")
|
| 22 |
-
# No need for model_cpu_offload on CPU
|
| 23 |
-
|
| 24 |
# Global variables to store points and the original image
|
| 25 |
input_points = []
|
| 26 |
input_image = None
|
| 27 |
|
| 28 |
-
def
|
| 29 |
"""
|
| 30 |
-
|
| 31 |
"""
|
| 32 |
-
|
| 33 |
-
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
|
| 34 |
-
return bg_transparent
|
| 35 |
-
|
| 36 |
-
def generate_mask(image, input_points):
|
| 37 |
-
"""
|
| 38 |
-
Generates a binary mask using SAM based on input points.
|
| 39 |
-
|
| 40 |
-
Args:
|
| 41 |
-
image (PIL.Image): The input image.
|
| 42 |
-
input_points (list of lists): List of points selected by the user.
|
| 43 |
-
|
| 44 |
-
Returns:
|
| 45 |
-
np.ndarray: Binary mask where the object is marked with 1s.
|
| 46 |
-
"""
|
| 47 |
-
if not input_points:
|
| 48 |
return None
|
| 49 |
-
|
| 50 |
-
# Convert image to RGB if not already
|
| 51 |
image = image.convert("RGB")
|
|
|
|
| 52 |
|
| 53 |
-
#
|
| 54 |
-
|
|
|
|
| 55 |
|
| 56 |
-
# Prepare inputs for SAM
|
| 57 |
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
|
| 58 |
|
| 59 |
with torch.no_grad():
|
| 60 |
outputs = sam_model(**inputs)
|
| 61 |
|
| 62 |
-
# Post-process masks
|
| 63 |
masks = sam_processor.image_processor.post_process_masks(
|
| 64 |
outputs.pred_masks.cpu(),
|
| 65 |
inputs["original_sizes"].cpu(),
|
|
@@ -69,32 +40,24 @@ def generate_mask(image, input_points):
|
|
| 69 |
if len(masks) == 0:
|
| 70 |
return None
|
| 71 |
|
| 72 |
-
# Select the mask with the highest IoU score
|
| 73 |
best_mask = masks[0][0][outputs.iou_scores.argmax()]
|
| 74 |
-
|
| 75 |
-
# Invert mask: object=1, background=0
|
| 76 |
binary_mask = ~best_mask.numpy().astype(bool).astype(int)
|
| 77 |
|
| 78 |
return binary_mask
|
| 79 |
|
| 80 |
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
|
| 81 |
"""
|
| 82 |
-
Replaces the
|
| 83 |
-
|
| 84 |
-
Args:
|
| 85 |
-
image (PIL.Image): The original image.
|
| 86 |
-
mask (np.ndarray): Binary mask of the selected object.
|
| 87 |
-
prompt (str): Text prompt describing the replacement.
|
| 88 |
-
negative_prompt (str): Negative text prompt to refine generation.
|
| 89 |
-
seed (int): Random seed for reproducibility.
|
| 90 |
-
guidance_scale (float): Guidance scale for the inpainting model.
|
| 91 |
-
|
| 92 |
-
Returns:
|
| 93 |
-
PIL.Image: The augmented image with the object replaced.
|
| 94 |
"""
|
| 95 |
if mask is None:
|
| 96 |
return image
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
|
| 99 |
|
| 100 |
generator = torch.Generator("cpu").manual_seed(seed)
|
|
@@ -116,46 +79,33 @@ def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
|
|
| 116 |
def visualize_mask(image, mask):
|
| 117 |
"""
|
| 118 |
Overlays the mask on the image for visualization.
|
| 119 |
-
|
| 120 |
-
Args:
|
| 121 |
-
image (PIL.Image): The original image.
|
| 122 |
-
mask (np.ndarray): Binary mask of the selected object.
|
| 123 |
-
|
| 124 |
-
Returns:
|
| 125 |
-
PIL.Image: Image with mask overlay.
|
| 126 |
"""
|
| 127 |
if mask is None:
|
| 128 |
return image
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
| 133 |
return overlay.convert("RGB")
|
| 134 |
|
| 135 |
def get_points(img, evt: gr.SelectData):
|
| 136 |
"""
|
| 137 |
Captures points selected by the user on the image.
|
| 138 |
-
|
| 139 |
-
Args:
|
| 140 |
-
img (PIL.Image): The uploaded image.
|
| 141 |
-
evt (gr.SelectData): Event data containing the point coordinates.
|
| 142 |
-
|
| 143 |
-
Returns:
|
| 144 |
-
Tuple: (Updated mask visualization, Updated image with crossmarks)
|
| 145 |
"""
|
| 146 |
global input_points
|
| 147 |
global input_image
|
| 148 |
|
| 149 |
-
# The first time this is called, save the untouched input image
|
| 150 |
if len(input_points) == 0:
|
| 151 |
input_image = img.copy()
|
| 152 |
|
| 153 |
x = evt.index[0]
|
| 154 |
y = evt.index[1]
|
| 155 |
-
|
| 156 |
input_points.append([x, y])
|
| 157 |
|
| 158 |
-
#
|
| 159 |
mask = generate_mask(input_image, input_points)
|
| 160 |
|
| 161 |
# Mark selected points with a green crossmark
|
|
@@ -174,16 +124,6 @@ def get_points(img, evt: gr.SelectData):
|
|
| 174 |
def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
|
| 175 |
"""
|
| 176 |
Runs the inpainting process based on user inputs.
|
| 177 |
-
|
| 178 |
-
Args:
|
| 179 |
-
prompt (str): Prompt for infill.
|
| 180 |
-
negative_prompt (str): Negative prompt.
|
| 181 |
-
cfg (float): Classifier-Free Guidance Scale.
|
| 182 |
-
seed (int): Random seed.
|
| 183 |
-
invert (bool): Whether to infill the subject instead of the background.
|
| 184 |
-
|
| 185 |
-
Returns:
|
| 186 |
-
PIL.Image: The inpainted image.
|
| 187 |
"""
|
| 188 |
global input_image
|
| 189 |
global input_points
|
|
@@ -209,9 +149,6 @@ def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
|
|
| 209 |
def reset_points_func():
|
| 210 |
"""
|
| 211 |
Resets the selected points and the input image.
|
| 212 |
-
|
| 213 |
-
Returns:
|
| 214 |
-
Tuple: (Reset mask visualization, Reset image, Empty inpainted image)
|
| 215 |
"""
|
| 216 |
global input_points
|
| 217 |
global input_image
|
|
@@ -222,19 +159,11 @@ def reset_points_func():
|
|
| 222 |
def preprocess(input_img):
|
| 223 |
"""
|
| 224 |
Preprocesses the uploaded image to ensure it is square and resized.
|
| 225 |
-
|
| 226 |
-
Args:
|
| 227 |
-
input_img (PIL.Image): The uploaded image.
|
| 228 |
-
|
| 229 |
-
Returns:
|
| 230 |
-
PIL.Image: The preprocessed image.
|
| 231 |
"""
|
| 232 |
if input_img is None:
|
| 233 |
return None
|
| 234 |
|
| 235 |
-
# Make sure the image is square
|
| 236 |
width, height = input_img.size
|
| 237 |
-
|
| 238 |
if width != height:
|
| 239 |
# Add white padding to make the image square
|
| 240 |
new_size = max(width, height)
|
|
@@ -246,104 +175,142 @@ def preprocess(input_img):
|
|
| 246 |
|
| 247 |
return input_img.resize((IMG_SIZE, IMG_SIZE))
|
| 248 |
|
| 249 |
-
|
| 250 |
-
"""
|
| 251 |
-
Builds and launches the Gradio app.
|
| 252 |
-
|
| 253 |
-
Args:
|
| 254 |
-
get_processed_inputs (function): Function to process inputs for SAM.
|
| 255 |
-
inpaint (function): Function to perform inpainting.
|
| 256 |
-
|
| 257 |
-
Returns:
|
| 258 |
-
None
|
| 259 |
-
"""
|
| 260 |
-
with gr.Blocks() as demo:
|
| 261 |
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
"""
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
**Instructions:**
|
| 268 |
-
1. **Upload Image:** Click on the first image box to upload your image.
|
| 269 |
-
2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
|
| 270 |
-
3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
|
| 271 |
-
4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
|
| 272 |
-
5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
|
| 273 |
-
6. **Reset:** Click the "Reset" button to clear selections and start over.
|
| 274 |
""")
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
|
| 285 |
-
|
| 286 |
-
# Preprocess image on change
|
| 287 |
-
upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
|
| 288 |
-
|
| 289 |
-
# Text inputs and settings
|
| 290 |
-
prompt = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
|
| 291 |
-
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
|
| 292 |
-
cfg = gr.Slider(
|
| 293 |
-
label="Classifier-Free Guidance Scale",
|
| 294 |
-
minimum=1.0,
|
| 295 |
-
maximum=20.0,
|
| 296 |
-
value=7.5,
|
| 297 |
-
step=0.5
|
| 298 |
-
)
|
| 299 |
-
seed = gr.Number(label="Seed", value=42, precision=0)
|
| 300 |
-
invert = gr.Checkbox(label="Infill subject instead of background")
|
| 301 |
-
|
| 302 |
-
# Buttons
|
| 303 |
-
replace_button = gr.Button("Replace Object")
|
| 304 |
-
reset_button = gr.Button("Reset")
|
| 305 |
-
with gr.Column():
|
| 306 |
-
# Output images
|
| 307 |
-
augmented_image = gr.Image(label="Augmented Image", type="pil", interactive=False)
|
| 308 |
-
|
| 309 |
-
# Define button actions
|
| 310 |
-
replace_button.click(
|
| 311 |
-
fn=run_inpaint,
|
| 312 |
-
inputs=[prompt, negative_prompt, cfg, seed, invert],
|
| 313 |
-
outputs=[augmented_image]
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
reset_button.click(
|
| 317 |
-
fn=reset_points_func,
|
| 318 |
-
inputs=[],
|
| 319 |
-
outputs=[mask_visualization, selected_image, augmented_image]
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
-
# Examples (optional)
|
| 323 |
-
gr.Markdown(
|
| 324 |
-
"""
|
| 325 |
-
## EXAMPLES
|
| 326 |
-
Click on an example to load it. Then, follow the instructions above.
|
| 327 |
-
""")
|
| 328 |
-
|
| 329 |
-
with gr.Row():
|
| 330 |
-
examples = gr.Examples(
|
| 331 |
-
examples=[
|
| 332 |
-
["car.png", "a red sports car", "blurry, low quality", 42],
|
| 333 |
-
["house.jpg", "a modern villa", "dark, overexposed", 123],
|
| 334 |
-
["tree.png", "a blooming cherry tree", "underexposed, low contrast", 999]
|
| 335 |
],
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
],
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
from PIL import Image, ImageDraw
|
| 4 |
import torch
|
|
|
|
| 5 |
from transformers import SamModel, SamProcessor
|
| 6 |
from diffusers import StableDiffusionInpaintPipeline
|
| 7 |
|
| 8 |
# Constants
|
| 9 |
IMG_SIZE = 512
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# Global variables to store points and the original image
|
| 12 |
input_points = []
|
| 13 |
input_image = None
|
| 14 |
|
| 15 |
+
def generate_mask(image, points):
|
| 16 |
"""
|
| 17 |
+
Generates a mask using SAM based on input points.
|
| 18 |
"""
|
| 19 |
+
if not points:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
return None
|
| 21 |
+
|
|
|
|
| 22 |
image = image.convert("RGB")
|
| 23 |
+
points = [tuple(point) for point in points]
|
| 24 |
|
| 25 |
+
# Initialize SAM model and processor on CPU
|
| 26 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
|
| 27 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
| 28 |
|
|
|
|
| 29 |
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
|
| 30 |
|
| 31 |
with torch.no_grad():
|
| 32 |
outputs = sam_model(**inputs)
|
| 33 |
|
|
|
|
| 34 |
masks = sam_processor.image_processor.post_process_masks(
|
| 35 |
outputs.pred_masks.cpu(),
|
| 36 |
inputs["original_sizes"].cpu(),
|
|
|
|
| 40 |
if len(masks) == 0:
|
| 41 |
return None
|
| 42 |
|
|
|
|
| 43 |
best_mask = masks[0][0][outputs.iou_scores.argmax()]
|
|
|
|
|
|
|
| 44 |
binary_mask = ~best_mask.numpy().astype(bool).astype(int)
|
| 45 |
|
| 46 |
return binary_mask
|
| 47 |
|
| 48 |
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
|
| 49 |
"""
|
| 50 |
+
Replaces the object in the image based on the mask and prompt.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
"""
|
| 52 |
if mask is None:
|
| 53 |
return image
|
| 54 |
|
| 55 |
+
# Initialize Inpainting pipeline on CPU with a compatible model
|
| 56 |
+
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 57 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
| 58 |
+
torch_dtype=torch.float32
|
| 59 |
+
).to("cpu")
|
| 60 |
+
|
| 61 |
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
|
| 62 |
|
| 63 |
generator = torch.Generator("cpu").manual_seed(seed)
|
|
|
|
| 79 |
def visualize_mask(image, mask):
|
| 80 |
"""
|
| 81 |
Overlays the mask on the image for visualization.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
"""
|
| 83 |
if mask is None:
|
| 84 |
return image
|
| 85 |
|
| 86 |
+
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
|
| 87 |
+
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
|
| 88 |
+
mask_rgba = Image.fromarray(bg_transparent)
|
| 89 |
+
|
| 90 |
+
overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba)
|
| 91 |
return overlay.convert("RGB")
|
| 92 |
|
| 93 |
def get_points(img, evt: gr.SelectData):
|
| 94 |
"""
|
| 95 |
Captures points selected by the user on the image.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
"""
|
| 97 |
global input_points
|
| 98 |
global input_image
|
| 99 |
|
|
|
|
| 100 |
if len(input_points) == 0:
|
| 101 |
input_image = img.copy()
|
| 102 |
|
| 103 |
x = evt.index[0]
|
| 104 |
y = evt.index[1]
|
| 105 |
+
|
| 106 |
input_points.append([x, y])
|
| 107 |
|
| 108 |
+
# Generate mask based on selected points
|
| 109 |
mask = generate_mask(input_image, input_points)
|
| 110 |
|
| 111 |
# Mark selected points with a green crossmark
|
|
|
|
| 124 |
def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
|
| 125 |
"""
|
| 126 |
Runs the inpainting process based on user inputs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
"""
|
| 128 |
global input_image
|
| 129 |
global input_points
|
|
|
|
| 149 |
def reset_points_func():
|
| 150 |
"""
|
| 151 |
Resets the selected points and the input image.
|
|
|
|
|
|
|
|
|
|
| 152 |
"""
|
| 153 |
global input_points
|
| 154 |
global input_image
|
|
|
|
| 159 |
def preprocess(input_img):
|
| 160 |
"""
|
| 161 |
Preprocesses the uploaded image to ensure it is square and resized.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
"""
|
| 163 |
if input_img is None:
|
| 164 |
return None
|
| 165 |
|
|
|
|
| 166 |
width, height = input_img.size
|
|
|
|
| 167 |
if width != height:
|
| 168 |
# Add white padding to make the image square
|
| 169 |
new_size = max(width, height)
|
|
|
|
| 175 |
|
| 176 |
return input_img.resize((IMG_SIZE, IMG_SIZE))
|
| 177 |
|
| 178 |
+
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
+
gr.Markdown(
|
| 181 |
+
"""
|
| 182 |
+
# Object Replacement App
|
| 183 |
+
Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
|
| 184 |
+
|
| 185 |
+
**Instructions:**
|
| 186 |
+
1. **Upload Image:** Click on the first image box to upload your image.
|
| 187 |
+
2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
|
| 188 |
+
3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
|
| 189 |
+
4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
|
| 190 |
+
5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
|
| 191 |
+
6. **Reset:** Click the "Reset" button to clear selections and start over.
|
| 192 |
+
""")
|
| 193 |
+
|
| 194 |
+
with gr.Row():
|
| 195 |
+
with gr.Column():
|
| 196 |
+
# Image upload and point selection
|
| 197 |
+
upload_image = gr.Image(
|
| 198 |
+
label="Upload Image",
|
| 199 |
+
type="pil",
|
| 200 |
+
interactive=True,
|
| 201 |
+
height=IMG_SIZE,
|
| 202 |
+
width=IMG_SIZE
|
| 203 |
+
)
|
| 204 |
+
mask_visualization = gr.Image(
|
| 205 |
+
label="Selected Object Mask Overlay",
|
| 206 |
+
interactive=False,
|
| 207 |
+
height=IMG_SIZE,
|
| 208 |
+
width=IMG_SIZE
|
| 209 |
+
)
|
| 210 |
+
selected_image = gr.Image(
|
| 211 |
+
label="Image with Selected Points",
|
| 212 |
+
type="pil",
|
| 213 |
+
interactive=False,
|
| 214 |
+
height=IMG_SIZE,
|
| 215 |
+
width=IMG_SIZE,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Capture points using the select event
|
| 219 |
+
upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
|
| 220 |
+
|
| 221 |
+
# Preprocess image on change
|
| 222 |
+
upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
|
| 223 |
+
|
| 224 |
+
# Text inputs and settings
|
| 225 |
+
prompt = gr.Textbox(
|
| 226 |
+
label="Replacement Prompt",
|
| 227 |
+
placeholder="e.g., a red sports car",
|
| 228 |
+
lines=2
|
| 229 |
+
)
|
| 230 |
+
negative_prompt = gr.Textbox(
|
| 231 |
+
label="Negative Prompt",
|
| 232 |
+
placeholder="e.g., blurry, low quality",
|
| 233 |
+
lines=2
|
| 234 |
+
)
|
| 235 |
+
cfg = gr.Slider(
|
| 236 |
+
label="Classifier-Free Guidance Scale",
|
| 237 |
+
minimum=1.0,
|
| 238 |
+
maximum=20.0,
|
| 239 |
+
value=7.5,
|
| 240 |
+
step=0.5
|
| 241 |
+
)
|
| 242 |
+
seed = gr.Number(
|
| 243 |
+
label="Seed",
|
| 244 |
+
value=42,
|
| 245 |
+
precision=0
|
| 246 |
+
)
|
| 247 |
+
invert = gr.Checkbox(
|
| 248 |
+
label="Infill subject instead of background"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# Buttons
|
| 252 |
+
replace_button = gr.Button("Replace Object")
|
| 253 |
+
reset_button = gr.Button("Reset")
|
| 254 |
+
with gr.Column():
|
| 255 |
+
# Output images
|
| 256 |
+
augmented_image = gr.Image(
|
| 257 |
+
label="Augmented Image",
|
| 258 |
+
type="pil",
|
| 259 |
+
interactive=False,
|
| 260 |
+
height=IMG_SIZE,
|
| 261 |
+
width=IMG_SIZE,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Define button actions
|
| 265 |
+
replace_button.click(
|
| 266 |
+
fn=run_inpaint,
|
| 267 |
+
inputs=[prompt, negative_prompt, cfg, seed, invert],
|
| 268 |
+
outputs=[augmented_image]
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
reset_button.click(
|
| 272 |
+
fn=reset_points_func,
|
| 273 |
+
inputs=[],
|
| 274 |
+
outputs=[mask_visualization, selected_image, augmented_image]
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Examples (optional)
|
| 278 |
+
gr.Markdown(
|
| 279 |
"""
|
| 280 |
+
## EXAMPLES
|
| 281 |
+
Click on an example to load it. Then, follow the instructions above.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
""")
|
| 283 |
+
|
| 284 |
+
with gr.Row():
|
| 285 |
+
examples = gr.Examples(
|
| 286 |
+
examples=[
|
| 287 |
+
[
|
| 288 |
+
"car.png",
|
| 289 |
+
"a red sports car",
|
| 290 |
+
"blurry, low quality",
|
| 291 |
+
42
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
],
|
| 293 |
+
[
|
| 294 |
+
"house.jpg",
|
| 295 |
+
"a modern villa",
|
| 296 |
+
"dark, overexposed",
|
| 297 |
+
123
|
| 298 |
],
|
| 299 |
+
[
|
| 300 |
+
"tree.png",
|
| 301 |
+
"a blooming cherry tree",
|
| 302 |
+
"underexposed, low contrast",
|
| 303 |
+
999
|
| 304 |
+
]
|
| 305 |
+
],
|
| 306 |
+
inputs=[
|
| 307 |
+
upload_image,
|
| 308 |
+
prompt,
|
| 309 |
+
negative_prompt,
|
| 310 |
+
seed
|
| 311 |
+
],
|
| 312 |
+
label="Click to load examples",
|
| 313 |
+
cache_examples=False # Set to False to avoid the error
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
demo.queue(max_size=10).launch()
|