File size: 9,612 Bytes
18c979d
22ba4da
a99346b
18c979d
 
 
a99346b
 
 
18c979d
a99346b
 
 
 
22ba4da
18c979d
22ba4da
18c979d
22ba4da
18c979d
22ba4da
18c979d
22ba4da
18c979d
22ba4da
 
 
18c979d
452ea00
18c979d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22ba4da
18c979d
 
 
 
22ba4da
 
 
 
 
 
18c979d
 
452ea00
18c979d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22ba4da
 
 
 
 
18c979d
 
a99346b
18c979d
a99346b
18c979d
a99346b
 
 
 
 
 
 
 
22ba4da
a99346b
18c979d
22ba4da
a99346b
 
 
 
 
 
 
 
 
 
 
 
18c979d
a99346b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22ba4da
a99346b
22ba4da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18c979d
22ba4da
 
a99346b
22ba4da
 
 
 
 
 
 
 
 
a99346b
22ba4da
a05e0f7
 
22ba4da
 
a99346b
22ba4da
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
import torch
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline

# Constants
IMG_SIZE = 512

# Global variables to store points and the original image
input_points = []
input_image = None

def generate_mask(image, points):
    """
    Generates a mask using SAM based on input points.
    """
    if not points:
        return None

    image = image.convert("RGB")
    points = [tuple(point) for point in points]
    
    # Initialize SAM model and processor on CPU
    sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
    sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
    
    inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
    
    with torch.no_grad():
        outputs = sam_model(**inputs)
    
    masks = sam_processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )
    
    if len(masks) == 0:
        return None
    
    best_mask = masks[0][0][outputs.iou_scores.argmax()]
    binary_mask = ~best_mask.numpy().astype(bool).astype(int)
    
    return binary_mask

def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
    """
    Replaces the object in the image based on the mask and prompt.
    """
    if mask is None:
        return image
    
    # Initialize Inpainting pipeline on CPU with a compatible model
    inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
        torch_dtype=torch.float32
    ).to("cpu")
    
    mask_image = Image.fromarray((mask * 255).astype(np.uint8))
    
    generator = torch.Generator("cpu").manual_seed(seed)
    
    try:
        result = inpaint_pipeline(
            prompt=prompt,
            image=image,
            mask_image=mask_image,
            negative_prompt=negative_prompt if negative_prompt else None,
            generator=generator,
            guidance_scale=guidance_scale
        ).images[0]
        return result
    except Exception as e:
        print(f"Inpainting error: {e}")
        return image

def visualize_mask(image, mask):
    """
    Overlays the mask on the image for visualization.
    """
    if mask is None:
        return image
    
    bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
    bg_transparent[mask == 1] = [0, 255, 0, 127]  # Green with transparency
    mask_rgba = Image.fromarray(bg_transparent)
    
    overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba)
    return overlay.convert("RGB")

def get_points(img, evt: gr.SelectData):
    """
    Captures points selected by the user on the image.
    """
    global input_points
    global input_image
    
    if len(input_points) == 0:
        input_image = img.copy()
    
    x = evt.index[0]
    y = evt.index[1]
    
    input_points.append([x, y])
    
    # Generate mask based on selected points
    mask = generate_mask(input_image, input_points)
    
    # Mark selected points with a green crossmark
    draw = ImageDraw.Draw(img)
    size = 10
    for point in input_points:
        px, py = point
        draw.line((px - size, py, px + size, py), fill="green", width=5)
        draw.line((px, py - size, px, py + size), fill="green", width=5)
    
    # Visualize the mask overlay
    masked_image = visualize_mask(input_image, mask)
    
    return masked_image, img

def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
    """
    Runs the inpainting process based on user inputs.
    """
    global input_image
    global input_points
    
    if input_image is None or len(input_points) == 0:
        raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")
    
    mask = generate_mask(input_image, input_points)
    
    if invert:
        what = 'subject'
        mask = ~mask
    else:
        what = 'background'
    
    try:
        inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
    except Exception as e:
        raise gr.Error(str(e))
    
    return inpainted.resize((IMG_SIZE, IMG_SIZE))

