vvaibhav commited on
Commit
18c979d
·
verified ·
1 Parent(s): 13cac7d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import torch
6
+ import numpy as np
7
+ from transformers import SamModel, SamProcessor
8
+ from diffusers import StableDiffusionInpaintPipeline
9
+ import io
10
+
11
+ # Initialize SAM model and processor
12
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
13
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
14
+
15
+ # Initialize Inpainting pipeline
16
+ inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
17
+ "runwayml/stable-diffusion-inpainting",
18
+ torch_dtype=torch.float16
19
+ ).to("cuda")
20
+ inpaint_pipeline.enable_model_cpu_offload()
21
+
22
+ def mask_to_rgba(mask):
23
+ """
24
+ Converts a binary mask to an RGBA image for visualization.
25
+ """
26
+ bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
27
+ bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
28
+ return bg_transparent
29
+
30
+ def generate_mask(image, input_points):
31
+ """
32
+ Generates a binary mask using SAM based on input points.
33
+
34
+ Args:
35
+ image (PIL.Image): The input image.
36
+ input_points (list of lists): List of points selected by the user.
37
+
38
+ Returns:
39
+ np.ndarray: Binary mask where the object is marked with 1s.
40
+ """
41
+ if not input_points:
42
+ return None
43
+
44
+ # Convert image to RGB if not already
45
+ image = image.convert("RGB")
46
+
47
+ # Flatten the list of points
48
+ points = [tuple(point) for point in input_points]
49
+
50
+ # Prepare inputs for SAM
51
+ inputs = sam_processor(image, points=points, return_tensors="pt").to("cuda")
52
+
53
+ with torch.no_grad():
54
+ outputs = sam_model(**inputs)
55
+
56
+ # Post-process masks
57
+ masks = sam_processor.image_processor.post_process_masks(
58
+ outputs.pred_masks.cpu(),
59
+ inputs["original_sizes"].cpu(),
60
+ inputs["reshaped_input_sizes"].cpu()
61
+ )
62
+
63
+ if len(masks) == 0:
64
+ return None
65
+
66
+ # Select the mask with the highest IoU score
67
+ best_mask = masks[0][0][outputs.iou_scores.argmax()]
68
+
69
+ # Invert mask: object=1, background=0
70
+ binary_mask = ~best_mask.numpy().astype(bool).astype(int)
71
+
72
+ return binary_mask
73
+
74
+ def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
75
+ """
76
+ Replaces the selected object in the image based on the prompt.
77
+
78
+ Args:
79
+ image (PIL.Image): The original image.
80
+ mask (np.ndarray): Binary mask of the selected object.
81
+ prompt (str): Text prompt describing the replacement.
82
+ negative_prompt (str): Negative text prompt to refine generation.
83
+ seed (int): Random seed for reproducibility.
84
+ guidance_scale (float): Guidance scale for the inpainting model.
85
+
86
+ Returns:
87
+ PIL.Image: The augmented image with the object replaced.
88
+ """
89
+ if mask is None:
90
+ return image
91
+
92
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8))
93
+
94
+ generator = torch.Generator("cuda").manual_seed(seed)
95
+
96
+ try:
97
+ result = inpaint_pipeline(
98
+ prompt=prompt,
99
+ image=image,
100
+ mask_image=mask_image,
101
+ negative_prompt=negative_prompt if negative_prompt else None,
102
+ generator=generator,
103
+ guidance_scale=guidance_scale
104
+ ).images[0]
105
+ return result
106
+ except Exception as e:
107
+ print(f"Inpainting error: {e}")
108
+ return image
109
+
110
+ def visualize_mask(image, mask):
111
+ """
112
+ Overlays the mask on the image for visualization.
113
+
114
+ Args:
115
+ image (PIL.Image): The original image.
116
+ mask (np.ndarray): Binary mask of the selected object.
117
+
118
+ Returns:
119
+ PIL.Image: Image with mask overlay.
120
+ """
121
+ if mask is None:
122
+ return image
123
+
124
+ mask_rgba = mask_to_rgba(mask)
125
+ mask_pil = Image.fromarray(mask_rgba)
126
+ overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil)
127
+ return overlay.convert("RGB")
128
+
129
+ def process(image, points, prompt, negative_prompt, seed, guidance_scale):
130
+ """
131
+ Processes the image by replacing the selected object based on the prompt.
132
+
133
+ Args:
134
+ image (PIL.Image): Uploaded image.
135
+ points (list of lists): Points selected on the image.
136
+ prompt (str): Text prompt for replacement.
137
+ negative_prompt (str): Negative text prompt.
138
+ seed (int): Seed for reproducibility.
139
+ guidance_scale (float): Guidance scale.
140
+
141
+ Returns:
142
+ Tuple of images: Original with mask overlay and augmented image.
143
+ """
144
+ mask = generate_mask(image, points)
145
+ masked_image = visualize_mask(image, mask)
146
+ augmented_image = replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale)
147
+ return masked_image, augmented_image
148
+
149
+ # Define Gradio Interface
150
+ with gr.Blocks() as demo:
151
+ gr.Markdown("# Object Replacement App")
152
+ gr.Markdown(
153
+ """
154
+ Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
155
+ """
156
+ )
157
+
158
+ with gr.Row():
159
+ with gr.Column():
160
+ image_input = gr.Image(label="Upload Image", type="pil")
161
+ prompt_input = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
162
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
163
+ seed_input = gr.Number(label="Seed", value=42)
164
+ guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, value=7.5)
165
+ process_button = gr.Button("Replace Object")
166
+ with gr.Column():
167
+ masked_output = gr.Image(label="Selected Object Mask Overlay")
168
+ augmented_output = gr.Image(label="Augmented Image")
169
+
170
+ image_input.change(fn=lambda img: img, inputs=image_input, outputs=masked_output)
171
+
172
+ process_button.click(
173
+ fn=process,
174
+ inputs=[image_input, gr.State(), prompt_input, negative_prompt_input, seed_input, guidance_scale_input],
175
+ outputs=[masked_output, augmented_output]
176
+ )
177
+
178
+ gr.Markdown(
179
+ """
180
+ **Instructions:**
181
+ 1. **Upload Image:** Upload the image containing the object you want to replace.
182
+ 2. **Select Points:** Click on the image to select points on the object. Use multiple points for better mask accuracy.
183
+ 3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
184
+ 4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
185
+ 5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
186
+ """
187
+ )
188
+
189
+ # Launch the app
190
+ demo.launch()