ginipick commited on
Commit
7ae084c
·
verified ·
1 Parent(s): 9a8a6d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -59
app.py CHANGED
@@ -196,26 +196,23 @@ def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
196
  ratio = max_size / max(width, height)
197
  width = int(width * ratio)
198
  height = int(height * ratio)
199
- width, height = adjust_size_to_multiple_of_8(width, height)
200
-
201
- with timer("Background generation"):
202
- try:
203
- with torch.inference_mode():
204
- image = pipe(
205
- prompt=prompt,
206
- width=width,
207
- height=height,
208
- num_inference_steps=8,
209
- guidance_scale=4.0
210
- ).images[0]
211
- except Exception as e:
212
- print(f"Pipeline error: {str(e)}")
213
- return Image.new('RGB', (width, height), 'white')
214
-
215
  return image
 
216
  except Exception as e:
217
  print(f"Background generation error: {str(e)}")
218
- return Image.new('RGB', (512, 512), 'white')
 
 
219
 
220
  def create_position_grid():
221
  return """
@@ -273,23 +270,20 @@ def combine_with_background(foreground: Image.Image, background: Image.Image,
273
  result.paste(scaled_foreground, (x, y), scaled_foreground)
274
  return result
275
 
276
- @spaces.GPU(duration=30) # 120초에서 30초로 감소
277
  def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
278
- time_log: list[str] = []
279
  try:
280
- if isinstance(prompt, str):
281
- t0 = time.time()
282
- bbox = gd_detect(img, prompt)
283
- time_log.append(f"detect: {time.time() - t0}")
284
- if not bbox:
285
- print(time_log[0])
286
- raise gr.Error("No object detected")
287
- else:
288
- bbox = prompt
289
- t0 = time.time()
290
- mask = segmenter(img, bbox)
291
- time_log.append(f"segment: {time.time() - t0}")
292
- return mask, bbox, time_log
293
  except Exception as e:
294
  print(f"GPU process error: {str(e)}")
295
  raise
@@ -345,37 +339,40 @@ def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None,
345
  aspect_ratio: str = "1:1", position: str = "bottom-center",
346
  scale_percent: float = 100) -> tuple[Image.Image, Image.Image]:
347
  try:
348
- if img is None or prompt.strip() == "":
349
  raise gr.Error("Please provide both image and prompt")
350
 
351
- print(f"Processing with position: {position}, scale: {scale_percent}")
352
-
353
- try:
354
- prompt = translate_to_english(prompt)
355
- if bg_prompt:
356
- bg_prompt = translate_to_english(bg_prompt)
357
- except Exception as e:
358
- print(f"Translation error (continuing with original text): {str(e)}")
359
 
360
- results, _ = _process(img, prompt, bg_prompt, aspect_ratio)
 
 
361
 
362
- if bg_prompt:
363
- try:
364
- combined = combine_with_background(
365
- foreground=results[2],
366
- background=results[1],
367
- position=position,
368
- scale_percent=scale_percent
369
- )
370
- print(f"Combined image created with position: {position}")
371
- return combined, results[2]
372
- except Exception as e:
373
- print(f"Combination error: {str(e)}")
374
- return results[1], results[2]
375
-
376
- return results[1], results[2]
 
 
 
 
377
  except Exception as e:
378
- print(f"Error in process_prompt: {str(e)}")
379
  raise gr.Error(str(e))
380
  finally:
381
  clear_memory()
@@ -482,6 +479,18 @@ button.primary:hover {
482
  }
483
  """
484
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  # UI 구성
486
  # UI 구성 부분에서 process_btn을 위로 이동하고 position_grid.click 부분 제거
487
 
@@ -627,7 +636,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
627
  )
628
 
629
 
630
- demo.queue(max_size=10) # 큐 크기 제한
631
  demo.launch(
632
  server_name="0.0.0.0",
633
  server_port=7860,
 
196
  ratio = max_size / max(width, height)
197
  width = int(width * ratio)
198
  height = int(height * ratio)
199
+
200
+ with torch.inference_mode():
201
+ image = pipe(
202
+ prompt=prompt,
203
+ width=width,
204
+ height=height,
205
+ num_inference_steps=8,
206
+ guidance_scale=4.0
207
+ ).images[0]
208
+
 
 
 
 
 
 
209
  return image
210
+
211
  except Exception as e:
212
  print(f"Background generation error: {str(e)}")
213
+ return Image.new('RGB', (width, height), 'white')
214
+ finally:
215
+ clear_memory()
216
 
217
  def create_position_grid():
218
  return """
 
270
  result.paste(scaled_foreground, (x, y), scaled_foreground)
271
  return result
272
 
273
+ @spaces.GPU(duration=30)
274
  def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
 
275
  try:
276
+ with torch.inference_mode():
277
+ if isinstance(prompt, str):
278
+ bbox = gd_detect(img, prompt)
279
+ if not bbox:
280
+ raise gr.Error("No object detected in image")
281
+ else:
282
+ bbox = prompt
283
+
284
+ mask = segmenter(img, bbox)
285
+ return mask, bbox, []
286
+
 
 
287
  except Exception as e:
288
  print(f"GPU process error: {str(e)}")
289
  raise
 
339
  aspect_ratio: str = "1:1", position: str = "bottom-center",
340
  scale_percent: float = 100) -> tuple[Image.Image, Image.Image]:
341
  try:
342
+ if img is None or not prompt or prompt.isspace():
343
  raise gr.Error("Please provide both image and prompt")
344
 
345
+ # 입력 이미지 크기 제한
346
+ max_size = 1024
347
+ if img.width > max_size or img.height > max_size:
348
+ ratio = max_size / max(img.width, img.height)
349
+ img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
 
 
 
350
 
351
+ # 번역 처리
352
+ translated_prompt = translate_to_english(prompt)
353
+ translated_bg_prompt = translate_to_english(bg_prompt) if bg_prompt else None
354
 
355
+ # 이미지 처리
356
+ with torch.inference_mode():
357
+ results, _ = _process(img, translated_prompt, translated_bg_prompt, aspect_ratio)
358
+
359
+ if translated_bg_prompt:
360
+ try:
361
+ combined = combine_with_background(
362
+ foreground=results[2],
363
+ background=results[1],
364
+ position=position,
365
+ scale_percent=scale_slider
366
+ )
367
+ return combined, results[2]
368
+ except Exception as e:
369
+ print(f"Background combination error: {e}")
370
+ return results[1], results[2]
371
+
372
+ return results[1], results[2]
373
+
374
  except Exception as e:
375
+ print(f"Process error: {str(e)}")
376
  raise gr.Error(str(e))
377
  finally:
378
  clear_memory()
 
479
  }
480
  """
481
 
482
+ ###--------------ZERO GPU 필수/ 메모리 관리 공통 --------------------###
483
+ def clear_memory():
484
+ gc.collect()
485
+ if torch.cuda.is_available():
486
+ try:
487
+ torch.cuda.empty_cache()
488
+ torch.cuda.synchronize()
489
+ except:
490
+ pass
491
+
492
+
493
+
494
  # UI 구성
495
  # UI 구성 부분에서 process_btn을 위로 이동하고 position_grid.click 부분 제거
496
 
 
636
  )
637
 
638
 
639
+ demo.queue(max_size=5) # 큐 크기 제한
640
  demo.launch(
641
  server_name="0.0.0.0",
642
  server_port=7860,