File size: 11,606 Bytes
18c979d
 
 
a99346b
18c979d
 
 
 
a99346b
 
 
18c979d
452ea00
 
18c979d
 
d633b07
18c979d
d633b07
452ea00
 
d633b07
18c979d
a99346b
 
 
 
18c979d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452ea00
18c979d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452ea00
18c979d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a99346b
18c979d
a99346b
18c979d
 
a99346b
 
18c979d
 
a99346b
18c979d
a99346b
 
 
 
 
 
 
 
 
18c979d
a99346b
18c979d
a99346b
 
 
 
 
 
 
 
 
 
 
 
 
18c979d
a99346b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18c979d
a99346b
 
 
18c979d
a99346b
 
18c979d
 
 
a99346b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18c979d
 
a99346b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# app.py

import gradio as gr
from PIL import Image, ImageDraw
import torch
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline

# Constants
IMG_SIZE = 512

# Initialize SAM model and processor on CPU
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

# Initialize Inpainting pipeline on CPU with a compatible model
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float32
).to("cpu")
# No need for model_cpu_offload on CPU

# Global variables to store points and the original image
input_points = []
input_image = None

def mask_to_rgba(mask):
    """
    Converts a binary mask to an RGBA image for visualization.
    """
    bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
    bg_transparent[mask == 1] = [0, 255, 0, 127]  # Green with transparency
    return bg_transparent

def generate_mask(image, input_points):
    """
    Generates a binary mask using SAM based on input points.
    
    Args:
        image (PIL.Image): The input image.
        input_points (list of lists): List of points selected by the user.
        
    Returns:
        np.ndarray: Binary mask where the object is marked with 1s.
    """
    if not input_points:
        return None
    
    # Convert image to RGB if not already
    image = image.convert("RGB")
    
    # Flatten the list of points
    points = [tuple(point) for point in input_points]
    
    # Prepare inputs for SAM
    inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
    
    with torch.no_grad():
        outputs = sam_model(**inputs)
    
    # Post-process masks
    masks = sam_processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )
    
    if len(masks) == 0:
        return None
    
    # Select the mask with the highest IoU score
    best_mask = masks[0][0][outputs.iou_scores.argmax()]
    
    # Invert mask: object=1, background=0
    binary_mask = ~best_mask.numpy().astype(bool).astype(int)
    
    return binary_mask

def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
    """
    Replaces the selected object in the image based on the prompt.
    
    Args:
        image (PIL.Image): The original image.
        mask (np.ndarray): Binary mask of the selected object.
        prompt (str): Text prompt describing the replacement.
        negative_prompt (str): Negative text prompt to refine generation.
        seed (int): Random seed for reproducibility.
        guidance_scale (float): Guidance scale for the inpainting model.
        
    Returns:
        PIL.Image: The augmented image with the object replaced.
    """
    if mask is None:
        return image
    
    mask_image = Image.fromarray((mask * 255).astype(np.uint8))
    
    generator = torch.Generator("cpu").manual_seed(seed)
    
    try:
        result = inpaint_pipeline(
            prompt=prompt,
            image=image,
            mask_image=mask_image,
            negative_prompt=negative_prompt if negative_prompt else None,
            generator=generator,
            guidance_scale=guidance_scale
        ).images[0]
        return result
    except Exception as e:
        print(f"Inpainting error: {e}")
        return image

def visualize_mask(image, mask):
    """
    Overlays the mask on the image for visualization.
    
    Args:
        image (PIL.Image): The original image.
        mask (np.ndarray): Binary mask of the selected object.
        
    Returns:
        PIL.Image: Image with mask overlay.
    """
    if mask is None:
        return image
    
    mask_rgba = mask_to_rgba(mask)
    mask_pil = Image.fromarray(mask_rgba)
    overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil)
    return overlay.convert("RGB")

def get_points(img, evt: gr.SelectData):
    """
    Captures points selected by the user on the image.
    
    Args:
        img (PIL.Image): The uploaded image.
        evt (gr.SelectData): Event data containing the point coordinates.
        
    Returns:
        Tuple: (Updated mask visualization, Updated image with crossmarks)
    """
    global input_points
    global input_image
    
    # The first time this is called, save the untouched input image
    if len(input_points) == 0:
        input_image = img.copy()
    
    x = evt.index[0]
    y = evt.index[1]

    input_points.append([x, y])
    
    # Run SAM to generate mask
    mask = generate_mask(input_image, input_points)
    
    # Mark selected points with a green crossmark
    draw = ImageDraw.Draw(img)
    size = 10
    for point in input_points:
        px, py = point
        draw.line((px - size, py, px + size, py), fill="green", width=5)
        draw.line((px, py - size, px, py + size), fill="green", width=5)
    
    # Visualize the mask overlay
    masked_image = visualize_mask(input_image, mask)
    
    return masked_image, img

