vvaibhav commited on
Commit
a99346b
·
verified ·
1 Parent(s): 19001b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -53
app.py CHANGED
@@ -1,12 +1,14 @@
1
  # app.py
2
 
3
  import gradio as gr
4
- from PIL import Image
5
  import torch
6
  import numpy as np
7
  from transformers import SamModel, SamProcessor
8
  from diffusers import StableDiffusionInpaintPipeline
9
- import io
 
 
10
 
11
  # Initialize SAM model and processor on CPU
12
  sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
@@ -19,6 +21,10 @@ inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
19
  ).to("cpu")
20
  # No need for model_cpu_offload on CPU
21
 
 
 
 
 
22
  def mask_to_rgba(mask):
23
  """
24
  Converts a binary mask to an RGBA image for visualization.
@@ -126,70 +132,218 @@ def visualize_mask(image, mask):
126
  overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil)
127
  return overlay.convert("RGB")
128
 
129
- def process(image, points, prompt, negative_prompt, seed, guidance_scale):
130
  """
131
- Processes the image by replacing the selected object based on the prompt.
132
 
133
  Args:
134
- image (PIL.Image): Uploaded image.
135
- points (list of lists): Points selected on the image.
136
- prompt (str): Text prompt for replacement.
137
- negative_prompt (str): Negative text prompt.
138
- seed (int): Seed for reproducibility.
139
- guidance_scale (float): Guidance scale.
140
 
141
  Returns:
142
- Tuple of images: Original with mask overlay and augmented image.
143
  """
144
- mask = generate_mask(image, points)
145
- masked_image = visualize_mask(image, mask)
146
- augmented_image = replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale)
147
- return masked_image, augmented_image
 
 
 
 
 
148
 
149
- # Define Gradio Interface
150
- with gr.Blocks() as demo:
151
- gr.Markdown("# Object Replacement App")
152
- gr.Markdown(
153
- """
154
- Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
155
- """
156
- )
157
 
158
- with gr.Row():
159
- with gr.Column():
160
- image_input = gr.Image(label="Upload Image", type="pil", interactive=True, elem_id="image")
161
- points_input = gr.Points(
162
- label="Select Points on the Object",
163
- show_label=True,
164
- source="image", # Links Points to the Image component via elem_id
165
- interactive=True
166
- )
167
- prompt_input = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
168
- negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
169
- seed_input = gr.Number(label="Seed", value=42)
170
- guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=1, maximum=20, value=7.5)
171
- process_button = gr.Button("Replace Object")
172
- with gr.Column():
173
- masked_output = gr.Image(label="Selected Object Mask Overlay")
174
- augmented_output = gr.Image(label="Augmented Image")
175
-
176
- # Bind the process function to the button click
177
- process_button.click(
178
- fn=process,
179
- inputs=[image_input, points_input, prompt_input, negative_prompt_input, seed_input, guidance_scale_input],
180
- outputs=[masked_output, augmented_output]
181
- )
182
 
183
- gr.Markdown(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  """
 
 
 
185
  **Instructions:**
186
- 1. **Upload Image:** Upload the image containing the object you want to replace.
187
- 2. **Select Points:** Click on the image to select points on the object. Use multiple points for better mask accuracy.
188
  3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
189
  4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
190
  5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
191
- """
192
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  # Launch the app
195
- demo.launch()
 
1
  # app.py
2
 
3
  import gradio as gr
4
+ from PIL import Image, ImageDraw
5
  import torch
6
  import numpy as np
7
  from transformers import SamModel, SamProcessor
8
  from diffusers import StableDiffusionInpaintPipeline
9
+
10
+ # Constants
11
+ IMG_SIZE = 512
12
 
13
  # Initialize SAM model and processor on CPU
14
  sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
 
21
  ).to("cpu")
22
  # No need for model_cpu_offload on CPU
23
 
24
+ # Global variables to store points and the original image
25
+ input_points = []
26
+ input_image = None
27
+
28
  def mask_to_rgba(mask):
29
  """
30
  Converts a binary mask to an RGBA image for visualization.
 
132
  overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil)
133
  return overlay.convert("RGB")
134
 
135
+ def get_points(img, evt: gr.SelectData):
136
  """
137
+ Captures points selected by the user on the image.
138
 
139
  Args:
140
+ img (PIL.Image): The uploaded image.
141
+ evt (gr.SelectData): Event data containing the point coordinates.
 
 
 
 
142
 
143
  Returns:
144
+ Tuple: (Updated mask visualization, Updated image with crossmarks)
145
  """
146
+ global input_points
147
+ global input_image
148
+
149
+ # The first time this is called, save the untouched input image
150
+ if len(input_points) == 0:
151
+ input_image = img.copy()
152
+
153
+ x = evt.index[0]
154
+ y = evt.index[1]
155
 
156
+ input_points.append([x, y])
 
 
 
 
 
 
 
157
 
