vvaibhav commited on
Commit
22ba4da
·
verified ·
1 Parent(s): a99346b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -190
app.py CHANGED
@@ -1,65 +1,36 @@
1
- # app.py
2
-
3
  import gradio as gr
 
4
  from PIL import Image, ImageDraw
5
  import torch
6
- import numpy as np
7
  from transformers import SamModel, SamProcessor
8
  from diffusers import StableDiffusionInpaintPipeline
9
 
10
  # Constants
11
  IMG_SIZE = 512
12
 
13
- # Initialize SAM model and processor on CPU
14
- sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
15
- sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
16
-
17
- # Initialize Inpainting pipeline on CPU with a compatible model
18
- inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
19
- "stabilityai/stable-diffusion-2-inpainting",
20
- torch_dtype=torch.float32
21
- ).to("cpu")
22
- # No need for model_cpu_offload on CPU
23
-
24
  # Global variables to store points and the original image
25
  input_points = []
26
  input_image = None
27
 
28
- def mask_to_rgba(mask):
29
  """
30
- Converts a binary mask to an RGBA image for visualization.
31
  """
32
- bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
33
- bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
34
- return bg_transparent
35
-
36
- def generate_mask(image, input_points):
37
- """
38
- Generates a binary mask using SAM based on input points.
39
-
40
- Args:
41
- image (PIL.Image): The input image.
42
- input_points (list of lists): List of points selected by the user.
43
-
44
- Returns:
45
- np.ndarray: Binary mask where the object is marked with 1s.
46
- """
47
- if not input_points:
48
  return None
49
-
50
- # Convert image to RGB if not already
51
  image = image.convert("RGB")
 
52
 
53
- # Flatten the list of points
54
- points = [tuple(point) for point in input_points]
 
55
 
56
- # Prepare inputs for SAM
57
  inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
58
 
59
  with torch.no_grad():
60
  outputs = sam_model(**inputs)
61
 