def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
    """
    Runs the inpainting process based on user inputs.
    
    Args:
        prompt (str): Prompt for infill.
        negative_prompt (str): Negative prompt.
        cfg (float): Classifier-Free Guidance Scale.
        seed (int): Random seed.
        invert (bool): Whether to infill the subject instead of the background.
        
    Returns:
        PIL.Image: The inpainted image.
    """
    global input_image
    global input_points
    
    if input_image is None or len(input_points) == 0:
        raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")
    
    mask = generate_mask(input_image, input_points)
    
    if invert:
        what = 'subject'
        mask = ~mask
    else:
        what = 'background'
    
    try:
        inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
    except Exception as e:
        raise gr.Error(str(e))
    
    return inpainted.resize((IMG_SIZE, IMG_SIZE))

def reset_points_func():
    """
    Resets the selected points and the input image.
    
    Returns:
        Tuple: (Reset mask visualization, Reset image, Empty inpainted image)
    """
    global input_points
    global input_image
    input_points = []
    input_image = None
    return None, None, None

def preprocess(input_img):
    """
    Preprocesses the uploaded image to ensure it is square and resized.
    
    Args:
        input_img (PIL.Image): The uploaded image.
        
    Returns:
        PIL.Image: The preprocessed image.
    """
    if input_img is None:
        return None
    
    # Make sure the image is square
    width, height = input_img.size
    
    if width != height:
        # Add white padding to make the image square
        new_size = max(width, height)
        new_image = Image.new("RGB", (new_size, new_size), 'white')
        left = (new_size - width) // 2
        top = (new_size - height) // 2
        new_image.paste(input_img, (left, top))
        input_img = new_image
    
    return input_img.resize((IMG_SIZE, IMG_SIZE))

def build_app(get_processed_inputs, inpaint):
    """
    Builds and launches the Gradio app.
    
    Args:
        get_processed_inputs (function): Function to process inputs for SAM.
        inpaint (function): Function to perform inpainting.
        
    Returns:
        None
    """
    with gr.Blocks() as demo:

        gr.Markdown(
        """
        # Object Replacement App
        Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
        
        **Instructions:**
        1. **Upload Image:** Click on the first image box to upload your image.
        2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
        3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
        4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
        5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
        6. **Reset:** Click the "Reset" button to clear selections and start over.
        """)
        
        with gr.Row():
            with gr.Column():
                # Image upload and point selection
                upload_image = gr.Image(label="Upload Image", type="pil", interactive=True)
                mask_visualization = gr.Image(label="Selected Object Mask Overlay", interactive=False)
                selected_image = gr.Image(label="Image with Selected Points", type="pil", interactive=False)
                
                # Capture points using the select event
                upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
                
                # Preprocess image on change
                upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
                
                # Text inputs and settings
                prompt = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
                negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
                cfg = gr.Slider(
                    label="Classifier-Free Guidance Scale", 
                    minimum=1.0, 
                    maximum=20.0, 
                    value=7.5, 
                    step=0.5
                )
                seed = gr.Number(label="Seed", value=42, precision=0)
                invert = gr.Checkbox(label="Infill subject instead of background")
                
                # Buttons
                replace_button = gr.Button("Replace Object")
                reset_button = gr.Button("Reset")
            with gr.Column():
                # Output images
                augmented_image = gr.Image(label="Augmented Image", type="pil", interactive=False)
        
        # Define button actions
        replace_button.click(
            fn=run_inpaint,
            inputs=[prompt, negative_prompt, cfg, seed, invert],
            outputs=[augmented_image]
        )
        
        reset_button.click(
            fn=reset_points_func,
            inputs=[],
            outputs=[mask_visualization, selected_image, augmented_image]
        )
        
        # Examples (optional)
        gr.Markdown(
            """
            ## EXAMPLES
            Click on an example to load it. Then, follow the instructions above.
            """)
        
        with gr.Row():
            examples = gr.Examples(
                examples=[
                    ["car.png", "a red sports car", "blurry, low quality", 42],
                    ["house.jpg", "a modern villa", "dark, overexposed", 123],
                    ["tree.png", "a blooming cherry tree", "underexposed, low contrast", 999]
                ],
                inputs=[
                    upload_image,
                    prompt,
                    negative_prompt,
                    seed
                ],
                label="Click to load examples",
                cache_examples=True
            )
        
    demo.queue(max_size=10).launch()

# Launch the app
build_app(None, None)