Himanshu-AT commited on
Commit
61dc46d
·
1 Parent(s): 71f7331

update ui, add download button + set inpaint

Browse files
Files changed (1) hide show
  1. app.py +88 -28
app.py CHANGED
@@ -37,16 +37,54 @@ for model_name, model_path in lora_models.items():
37
 
38
  lora_models["None"] = None
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  @spaces.GPU(durations=300)
41
- def infer(edit_images, prompt, width, height, lora_model, seed=42, randomize_seed=False, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
42
  # pipe.enable_xformers_memory_efficient_attention()
43
 
 
44
  if lora_model != "None":
45
  pipe.load_lora_weights(lora_models[lora_model])
46
  pipe.enable_lora()
47
 
48
  image = edit_images["background"]
49
- # width, height = calculate_optimal_dimensions(image)
50
  mask = edit_images["layers"][0]
51
  if randomize_seed:
52
  seed = random.randint(0, MAX_SEED)
@@ -72,6 +110,13 @@ def infer(edit_images, prompt, width, height, lora_model, seed=42, randomize_see
72
  return output_image_jpg, seed
73
  # return image, seed
74
 
 
 
 
 
 
 
 
75
  examples = [
76
  "photography of a young woman, accent lighting, (front view:1.4), "
77
  # "a tiny astronaut hatching from an egg on the moon",
@@ -150,31 +195,46 @@ with gr.Blocks(css=css) as demo:
150
  value=28,
151
  )
152
 
153
- with gr.Row():
154
 
155
- width = gr.Slider(
156
- label="width",
157
- minimum=512,
158
- maximum=3072,
159
- step=1,
160
- value=1024,
161
- )
162
 
163
- height = gr.Slider(
164
- label="height",
165
- minimum=512,
166
- maximum=3072,
167
- step=1,
168
- value=1024,
169
- )
170
 
171
  gr.on(
172
  triggers=[run_button.click, prompt.submit],
173
  fn = infer,
174
- inputs = [edit_image, prompt, width, height, lora_model, seed, randomize_seed, guidance_scale, num_inference_steps],
175
  outputs = [result, seed]
176
  )
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  # demo.launch()
179
  PASSWORD = os.getenv("GRADIO_PASSWORD")
180
  USERNAME = os.getenv("GRADIO_USERNAME")
@@ -262,12 +322,12 @@ demo.launch(auth=authenticate)
262
 
263
  # The mask_prompt is expected to be a comma-separated string of two integers,
264
  # e.g. "450,600" representing an (x,y) coordinate in the image.
265
-
266
  # The function converts the coordinate into the proper input format for SAM and returns a binary mask.
267
  # """
268
  # if mask_prompt.strip() == "":
269
  # raise ValueError("No mask prompt provided.")
270
-
271
  # try:
272
  # # Parse the mask_prompt into a coordinate
273
  # coords = [int(x.strip()) for x in mask_prompt.split(",")]
@@ -275,33 +335,33 @@ demo.launch(auth=authenticate)
275
  # raise ValueError("Expected two comma-separated integers (x,y).")
276
  # except Exception as e:
277
  # raise ValueError("Invalid mask prompt. Please provide coordinates as 'x,y'. Error: " + str(e))
278
-
279
  # # The SAM processor expects a list of input points.
280
  # # Format the point as a list of lists; here we assume one point per image.
281
  # # (The Transformers SAM expects the points in [x, y] order.)
282
  # input_points = [coords] # e.g. [[450,600]]
283
  # # Optionally, you can supply input_labels (1 for foreground, 0 for background)
284
  # input_labels = [1]
285
-
286
  # # Prepare the inputs for the SAM processor.
287
  # inputs = sam_processor(images=image,
288
  # input_points=[input_points],
289
  # input_labels=[input_labels],
290
  # return_tensors="pt")
291
-
292
  # # Move tensors to the same device as the model.
293
  # device = next(sam_model.parameters()).device
294
  # inputs = {k: v.to(device) for k, v in inputs.items()}
295
-
296
  # # Forward pass through SAM.
297
  # with torch.no_grad():
298
  # outputs = sam_model(**inputs)
299
-
300
  # # The output contains predicted masks; we take the first mask from the first prompt.
301
  # # (Assuming outputs.pred_masks is of shape (batch_size, num_masks, H, W))
302
  # pred_masks = outputs.pred_masks # Tensor of shape (1, num_masks, H, W)
303
  # mask = pred_masks[0][0].detach().cpu().numpy()
304
-
305
  # # Convert the mask to binary (0 or 255) using a threshold.
306
  # mask_bin = (mask > 0.5).astype(np.uint8) * 255
307
  # mask_pil = Image.fromarray(mask_bin)
@@ -387,14 +447,14 @@ demo.launch(auth=authenticate)
387
  # mask_preview = gr.Image(label="Mask Preview", show_label=True)
388
  # run_button = gr.Button("Run")
389
  # result = gr.Image(label="Result", show_label=False)
390
-
391
  # # Button to preview the generated mask.
392
  # def on_generate_mask(image, mask_prompt):
393
  # if image is None or mask_prompt.strip() == "":
394
  # return None
395
  # mask = generate_mask_with_sam(image, mask_prompt)
396
  # return mask
397
-
398
  # generate_mask_btn.click(
399
  # fn=on_generate_mask,
400
  # inputs=[edit_image, mask_prompt],
 
37
 
38
  lora_models["None"] = None
39
 
40
+ def calculate_optimal_dimensions(image: Image.Image):
41
+ # Extract the original dimensions
42
+ original_width, original_height = image.size
43
+
44
+ # Set constants
45
+ MIN_ASPECT_RATIO = 9 / 16
46
+ MAX_ASPECT_RATIO = 16 / 9
47
+ FIXED_DIMENSION = 1024
48
+
49
+ # Calculate the aspect ratio of the original image
50
+ original_aspect_ratio = original_width / original_height
51
+
52
+ # Determine which dimension to fix
53
+ if original_aspect_ratio > 1: # Wider than tall
54
+ width = FIXED_DIMENSION
55
+ height = round(FIXED_DIMENSION / original_aspect_ratio)
56
+ else: # Taller than wide
57
+ height = FIXED_DIMENSION
58
+ width = round(FIXED_DIMENSION * original_aspect_ratio)
59
+
60
+ # Ensure dimensions are multiples of 8
61
+ width = (width // 8) * 8
62
+ height = (height // 8) * 8
63
+
64
+ # Enforce aspect ratio limits
65
+ calculated_aspect_ratio = width / height
66
+ if calculated_aspect_ratio > MAX_ASPECT_RATIO:
67
+ width = (height * MAX_ASPECT_RATIO // 8) * 8
68
+ elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
69
+ height = (width / MIN_ASPECT_RATIO // 8) * 8
70
+
71
+ # Ensure width and height remain above the minimum dimensions
72
+ width = max(width, 576) if width == FIXED_DIMENSION else width
73
+ height = max(height, 576) if height == FIXED_DIMENSION else height
74
+
75
+ return width, height
76
+
77
  @spaces.GPU(durations=300)
78
+ def infer(edit_images, prompt, lora_model, seed=42, randomize_seed=False, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
79
  # pipe.enable_xformers_memory_efficient_attention()
80
 
81
+
82
  if lora_model != "None":
83
  pipe.load_lora_weights(lora_models[lora_model])
84
  pipe.enable_lora()
85
 
86
  image = edit_images["background"]
87
+ width, height = calculate_optimal_dimensions(image)
88
  mask = edit_images["layers"][0]
89
  if randomize_seed:
90
  seed = random.randint(0, MAX_SEED)
 
110
  return output_image_jpg, seed
111
  # return image, seed
112
 
113
+ def download_image(image):
114
+ image.save("output.png", "PNG")
115
+ return "output.png"
116
+
117
+ def set_image_as_inpaint(image):
118
+ return image
119
+
120
  examples = [
121
  "photography of a young woman, accent lighting, (front view:1.4), "
122
  # "a tiny astronaut hatching from an egg on the moon",
 
195
  value=28,
196
  )
197
 
198
+ # with gr.Row():
199
 
200
+ # width = gr.Slider(
201
+ # label="width",
202
+ # minimum=512,
203
+ # maximum=3072,
204
+ # step=1,
205
+ # value=1024,
206
+ # )
207
 
208
+ # height = gr.Slider(
209
+ # label="height",
210
+ # minimum=512,
211
+ # maximum=3072,
212
+ # step=1,
213
+ # value=1024,
214
+ # )
215
 
216
  gr.on(
217
  triggers=[run_button.click, prompt.submit],
218
  fn = infer,
219
+ inputs = [edit_image, prompt, lora_model, seed, randomize_seed, guidance_scale, num_inference_steps],
220
  outputs = [result, seed]
221
  )
222
 
223
+ download_button = gr.Button("Download Image as PNG")
224
+ set_inpaint_button = gr.Button("Set Image as Inpaint")
225
+
226
+ download_button.click(
227
+ fn=download_image,
228
+ inputs=[result],
229
+ outputs=gr.File(label="Download Image")
230
+ )
231
+
232
+ set_inpaint_button.click(
233
+ fn=set_image_as_inpaint,
234
+ inputs=[result],
235
+ outputs=[edit_image]
236
+ )
237
+
238
  # demo.launch()
239
  PASSWORD = os.getenv("GRADIO_PASSWORD")
240
  USERNAME = os.getenv("GRADIO_USERNAME")
 
322
 
323
  # The mask_prompt is expected to be a comma-separated string of two integers,
324
  # e.g. "450,600" representing an (x,y) coordinate in the image.
325
+
326
  # The function converts the coordinate into the proper input format for SAM and returns a binary mask.
327
  # """
328
  # if mask_prompt.strip() == "":
329
  # raise ValueError("No mask prompt provided.")
330
+
331
  # try:
332
  # # Parse the mask_prompt into a coordinate
333
  # coords = [int(x.strip()) for x in mask_prompt.split(",")]
 
335
  # raise ValueError("Expected two comma-separated integers (x,y).")
336
  # except Exception as e:
337
  # raise ValueError("Invalid mask prompt. Please provide coordinates as 'x,y'. Error: " + str(e))
338
+
339
  # # The SAM processor expects a list of input points.
340
  # # Format the point as a list of lists; here we assume one point per image.
341
  # # (The Transformers SAM expects the points in [x, y] order.)
342
  # input_points = [coords] # e.g. [[450,600]]
343
  # # Optionally, you can supply input_labels (1 for foreground, 0 for background)
344
  # input_labels = [1]
345
+
346
  # # Prepare the inputs for the SAM processor.
347
  # inputs = sam_processor(images=image,
348
  # input_points=[input_points],
349
  # input_labels=[input_labels],
350
  # return_tensors="pt")
351
+
352
  # # Move tensors to the same device as the model.
353
  # device = next(sam_model.parameters()).device
354
  # inputs = {k: v.to(device) for k, v in inputs.items()}
355
+
356
  # # Forward pass through SAM.
357
  # with torch.no_grad():
358
  # outputs = sam_model(**inputs)
359
+
360
  # # The output contains predicted masks; we take the first mask from the first prompt.
361
  # # (Assuming outputs.pred_masks is of shape (batch_size, num_masks, H, W))
362
  # pred_masks = outputs.pred_masks # Tensor of shape (1, num_masks, H, W)
363
  # mask = pred_masks[0][0].detach().cpu().numpy()
364
+
365
  # # Convert the mask to binary (0 or 255) using a threshold.
366
  # mask_bin = (mask > 0.5).astype(np.uint8) * 255
367
  # mask_pil = Image.fromarray(mask_bin)
 
447
  # mask_preview = gr.Image(label="Mask Preview", show_label=True)
448
  # run_button = gr.Button("Run")
449
  # result = gr.Image(label="Result", show_label=False)
450
+
451
  # # Button to preview the generated mask.
452
  # def on_generate_mask(image, mask_prompt):
453
  # if image is None or mask_prompt.strip() == "":
454
  # return None
455
  # mask = generate_mask_with_sam(image, mask_prompt)
456
  # return mask
457
+
458
  # generate_mask_btn.click(
459
  # fn=on_generate_mask,
460
  # inputs=[edit_image, mask_prompt],