62
- # Post-process masks
63
  masks = sam_processor.image_processor.post_process_masks(
64
  outputs.pred_masks.cpu(),
65
  inputs["original_sizes"].cpu(),
@@ -69,32 +40,24 @@ def generate_mask(image, input_points):
69
  if len(masks) == 0:
70
  return None
71
 
72
- # Select the mask with the highest IoU score
73
  best_mask = masks[0][0][outputs.iou_scores.argmax()]
74
-
75
- # Invert mask: object=1, background=0
76
  binary_mask = ~best_mask.numpy().astype(bool).astype(int)
77
 
78
  return binary_mask
79
 
80
  def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
81
  """
82
- Replaces the selected object in the image based on the prompt.
83
-
84
- Args:
85
- image (PIL.Image): The original image.
86
- mask (np.ndarray): Binary mask of the selected object.
87
- prompt (str): Text prompt describing the replacement.
88
- negative_prompt (str): Negative text prompt to refine generation.
89
- seed (int): Random seed for reproducibility.
90
- guidance_scale (float): Guidance scale for the inpainting model.
91
-
92
- Returns:
93
- PIL.Image: The augmented image with the object replaced.
94
  """
95
  if mask is None:
96
  return image
97
 
 
 
 
 
 
 
98
  mask_image = Image.fromarray((mask * 255).astype(np.uint8))
99
 
100
  generator = torch.Generator("cpu").manual_seed(seed)
@@ -116,46 +79,33 @@ def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
116
  def visualize_mask(image, mask):
117
  """
118
  Overlays the mask on the image for visualization.
119
-
120
- Args:
121
- image (PIL.Image): The original image.
122
- mask (np.ndarray): Binary mask of the selected object.
123
-
124
- Returns:
125
- PIL.Image: Image with mask overlay.
126
  """
127
  if mask is None:
128
  return image
129
 
130
- mask_rgba = mask_to_rgba(mask)
131
- mask_pil = Image.fromarray(mask_rgba)
132
- overlay = Image.alpha_composite(image.convert("RGBA"), mask_pil)
 
 
133
  return overlay.convert("RGB")
134
 
135
  def get_points(img, evt: gr.SelectData):
136
  """
137
  Captures points selected by the user on the image.
138
-
139
- Args:
140
- img (PIL.Image): The uploaded image.
141
- evt (gr.SelectData): Event data containing the point coordinates.
142
-
143
- Returns:
144
- Tuple: (Updated mask visualization, Updated image with crossmarks)
145
  """
146
  global input_points
147
  global input_image
148
 
149
- # The first time this is called, save the untouched input image
150
  if len(input_points) == 0:
151
  input_image = img.copy()
152
 
153
  x = evt.index[0]
154
  y = evt.index[1]
155
-
156
  input_points.append([x, y])
157
 
158
- # Run SAM to generate mask
159
  mask = generate_mask(input_image, input_points)
160
 
161
  # Mark selected points with a green crossmark
@@ -174,16 +124,6 @@ def get_points(img, evt: gr.SelectData):
174
  def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
175
  """
176
  Runs the inpainting process based on user inputs.
177
-
178
- Args:
179
- prompt (str): Prompt for infill.
180
- negative_prompt (str): Negative prompt.
181
- cfg (float): Classifier-Free Guidance Scale.
182
- seed (int): Random seed.
183
- invert (bool): Whether to infill the subject instead of the background.
184
-
185
- Returns:
186
- PIL.Image: The inpainted image.
187
  """
188
  global input_image
189
  global input_points
@@ -209,9 +149,6 @@ def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
209
  def reset_points_func():
210
  """
211
  Resets the selected points and the input image.
212
-
213
- Returns:
214
- Tuple: (Reset mask visualization, Reset image, Empty inpainted image)
215
  """
216
  global input_points
217
  global input_image
@@ -222,19 +159,11 @@ def reset_points_func():
222
  def preprocess(input_img):
223
  """
224
  Preprocesses the uploaded image to ensure it is square and resized.
225
-
226
- Args:
227
- input_img (PIL.Image): The uploaded image.
228
-
229
- Returns:
230
- PIL.Image: The preprocessed image.
231
  """
232
  if input_img is None:
233
  return None
234
 
235
- # Make sure the image is square
236
  width, height = input_img.size
237
-
238
  if width != height:
239
  # Add white padding to make the image square
240
  new_size = max(width, height)
@@ -246,104 +175,142 @@ def preprocess(input_img):
246
 
247
  return input_img.resize((IMG_SIZE, IMG_SIZE))
248
 
249
- def build_app(get_processed_inputs, inpaint):
250
- """
251
- Builds and launches the Gradio app.
252
-
253
- Args:
254
- get_processed_inputs (function): Function to process inputs for SAM.
255
- inpaint (function): Function to perform inpainting.
256
-
257
- Returns:
258
- None
259
- """
260
- with gr.Blocks() as demo:
261
 
262
- gr.Markdown(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  """
264
- # Object Replacement App
265
- Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
266
-
267
- **Instructions:**
268
- 1. **Upload Image:** Click on the first image box to upload your image.
269
- 2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
270
- 3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
271
- 4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
272
- 5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
273
- 6. **Reset:** Click the "Reset" button to clear selections and start over.
274
  """)
275
-
276
- with gr.Row():
277
- with gr.Column():
278
- # Image upload and point selection
279
- upload_image = gr.Image(label="Upload Image", type="pil", interactive=True)
280
- mask_visualization = gr.Image(label="Selected Object Mask Overlay", interactive=False)
281
- selected_image = gr.Image(label="Image with Selected Points", type="pil", interactive=False)
282
-
283
- # Capture points using the select event
284
- upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
285
-
286
- # Preprocess image on change
287
- upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
288
-
289
- # Text inputs and settings
290
- prompt = gr.Textbox(label="Replacement Prompt", placeholder="e.g., a red sports car", lines=2)
291
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="e.g., blurry, low quality", lines=2)
292
- cfg = gr.Slider(
293
- label="Classifier-Free Guidance Scale",
294
- minimum=1.0,
295
- maximum=20.0,
296
- value=7.5,
297
- step=0.5
298
- )
299
- seed = gr.Number(label="Seed", value=42, precision=0)
300
- invert = gr.Checkbox(label="Infill subject instead of background")
301
-
302
- # Buttons
303
- replace_button = gr.Button("Replace Object")
304
- reset_button = gr.Button("Reset")
305
- with gr.Column():
306
- # Output images
307
- augmented_image = gr.Image(label="Augmented Image", type="pil", interactive=False)
308
-
309
- # Define button actions
310
- replace_button.click(
311
- fn=run_inpaint,
312
- inputs=[prompt, negative_prompt, cfg, seed, invert],
313
- outputs=[augmented_image]
314
- )
315
-
316
- reset_button.click(
317
- fn=reset_points_func,
318
- inputs=[],
319
- outputs=[mask_visualization, selected_image, augmented_image]
320
- )
321
-
322
- # Examples (optional)
323
- gr.Markdown(
324
- """
325
- ## EXAMPLES
326
- Click on an example to load it. Then, follow the instructions above.
327
- """)
328
-
329
- with gr.Row():
330
- examples = gr.Examples(
331
- examples=[
332
- ["car.png", "a red sports car", "blurry, low quality", 42],
333
- ["house.jpg", "a modern villa", "dark, overexposed", 123],
334
- ["tree.png", "a blooming cherry tree", "underexposed, low contrast", 999]
335
  ],
336
- inputs=[
337
- upload_image,
338
- prompt,
339
- negative_prompt,
340
- seed
341
  ],
342
- label="Click to load examples",
343
- cache_examples=True
344
- )
345
-
346
- demo.queue(max_size=10).launch()
347
-
348
- # Launch the app
349
- build_app(None, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
  from PIL import Image, ImageDraw
4
  import torch
 
5
  from transformers import SamModel, SamProcessor
6
  from diffusers import StableDiffusionInpaintPipeline
7
 
8
  # Constants
9
  IMG_SIZE = 512
10
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Global variables to store points and the original image
12
  input_points = []
13
  input_image = None
14
 
15
+ def generate_mask(image, points):
16
  """
17
+ Generates a mask using SAM based on input points.
18
  """
19
+ if not points:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  return None
21
+
 
22
  image = image.convert("RGB")
23
+ points = [tuple(point) for point in points]
24
 
25
+ # Initialize SAM model and processor on CPU
26
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge", torch_dtype=torch.float32).to("cpu")
27
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
28
 
 
29
  inputs = sam_processor(image, points=points, return_tensors="pt").to("cpu")
30
 
31
  with torch.no_grad():
32
  outputs = sam_model(**inputs)
33
 
 
34
  masks = sam_processor.image_processor.post_process_masks(
35
  outputs.pred_masks.cpu(),
36
  inputs["original_sizes"].cpu(),
 
40
  if len(masks) == 0:
41
  return None
42
 
 
43
  best_mask = masks[0][0][outputs.iou_scores.argmax()]
 
 
44
  binary_mask = ~best_mask.numpy().astype(bool).astype(int)
45
 
46
  return binary_mask
47
 
48
  def replace_object(image, mask, prompt, negative_prompt, seed, guidance_scale):
49
  """
50
+ Replaces the object in the image based on the mask and prompt.
 
 
 
 
 
 
 
 
 
 
 
51
  """
52
  if mask is None:
53
  return image
54
 
55
+ # Initialize Inpainting pipeline on CPU with a compatible model
56
+ inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
57
+ "stabilityai/stable-diffusion-2-inpainting",
58
+ torch_dtype=torch.float32
59
+ ).to("cpu")
60
+
61
  mask_image = Image.fromarray((mask * 255).astype(np.uint8))
62
 
63
  generator = torch.Generator("cpu").manual_seed(seed)
 
79
  def visualize_mask(image, mask):
80
  """
81
  Overlays the mask on the image for visualization.
 
 
 
 
 
 
 
82
  """
83
  if mask is None:
84
  return image
85
 
86
+ bg_transparent = np.zeros(mask.shape + (4,), dtype=np.uint8)
87
+ bg_transparent[mask == 1] = [0, 255, 0, 127] # Green with transparency
88
+ mask_rgba = Image.fromarray(bg_transparent)
89
+
90
+ overlay = Image.alpha_composite(image.convert("RGBA"), mask_rgba)
91
  return overlay.convert("RGB")
92
 
93
  def get_points(img, evt: gr.SelectData):
94
  """
95
  Captures points selected by the user on the image.
 
 
 
 
 
 
 
96
  """
97
  global input_points
98
  global input_image
99
 
 
100
  if len(input_points) == 0:
101
  input_image = img.copy()
102
 
103
  x = evt.index[0]
104
  y = evt.index[1]
105
+
106
  input_points.append([x, y])
107
 
108
+ # Generate mask based on selected points
109
  mask = generate_mask(input_image, input_points)
110
 
111
  # Mark selected points with a green crossmark
 
124
  def run_inpaint(prompt, negative_prompt, cfg, seed, invert):
125
  """
126
  Runs the inpainting process based on user inputs.
 
 
 
 
 
 
 
 
 
 
127
  """
128
  global input_image
129
  global input_points
 
149
  def reset_points_func():
150
  """
151
  Resets the selected points and the input image.
 
 
 
152
  """
153
  global input_points
154
  global input_image
 
159
  def preprocess(input_img):
160
  """
161
  Preprocesses the uploaded image to ensure it is square and resized.
 
 
 
 
 
 
162
  """
163
  if input_img is None:
164
  return None
165
 
 
166
  width, height = input_img.size
 
167
  if width != height:
168
  # Add white padding to make the image square
169
  new_size = max(width, height)
 
175
 
176
  return input_img.resize((IMG_SIZE, IMG_SIZE))
177
 
178
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
179
 
180
+ gr.Markdown(
181
+ """
182
+ # Object Replacement App
183
+ Upload an image, select points on the object you want to replace, provide a text prompt for the replacement, and view the augmented image.
184
+
185
+ **Instructions:**
186
+ 1. **Upload Image:** Click on the first image box to upload your image.
187
+ 2. **Select Points:** Click on the image to select points on the object you wish to replace. Use multiple points for better mask accuracy.
188
+ 3. **Enter Prompts:** Provide a replacement prompt and optionally a negative prompt to refine the output.
189
+ 4. **Adjust Settings:** Set the seed for reproducibility and adjust the guidance scale as needed.
190
+ 5. **Replace Object:** Click the "Replace Object" button to generate the augmented image.
191
+ 6. **Reset:** Click the "Reset" button to clear selections and start over.
192
+ """)
193
+
194
+ with gr.Row():
195
+ with gr.Column():
196
+ # Image upload and point selection
197
+ upload_image = gr.Image(
198
+ label="Upload Image",
199
+ type="pil",
200
+ interactive=True,
201
+ height=IMG_SIZE,
202
+ width=IMG_SIZE
203
+ )
204
+ mask_visualization = gr.Image(
205
+ label="Selected Object Mask Overlay",
206
+ interactive=False,
207
+ height=IMG_SIZE,
208
+ width=IMG_SIZE
209
+ )
210
+ selected_image = gr.Image(
211
+ label="Image with Selected Points",
212
+ type="pil",
213
+ interactive=False,
214
+ height=IMG_SIZE,
215
+ width=IMG_SIZE,
216
+ )
217
+
218
+ # Capture points using the select event
219
+ upload_image.select(get_points, inputs=[upload_image], outputs=[mask_visualization, selected_image])
220
+
221
+ # Preprocess image on change
222
+ upload_image.change(preprocess, inputs=[upload_image], outputs=[upload_image])
223
+
224
+ # Text inputs and settings
225
+ prompt = gr.Textbox(
226
+ label="Replacement Prompt",
227
+ placeholder="e.g., a red sports car",
228
+ lines=2
229
+ )
230
+ negative_prompt = gr.Textbox(
231
+ label="Negative Prompt",
232
+ placeholder="e.g., blurry, low quality",
233
+ lines=2
234
+ )
235
+ cfg = gr.Slider(
236
+ label="Classifier-Free Guidance Scale",
237
+ minimum=1.0,
238
+ maximum=20.0,
239
+ value=7.5,
240
+ step=0.5
241
+ )
242
+ seed = gr.Number(
243
+ label="Seed",
244
+ value=42,
245
+ precision=0
246
+ )
247
+ invert = gr.Checkbox(
248
+ label="Infill subject instead of background"
249
+ )
250
+
251
+ # Buttons
252
+ replace_button = gr.Button("Replace Object")
253
+ reset_button = gr.Button("Reset")
254
+ with gr.Column():
255
+ # Output images
256
+ augmented_image = gr.Image(
257
+ label="Augmented Image",
258
+ type="pil",
259
+ interactive=False,
260
+ height=IMG_SIZE,
261
+ width=IMG_SIZE,
262
+ )
263
+
264
+ # Define button actions
265
+ replace_button.click(
266
+ fn=run_inpaint,
267
+ inputs=[prompt, negative_prompt, cfg, seed, invert],
268
+ outputs=[augmented_image]
269
+ )
270
+
271
+ reset_button.click(
272
+ fn=reset_points_func,
273
+ inputs=[],
274
+ outputs=[mask_visualization, selected_image, augmented_image]
275
+ )
276
+
277
+ # Examples (optional)
278
+ gr.Markdown(
279
  """
280
+ ## EXAMPLES
281
+ Click on an example to load it. Then, follow the instructions above.
 
 
 
 
 
 
 
 
282
  """)
283
+
284
+ with gr.Row():
285
+ examples = gr.Examples(
286
+ examples=[
287
+ [
288
+ "car.png",
289
+ "a red sports car",
290
+ "blurry, low quality",
291
+ 42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  ],
293
+ [
294
+ "house.jpg",
295
+ "a modern villa",
296
+ "dark, overexposed",
297
+ 123
298
  ],
299
+ [
300
+ "tree.png",
301
+ "a blooming cherry tree",
302
+ "underexposed, low contrast",
303
+ 999
304
+ ]
305
+ ],
306
+ inputs=[
307
+ upload_image,
308
+ prompt,
309
+ negative_prompt,
310
+ seed
311
+ ],
312
+ label="Click to load examples",
313
+ cache_examples=False # Set to False to avoid the error
314
+ )
315
+
316
+ demo.queue(max_size=10).launch()