File size: 9,612 Bytes
18c979d 22ba4da a99346b 18c979d a99346b 18c979d a99346b 22ba4da 18c979d 22ba4da 18c979d 22ba4da 18c979d 22ba4da 18c979d 22ba4da 18c979d 22ba4da 18c979d 452ea00 18c979d 22ba4da 18c979d 22ba4da 18c979d 452ea00 18c979d 22ba4da 18c979d a99346b 18c979d a99346b 18c979d a99346b 22ba4da a99346b 18c979d 22ba4da a99346b 18c979d a99346b 22ba4da a99346b 22ba4da 18c979d 22ba4da a99346b 22ba4da a99346b 22ba4da a05e0f7 22ba4da a99346b 22ba4da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline
# Constants
IMG_SIZE = 512
# Global variables to store points and the original image
input_points = []
input_image = None
def generate_mask(image, points):
"""
Generates a mask using SAM based on input points.
"""
if not points:
return None
image = image.convert("RGB")
points = [tuple(point) for point in points]
# Initialize SAM model and processor on CPU
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
with torch.no_grad():
outputs = sam_model(**inputs)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
if len(masks) == 0:
return None
best_mask = masks[0][0][outputs.iou_scores.argmax()]
binary_mask = ~best_mask.numpy().astype(bool).astype(int)
return binary_mask
def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
"""
Replaces the object in the image based on the mask and prompt.
"""
if mask is None:
return image
# Initialize Inpainting pipeline on CPU with a compatible model
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float32
).to("cpu")
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
generator = torch.Generator("cpu").manual_seed(seed)
try:
result = inpaint_pipeline(
prompt=prompt,
image=image,
mask_image=mask_image,
negative_prompt=negative_prompt if negative_prompt else None,
generator=generator,
guidance_scale=guidance_scale
).images[0]
return result
except Exception as e:
print(f"Inpainting error: {e}")
return image
def visualize_mask(image, mask):
"""
Overlays the mask on the image for visualization.
"""
if mask is None:
return image
bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
mask_rgba = Image.fromarray(bg_transparent)
overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba)
return overlay.convert("RGB")
def get_points(img, evt: gr.SelectData):
"""
Captures points selected by the user on the image.
"""
global input_points
global input_image
if len(input_points) == 0:
input_image = img.copy()
x = evt.index[0]
y = evt.index[1]
input_points.append([x, y])
# Generate mask based on selected points
mask = generate_mask(input_image, input_points)
# Mark selected points with a green crossmark
draw = ImageDraw.Draw(img)
size = 10
for point in input_points:
px, py = point
draw.line((px - size, py, px + size, py), fill="green", width=5)
draw.line((px, py - size, px, py + size), fill="green", width=5)
# Visualize the mask overlay
masked_image = visualize_mask(input_image, mask)
return masked_image, img
def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
"""
Runs the inpainting process based on user inputs.
"""
global input_image
global input_points
if input_image is None or len(input_points) == 0:
raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")
mask = generate_mask(input_image, input_points)
if invert:
what = 'subject'
mask = ~mask
else:
what = 'background'
try:
inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
except Exception as e:
raise gr.Error(str(e))
return inpainted.resize((IMG_SIZE, IMG_SIZE))
def reset_points_func():
"""
Resets the selected points and the input image.
"""
global input_points
global input_image
input_points = []
input_image = None
return None, None, None
def preprocess(input_img):
"""
Preprocesses the uploaded image to ensure it is square and resized.
"""
if input_img is None:
return None
width, height = input_img.size
if width != height:
# Add white padding to make the image square
new_size = max(width, height)
new_image = Image.new("RGB", (new_size, new_size), 'white')
left = (new_size - width) // 2
top = (new_size - height) // 2
new_image.paste(input_img, (left, top))
input_img = new_image
return input_img.resize((IMG_SIZE, IMG_SIZE))
with gr.Blocks() as demo:
gr.Markdown(
"""
# Object Replacement App
Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
**Instructions:**
1. **Upload Image:** Click on the first image box to upload your image.
2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
6. **Reset:** Click the "Reset" button to clear selections and start over.
""")
with gr.Row():
with gr.Column():
# Image upload and point selection
upload_image = gr.Image(
label="Upload Image",
type="pil",
interactive=True,
height=IMG_SIZE,
width=IMG_SIZE
)
mask_visualization = gr.Image(
label="Selected Object Mask Overlay",
interactive=False,
height=IMG_SIZE,
width=IMG_SIZE
)
selected_image = gr.Image(
label="Image with Selected Points",
type="pil",
interactive=False,
height=IMG_SIZE,
width=IMG_SIZE,
)
# Capture points using the select event
upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
# Preprocess image on change
upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
# Text inputs and settings
prompt = gr.Textbox(
label="Replacement Prompt",
placeholder="e.g., a red sports car",
lines=2
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="e.g., blurry, low quality",
lines=2
)
cfg = gr.Slider(
label="Classifier-Free Guidance Scale",
minimum=1.0,
maximum=20.0,
value=7.5,
step=0.5
)
seed = gr.Number(
label="Seed",
value=42,
precision=0
)
invert = gr.Checkbox(
label="Infill subject instead of background"
)
# Buttons
replace_button = gr.Button("Replace Object")
reset_button = gr.Button("Reset")
with gr.Column():
# Output images
augmented_image = gr.Image(
label="Augmented Image",
type="pil",
interactive=False,
height=IMG_SIZE,
width=IMG_SIZE,
)
# Define button actions
replace_button.click(
fn=run_inpaint,
inputs=[prompt, negative_prompt, cfg, seed, invert],
outputs=[augmented_image]
)
reset_button.click(
fn=reset_points_func,
inputs=[],
outputs=[mask_visualization, selected_image, augmented_image]
)
# Examples (optional)
gr.Markdown(
"""
## EXAMPLES
Click on an example to load it. Then, follow the instructions above.
""")
with gr.Row():
examples = gr.Examples(
examples=[
[
"car.png",
"a red sports car",
"blurry, low quality",
42
],
[
"monalisa.png",
"a rockstar",
"dark, overexposed",
123
],
],
inputs=[
upload_image,
prompt,
negative_prompt,
seed
],
label="Click to load examples",
cache_examples=False # Set to False to avoid the error
)
demo.queue(max_size=10).launch() |