158
+ # Run SAM to generate mask
159
+ mask = generate_mask(input_image, input_points)
160
+
161
+ # Mark selected points with a green crossmark
162
+ draw = ImageDraw.Draw(img)
163
+ size = 10
164
+ for point in input_points:
165
+ px, py = point
166
+ draw.line((px - size, py, px + size, py), fill="green", width=5)
167
+ draw.line((px, py - size, px, py + size), fill="green", width=5)
168
+
169
+ # Visualize the mask overlay
170
+ masked_image = visualize_mask(input_image, mask)
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ return masked_image, img
173
+
174
+ def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
175
+ """
176
+ Runs the inpainting process based on user inputs.
177
+
178
+ Args:
179
+ prompt (str): Prompt for infill.
180
+ negative_prompt (str): Negative prompt.
181
+ cfg (float): Classifier-Free Guidance Scale.
182
+ seed (int): Random seed.
183
+ invert (bool): Whether to infill the subject instead of the background.
184
+
185
+ Returns:
186
+ PIL.Image: The inpainted image.
187
+ """
188
+ global input_image
189
+ global input_points
190
+
191
+ if input_image is None or len(input_points) == 0:
192
+ raise gr.Error("No points provided. Click on the image to select the object to segment with SAM.")
193
+
194
+ mask = generate_mask(input_image, input_points)
195
+
196
+ if invert:
197
+ what = 'subject'
198
+ mask = ~mask
199
+ else:
200
+ what = 'background'
201
+
202
+ try:
203
+ inpainted = replace_object(input_image, mask, prompt, negative_prompt, seed, cfg)
204
+ except Exception as e:
205
+ raise gr.Error(str(e))
206
+
207
+ return inpainted.resize((IMG_SIZE, IMG_SIZE))
208
+
209
+ def reset_points_func():
210
+ """
211
+ Resets the selected points and the input image.
212
+
213
+ Returns:
214
+ Tuple: (Reset mask visualization, Reset image, Empty inpainted image)
215
+ """
216
+ global input_points
217
+ global input_image
218
+ input_points = []
219
+ input_image = None
220
+ return None, None, None
221
+
222
+ def preprocess(input_img):
223
+ """
224
+ Preprocesses the uploaded image to ensure it is square and resized.
225
+
226
+ Args:
227
+ input_img (PIL.Image): The uploaded image.
228
+
229
+ Returns:
230
+ PIL.Image: The preprocessed image.
231
+ """
232
+ if input_img is None:
233
+ return None
234
+
235
+ # Make sure the image is square
236
+ width, height = input_img.size
237
+
238
+ if width != height:
239
+ # Add white padding to make the image square
240
+ new_size = max(width, height)
241
+ new_image = Image.new("RGB", (new_size, new_size), 'white')
242
+ left = (new_size - width) // 2
243
+ top = (new_size - height) // 2
244
+ new_image.paste(input_img, (left, top))
245
+ input_img = new_image
246
+
247
+ return input_img.resize((IMG_SIZE, IMG_SIZE))
248
+
249
+ def build_app(get_processed_inputs, inpaint):
250
+ """
251
+ Builds and launches the Gradio app.
252
+
253
+ Args:
254
+ get_processed_inputs (function): Function to process inputs for SAM.
255
+ inpaint (function): Function to perform inpainting.
256
+
257
+ Returns:
258
+ None
259
+ """
260
+ with gr.Blocks() as demo:
261
+
262
+ gr.Markdown(
263
  """
264
+ # Object Replacement App
265
+ Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
266
+
267
  **Instructions:**
268
+ 1. **Upload Image:** Click on the first image box to upload your image.
269
+ 2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
270
  3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
271
  4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
272
  5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
273
+ 6. **Reset:** Click the "Reset" button to clear selections and start over.
274
+ """)
275
+
276
+ with gr.Row():
277
+ with gr.Column():
278
+ # Image upload and point selection
279
+ upload_image = gr.Image(label="Upload Image", type="pil", interactive=True)
280
+ mask_visualization = gr.Image(label="Selected Object Mask Overlay", interactive=False)
281
+ selected_image = gr.Image(label="Image with Selected Points", type="pil", interactive=False)
282
+
283
+ # Capture points using the select event
284
+ upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
285
+
286
+ # Preprocess image on change
287
+ upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
288
+
289
+ # Text inputs and settings
290
+ prompt = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
291
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
292
+ cfg = gr.Slider(
293
+ label="Classifier-Free Guidance Scale",
294
+ minimum=1.0,
295
+ maximum=20.0,
296
+ value=7.5,
297
+ step=0.5
298
+ )
299
+ seed = gr.Number(label="Seed", value=42, precision=0)
300
+ invert = gr.Checkbox(label="Infill subject instead of background")
301
+
302
+ # Buttons
303
+ replace_button = gr.Button("Replace Object")
304
+ reset_button = gr.Button("Reset")
305
+ with gr.Column():
306
+ # Output images
307
+ augmented_image = gr.Image(label="Augmented Image", type="pil", interactive=False)
308
+
309
+ # Define button actions
310
+ replace_button.click(
311
+ fn=run_inpaint,
312
+ inputs=[prompt, negative_prompt, cfg, seed, invert],
313
+ outputs=[augmented_image]
314
+ )
315
+
316
+ reset_button.click(
317
+ fn=reset_points_func,
318
+ inputs=[],
319
+ outputs=[mask_visualization, selected_image, augmented_image]
320
+ )
321
+
322
+ # Examples (optional)
323
+ gr.Markdown(
324
+ """
325
+ ## EXAMPLES
326
+ Click on an example to load it. Then, follow the instructions above.
327
+ """)
328
+
329
+ with gr.Row():
330
+ examples = gr.Examples(
331
+ examples=[
332
+ ["car.png", "a red sports car", "blurry, low quality", 42],
333
+ ["house.jpg", "a modern villa", "dark, overexposed", 123],
334
+ ["tree.png", "a blooming cherry tree", "underexposed, low contrast", 999]
335
+ ],
336
+ inputs=[
337
+ upload_image,
338
+ prompt,
339
+ negative_prompt,
340
+ seed
341
+ ],
342
+ label="Click to load examples",
343
+ cache_examples=True
344
+ )
345
+
346
+ demo.queue(max_size=10).launch()
347
 
348
  # Launch the app
349
+ build_app(None, None)