File size: 11,606 Bytes
18c979d a99346b 18c979d a99346b 18c979d 452ea00 18c979d d633b07 18c979d d633b07 452ea00 d633b07 18c979d a99346b 18c979d 452ea00 18c979d 452ea00 18c979d a99346b 18c979d a99346b 18c979d a99346b 18c979d a99346b 18c979d a99346b 18c979d a99346b 18c979d a99346b 18c979d a99346b 18c979d a99346b 18c979d a99346b 18c979d a99346b 18c979d a99346b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 |
# app.py
import gradio as gr
from PIL import Image, ImageDraw
import torch
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline
# Constants
IMG_SIZE = 512
# Initialize SAM model and processor on CPU
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
# Initialize Inpainting pipeline on CPU with a compatible model
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32
).to("cpu")
# No need for model_cpu_offload on CPU
# Global variables to store points and the original image
input_points = []
input_image = None
def mask_to_rgba(mask):
"""
Converts a binary mask to an RGBA image for visualization.
"""
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
return bg_transparent
def generate_mask(image, input_points):
"""
Generates a binary mask using SAM based on input points.
Args:
image (PIL.Image): The input image.
input_points (list of lists): List of points selected by the user.
Returns:
np.ndarray: Binary mask where the object is marked with 1s.
"""
if not input_points:
return None
# Convert image to RGB if not already
image = image.convert("RGB")
# Flatten the list of points
points = [tuple(point) for point in input_points]
# Prepare inputs for SAM
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
with torch.no_grad():
outputs = sam_model(**inputs)
# Post-process masks
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
if len(masks) == 0:
return None
# Select the mask with the highest IoU score
best_mask = masks[0][0][outputs.iou_scores.argmax()]
# Invert mask: object=1, background=0
binary_mask = ~best_mask.numpy().astype(bool).astype(int)
return binary_mask
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
"""
Replaces the selected object in the image based on the prompt.
Args:
image (PIL.Image): The original image.
mask (np.ndarray): Binary mask of the selected object.
prompt (str): Text prompt describing the replacement.
negative_prompt (str): Negative text prompt to refine generation.
seed (int): Random seed for reproducibility.
guidance_scale (float): Guidance scale for the inpainting model.
Returns:
PIL.Image: The augmented image with the object replaced.
"""
if mask is None:
return image
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
generator = torch.Generator("cpu").manual_seed(seed)
try:
result = inpaint_pipeline(
prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt if negative_prompt else None,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return result
except Exception as e:
print(f"Inpainting error: {e}")
return image
def visualize_mask(image, mask):
"""
Overlays the mask on the image for visualization.
Args:
image (PIL.Image): The original image.
mask (np.ndarray): Binary mask of the selected object.
Returns:
PIL.Image: Image with mask overlay.
"""
if mask is None:
return image
mask_rgba = mask_to_rgba(mask)
mask_pil = Image.fromarray(mask_rgba)
overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil)
return overlay.convert("RGB")
def get_points(img, evt: gr.SelectData):
"""
Captures points selected by the user on the image.
Args:
img (PIL.Image): The uploaded image.
evt (gr.SelectData): Event data containing the point coordinates.
Returns:
Tuple: (Updated mask visualization, Updated image with crossmarks)
"""
global input_points
global input_image
# The first time this is called, save the untouched input image
if len(input_points) == 0:
input_image = img.copy()
x = evt.index[0]
y = evt.index[1]
input_points.append([x, y])
# Run SAM to generate mask
mask = generate_mask(input_image, input_points)
# Mark selected points with a green crossmark
draw = ImageDraw.Draw(img)
size = 10
for point in input_points:
px, py = point
draw.line((px - size, py, px + size, py), fill="green", width=5)
draw.line((px, py - size, px, py + size), fill="green", width=5)
# Visualize the mask overlay
masked_image = visualize_mask(input_image, mask)
return masked_image, img
def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
"""
Runs the inpainting process based on user inputs.
Args:
prompt (str): Prompt for infill.
negative_prompt (str): Negative prompt.
cfg (float): Classifier-Free Guidance Scale.
seed (int): Random seed.
invert (bool): Whether to infill the subject instead of the background.
Returns:
PIL.Image: The inpainted image.
"""
global input_image
global input_points
if input_image is None or len(input_points) == 0:
raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")
mask = generate_mask(input_image, input_points)
if invert:
what = 'subject'
mask = ~mask
else:
what = 'background'
try:
inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
except Exception as e:
raise gr.Error(str(e))
return inpainted.resize((IMG_SIZE, IMG_SIZE))
def reset_points_func():
"""
Resets the selected points and the input image.
Returns:
Tuple: (Reset mask visualization, Reset image, Empty inpainted image)
"""
global input_points
global input_image
input_points = []
input_image = None
return None, None, None
def preprocess(input_img):
"""
Preprocesses the uploaded image to ensure it is square and resized.
Args:
input_img (PIL.Image): The uploaded image.
Returns:
PIL.Image: The preprocessed image.
"""
if input_img is None:
return None
# Make sure the image is square
width, height = input_img.size
if width != height:
# Add white padding to make the image square
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), 'white')
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(input_img, (left, top))
input_img = new_image
return input_img.resize((IMG_SIZE, IMG_SIZE))
def build_app(get_processed_inputs, inpaint):
"""
Builds and launches the Gradio app.
Args:
get_processed_inputs (function): Function to process inputs for SAM.
inpaint (function): Function to perform inpainting.
Returns:
None
"""
with gr.Blocks() as demo:
gr.Markdown(
"""
# Object Replacement App
Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
**Instructions:**
1. **Upload Image:** Click on the first image box to upload your image.
2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
6. **Reset:** Click the "Reset" button to clear selections and start over.
""")
with gr.Row():
with gr.Column():
# Image upload and point selection
upload_image = gr.Image(label="Upload Image", type="pil", interactive=True)
mask_visualization = gr.Image(label="Selected Object Mask Overlay", interactive=False)
selected_image = gr.Image(label="Image with Selected Points", type="pil", interactive=False)
# Capture points using the select event
upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
# Preprocess image on change
upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
# Text inputs and settings
prompt = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
cfg = gr.Slider(
label="Classifier-Free Guidance Scale",
minimum=1.0,
maximum=20.0,
value=7.5,
step=0.5
)
seed = gr.Number(label="Seed", value=42, precision=0)
invert = gr.Checkbox(label="Infill subject instead of background")
# Buttons
replace_button = gr.Button("Replace Object")
reset_button = gr.Button("Reset")
with gr.Column():
# Output images
augmented_image = gr.Image(label="Augmented Image", type="pil", interactive=False)
# Define button actions
replace_button.click(
fn=run_inpaint,
inputs=[prompt, negative_prompt, cfg, seed, invert],
outputs=[augmented_image]
)
reset_button.click(
fn=reset_points_func,
inputs=[],
outputs=[mask_visualization, selected_image, augmented_image]
)
# Examples (optional)
gr.Markdown(
"""
## EXAMPLES
Click on an example to load it. Then, follow the instructions above.
""")
with gr.Row():
examples = gr.Examples(
examples=[
["car.png", "a red sports car", "blurry, low quality", 42],
["house.jpg", "a modern villa", "dark, overexposed", 123],
["tree.png", "a blooming cherry tree", "underexposed, low contrast", 999]
],
inputs=[
upload_image,
prompt,
negative_prompt,
seed
],
label="Click to load examples",
cache_examples=True
)
demo.queue(max_size=10).launch()
# Launch the app
build_app(None, None) |