File size: 6,990 Bytes
18c979d
 
 
 
 
 
 
 
 
 
452ea00
 
18c979d
 
d633b07
18c979d
d633b07
452ea00
 
d633b07
18c979d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452ea00
18c979d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452ea00
18c979d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d633b07
 
 
 
 
 
 
18c979d
 
 
 
 
 
 
 
 
d633b07
18c979d
 
d633b07
18c979d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# app.py

import gradio as gr
from PIL import Image
import torch
import numpy as np
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline
import io

# Initialize SAM model and processor on CPU
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

# Initialize Inpainting pipeline on CPU with a compatible model
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float32
).to("cpu")
# No need for model_cpu_offload on CPU

def mask_to_rgba(mask):
    """
    Converts a binary mask to an RGBA image for visualization.
    """
    bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
    bg_transparent[mask == 1] = [0, 255, 0, 127]  # Green with transparency
    return bg_transparent

def generate_mask(image, input_points):
    """
    Generates a binary mask using SAM based on input points.
    
    Args:
        image (PIL.Image): The input image.
        input_points (list of lists): List of points selected by the user.
        
    Returns:
        np.ndarray: Binary mask where the object is marked with 1s.
    """
    if not input_points:
        return None
    
    # Convert image to RGB if not already
    image = image.convert("RGB")
    
    # Flatten the list of points
    points = [tuple(point) for point in input_points]
    
    # Prepare inputs for SAM
    inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
    
    with torch.no_grad():
        outputs = sam_model(**inputs)
    
    # Post-process masks
    masks = sam_processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )
    
    if len(masks) == 0:
        return None
    
    # Select the mask with the highest IoU score
    best_mask = masks[0][0][outputs.iou_scores.argmax()]
    
    # Invert mask: object=1, background=0
    binary_mask = ~best_mask.numpy().astype(bool).astype(int)
    
    return binary_mask

def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
    """
    Replaces the selected object in the image based on the prompt.
    
    Args:
        image (PIL.Image): The original image.
        mask (np.ndarray): Binary mask of the selected object.
        prompt (str): Text prompt describing the replacement.
        negative_prompt (str): Negative text prompt to refine generation.
        seed (int): Random seed for reproducibility.
        guidance_scale (float): Guidance scale for the inpainting model.
        
    Returns:
        PIL.Image: The augmented image with the object replaced.
    """
    if mask is None:
        return image
    
    mask_image = Image.fromarray((mask * 255).astype(np.uint8))
    
    generator = torch.Generator("cpu").manual_seed(seed)
    
    try:
        result = inpaint_pipeline(
            prompt=prompt,
            image=image,
            mask_image=mask_image,
            negative_prompt=negative_prompt if negative_prompt else None,
            generator=generator,
            guidance_scale=guidance_scale
        ).images[0]
        return result
    except Exception as e:
        print(f"Inpainting error: {e}")
        return image

def visualize_mask(image, mask):
    """
    Overlays the mask on the image for visualization.
    
    Args:
        image (PIL.Image): The original image.
        mask (np.ndarray): Binary mask of the selected object.
        
    Returns:
        PIL.Image: Image with mask overlay.
    """
    if mask is None:
        return image
    
    mask_rgba = mask_to_rgba(mask)
    mask_pil = Image.fromarray(mask_rgba)
    overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil)
    return overlay.convert("RGB")

def process(image, points, prompt, negative_prompt, seed, guidance_scale):
    """
    Processes the image by replacing the selected object based on the prompt.
    
    Args:
        image (PIL.Image): Uploaded image.
        points (list of lists): Points selected on the image.
        prompt (str): Text prompt for replacement.
        negative_prompt (str): Negative text prompt.
        seed (int): Seed for reproducibility.
        guidance_scale (float): Guidance scale.
        
    Returns:
        Tuple of images: Original with mask overlay and augmented image.
    """
    mask = generate_mask(image, points)
    masked_image = visualize_mask(image, mask)
    augmented_image = replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale)
    return masked_image, augmented_image

# Define Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Object Replacement App")
    gr.Markdown(
        """
        Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
        """
    )
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload Image", type="pil", interactive=True, elem_id="image")
            points_input = gr.Points(
                label="Select Points on the Object",
                show_label=True,
                source="image",  # Links Points to the Image component via elem_id
                interactive=True
            )
            prompt_input = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
            negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
            seed_input = gr.Number(label="Seed", value=42)
            guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, value=7.5)
            process_button = gr.Button("Replace Object")
        with gr.Column():
            masked_output = gr.Image(label="Selected Object Mask Overlay")
            augmented_output = gr.Image(label="Augmented Image")
    
    # Bind the process function to the button click
    process_button.click(
        fn=process,
        inputs=[image_input, points_input, prompt_input, negative_prompt_input, seed_input, guidance_scale_input],
        outputs=[masked_output, augmented_output]
    )
    
    gr.Markdown(
        """
        **Instructions:**
        1. **Upload Image:** Upload the image containing the object you want to replace.
        2. **Select Points:** Click on the image to select points on the object. Use multiple points for better mask accuracy.
        3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
        4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
        5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
        """
    )

# Launch the app
demo.launch()