def reset_points_func():
    """
    Resets the selected points and the input image.
    """
    global input_points
    global input_image
    input_points = []
    input_image = None
    return None, None, None

def preprocess(input_img):
    """
    Preprocesses the uploaded image to ensure it is square and resized.
    """
    if input_img is None:
        return None
    
    width, height = input_img.size
    if width != height:
        # Add white padding to make the image square
        new_size = max(width, height)
        new_image = Image.new("RGB", (new_size, new_size), 'white')
        left = (new_size - width) // 2
        top = (new_size - height) // 2
        new_image.paste(input_img, (left, top))
        input_img = new_image
    
    return input_img.resize((IMG_SIZE, IMG_SIZE))

with gr.Blocks() as demo:

    gr.Markdown(
    """
    # Object Replacement App
    Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
    
    **Instructions:**
    1. **Upload Image:** Click on the first image box to upload your image.
    2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
    3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
    4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
    5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
    6. **Reset:** Click the "Reset" button to clear selections and start over.
    """)
    
    with gr.Row():
        with gr.Column():
            # Image upload and point selection
            upload_image = gr.Image(
                label="Upload Image", 
                type="pil", 
                interactive=True,
                height=IMG_SIZE,
                width=IMG_SIZE
            )
            mask_visualization = gr.Image(
                label="Selected Object Mask Overlay", 
                interactive=False,
                height=IMG_SIZE,
                width=IMG_SIZE
            )
            selected_image = gr.Image(
                label="Image with Selected Points", 
                type="pil", 
                interactive=False,
                height=IMG_SIZE,
                width=IMG_SIZE,
            )
            
            # Capture points using the select event
            upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
            
            # Preprocess image on change
            upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
            
            # Text inputs and settings
            prompt = gr.Textbox(
                label="Replacement Prompt", 
                placeholder="e.g., a red sports car", 
                lines=2
            )
            negative_prompt = gr.Textbox(
                label="Negative Prompt", 
                placeholder="e.g., blurry, low quality", 
                lines=2
            )
            cfg = gr.Slider(
                label="Classifier-Free Guidance Scale", 
                minimum=1.0, 
                maximum=20.0, 
                value=7.5, 
                step=0.5
            )
            seed = gr.Number(
                label="Seed", 
                value=42, 
                precision=0
            )
            invert = gr.Checkbox(
                label="Infill subject instead of background"
            )
            
            # Buttons
            replace_button = gr.Button("Replace Object")
            reset_button = gr.Button("Reset")
        with gr.Column():
            # Output images
            augmented_image = gr.Image(
                label="Augmented Image", 
                type="pil", 
                interactive=False,
                height=IMG_SIZE,
                width=IMG_SIZE,
            )
    
    # Define button actions
    replace_button.click(
        fn=run_inpaint,
        inputs=[prompt, negative_prompt, cfg, seed, invert],
        outputs=[augmented_image]
    )
    
    reset_button.click(
        fn=reset_points_func,
        inputs=[],
        outputs=[mask_visualization, selected_image, augmented_image]
    )
    
    # Examples (optional)
    gr.Markdown(
        """
        ## EXAMPLES
        Click on an example to load it. Then, follow the instructions above.
        """)
    
    with gr.Row():
        examples = gr.Examples(
            examples=[
                [
                    "car.png", 
                    "a red sports car", 
                    "blurry, low quality", 
                    42
                ],
                [
                    "monalisa.png", 
                    "a rockstar", 
                    "dark, overexposed", 
                    123
                ],
            ],
            inputs=[
                upload_image,
                prompt,
                negative_prompt,
                seed
            ],
            label="Click to load examples",
            cache_examples=False  # Set to False to avoid the error
        )
    
demo.queue(max_size=10).launch()