ginipick commited on
Commit
deb56c6
ยท
verified ยท
1 Parent(s): 2a729b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -42
app.py CHANGED
@@ -23,12 +23,18 @@ from diffusers import FluxPipeline
23
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
24
 
25
  import gc
 
26
 
 
27
  def clear_memory():
28
  """๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
 
 
 
29
  gc.collect()
30
- torch.cuda.empty_cache()
31
-
 
32
 
33
 
34
  model_name = "Helsinki-NLP/opus-mt-ko-en"
@@ -79,13 +85,17 @@ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_
79
  gd_model = gd_model.to(device=device)
80
  assert isinstance(gd_model, GroundingDinoForObjectDetection)
81
 
82
- # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ๋ถ€๋ถ„ ์ˆ˜์ •
83
  pipe = FluxPipeline.from_pretrained(
84
  "black-forest-labs/FLUX.1-dev",
85
- torch_dtype=torch.float16, # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๊ฐ์†Œ๋ฅผ ์œ„ํ•ด float16 ์‚ฌ์šฉ
86
  use_auth_token=HF_TOKEN,
87
- device_map="balanced" # 'auto' ๋Œ€์‹  'balanced' ์‚ฌ์šฉ
88
  )
 
 
 
 
89
  pipe.load_lora_weights(
90
  hf_hub_download(
91
  "ByteDance/Hyper-SD",
@@ -95,7 +105,8 @@ pipe.load_lora_weights(
95
  )
96
  pipe.fuse_lora(lora_scale=0.125)
97
 
98
-
 
99
 
100
 
101
  class timer:
@@ -171,37 +182,32 @@ def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int,
171
  return base_size, base_size
172
 
173
  def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
174
- """๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ํ•จ์ˆ˜"""
175
  try:
176
- # ์„ ํƒ๋œ ๋น„์œจ์— ๋”ฐ๋ผ ํฌ๊ธฐ ๊ณ„์‚ฐ
177
  width, height = calculate_dimensions(aspect_ratio)
178
-
179
- # 8์˜ ๋ฐฐ์ˆ˜๋กœ ์กฐ์ •
180
  width, height = adjust_size_to_multiple_of_8(width, height)
181
 
182
- # ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ
183
- if not prompt or prompt.strip() == "":
184
- prompt = "plain white background"
185
-
 
 
 
 
186
  with timer("Background generation"):
187
- try:
188
  image = pipe(
189
  prompt=prompt,
190
  width=width,
191
  height=height,
192
  num_inference_steps=8,
193
  guidance_scale=4.0,
194
- max_length=77, # CLIP ํ…์ŠคํŠธ ์ธ์ฝ”๋”์˜ ์ตœ๋Œ€ ๊ธธ์ด ์ œํ•œ
195
  ).images[0]
196
- except Exception as e:
197
- print(f"Pipeline error: {str(e)}")
198
- # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ๊ธฐ๋ณธ ํฐ์ƒ‰ ๋ฐฐ๊ฒฝ ์ƒ์„ฑ
199
- image = Image.new('RGB', (width, height), 'white')
200
 
201
  return image
202
  except Exception as e:
203
  print(f"Background generation error: {str(e)}")
204
- # ์ตœํ›„์˜ ํด๋ฐฑ: ๊ธฐ๋ณธ ํฐ์ƒ‰ ๋ฐฐ๊ฒฝ ๋ฐ˜ํ™˜
205
  return Image.new('RGB', (512, 512), 'white')
206
 
207
  # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ๋ถ€๋ถ„ ์ˆ˜์ •
@@ -296,34 +302,32 @@ def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Im
296
 
297
  def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
298
  try:
299
- if img.width > 2048 or img.height > 2048:
300
- orig_res = max(img.width, img.height)
301
- img.thumbnail((2048, 2048))
302
- if isinstance(prompt, tuple):
303
- x0, y0, x1, y1 = (int(x * 2048 / orig_res) for x in prompt)
304
- prompt = (x0, y0, x1, y1)
 
 
 
 
305
 
306
- mask, bbox, time_log = _gpu_process(img, prompt)
307
- masked_alpha = apply_mask(img, mask, defringe=True)
 
308
 
309
  if bg_prompt:
310
- # ๋ฐฐ๊ฒฝ๋งŒ ์ƒ์„ฑํ•˜๏ฟฝ๏ฟฝ๏ฟฝ ํ•ฉ์„ฑ์€ ํ•˜์ง€ ์•Š์Œ
311
  background = generate_background(bg_prompt, aspect_ratio)
312
- combined = background # ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€๋งŒ ๋ฐ˜ํ™˜
313
  else:
314
  combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
315
 
316
- thresholded = mask.point(lambda p: 255 if p > 10 else 0)
317
- bbox = thresholded.getbbox()
318
- to_dl = masked_alpha.crop(bbox)
319
-
320
- temp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
321
- to_dl.save(temp, format="PNG")
322
- temp.close()
323
 
324
  return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
325
-
326
  except Exception as e:
 
327
  raise gr.Error(f"Processing failed: {str(e)}")
328
 
329
  def on_change_bbox(prompts: dict[str, Any] | None):
@@ -683,12 +687,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
683
  </div>
684
  </div>
685
  """)
686
- demo.queue(max_size=20) # ํ ํฌ๊ธฐ ์ œํ•œ
 
687
  demo.launch(
688
  server_name="0.0.0.0",
689
  server_port=7860,
690
  share=False,
691
  enable_queue=True,
692
- max_threads=4, # ์Šค๋ ˆ๋“œ ์ˆ˜ ์ œํ•œ
693
- allowed_paths=["examples"]
 
694
  )
 
23
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
24
 
25
  import gc
26
+ import torch.cuda.amp as amp
27
 
28
+ # ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ํ•จ์ˆ˜ ๊ฐ•ํ™”
29
  def clear_memory():
30
  """๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
31
+ if torch.cuda.is_available():
32
+ torch.cuda.empty_cache()
33
+ torch.cuda.synchronize()
34
  gc.collect()
35
+
36
+ # ์ž๋™ ํ˜ผํ•ฉ ์ •๋ฐ€๋„(Automatic Mixed Precision) ์„ค์ •
37
+ scaler = amp.GradScaler()
38
 
39
 
40
  model_name = "Helsinki-NLP/opus-mt-ko-en"
 
85
  gd_model = gd_model.to(device=device)
86
  assert isinstance(gd_model, GroundingDinoForObjectDetection)
87
 
88
+ # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
89
  pipe = FluxPipeline.from_pretrained(
90
  "black-forest-labs/FLUX.1-dev",
91
+ torch_dtype=torch.float16, # A100์— ์ตœ์ ํ™”๋œ float16 ์‚ฌ์šฉ
92
  use_auth_token=HF_TOKEN,
93
+ device_map="balanced"
94
  )
95
+ pipe.enable_attention_slicing(slice_size="auto") # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ตœ์ ํ™”
96
+ pipe.enable_sequential_cpu_offload() # CPU ์˜คํ”„๋กœ๋”ฉ ํ™œ์„ฑํ™”
97
+
98
+
99
  pipe.load_lora_weights(
100
  hf_hub_download(
101
  "ByteDance/Hyper-SD",
 
105
  )
106
  pipe.fuse_lora(lora_scale=0.125)
107
 
108
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0" # ๋‹จ์ผ GPU ์‚ฌ์šฉ
109
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" # CUDA ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ์„ค์ •
110
 
111
 
112
  class timer:
 
182
  return base_size, base_size
183
 
184
  def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
 
185
  try:
 
186
  width, height = calculate_dimensions(aspect_ratio)
 
 
187
  width, height = adjust_size_to_multiple_of_8(width, height)
188
 
189
+ # A100 ๋ฉ”๋ชจ๋ฆฌ ์ œํ•œ์„ ๊ณ ๋ คํ•œ ์ตœ๋Œ€ ํฌ๊ธฐ ์„ค์ •
190
+ max_size = 768
191
+ if width > max_size or height > max_size:
192
+ ratio = max_size / max(width, height)
193
+ width = int(width * ratio)
194
+ height = int(height * ratio)
195
+ width, height = adjust_size_to_multiple_of_8(width, height)
196
+
197
  with timer("Background generation"):
198
+ with torch.cuda.amp.autocast(): # ์ž๋™ ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ์‚ฌ์šฉ
199
  image = pipe(
200
  prompt=prompt,
201
  width=width,
202
  height=height,
203
  num_inference_steps=8,
204
  guidance_scale=4.0,
205
+ max_length=77,
206
  ).images[0]
 
 
 
 
207
 
208
  return image
209
  except Exception as e:
210
  print(f"Background generation error: {str(e)}")
 
211
  return Image.new('RGB', (512, 512), 'white')
212
 
213
  # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ๋ถ€๋ถ„ ์ˆ˜์ •
 
302
 
303
  def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
304
  try:
305
+ # ์ž…๋ ฅ ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ œํ•œ
306
+ max_size = 1024
307
+ if img.width > max_size or img.height > max_size:
308
+ ratio = max_size / max(img.width, img.height)
309
+ new_size = (int(img.width * ratio), int(img.height * ratio))
310
+ img = img.resize(new_size, Image.LANCZOS)
311
+
312
+ # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋ชจ๋‹ˆํ„ฐ๋ง
313
+ if torch.cuda.is_available():
314
+ torch.cuda.reset_peak_memory_stats()
315
 
316
+ with torch.cuda.amp.autocast():
317
+ mask, bbox, time_log = _gpu_process(img, prompt)
318
+ masked_alpha = apply_mask(img, mask, defringe=True)
319
 
320
  if bg_prompt:
 
321
  background = generate_background(bg_prompt, aspect_ratio)
322
+ combined = background
323
  else:
324
  combined = Image.alpha_composite(Image.new("RGBA", masked_alpha.size, "white"), masked_alpha)
325
 
326
+ clear_memory() # ์ค‘๊ฐ„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
 
 
 
 
 
 
327
 
328
  return (img, combined, masked_alpha), gr.DownloadButton(value=temp.name, interactive=True)
 
329
  except Exception as e:
330
+ clear_memory()
331
  raise gr.Error(f"Processing failed: {str(e)}")
332
 
333
  def on_change_bbox(prompts: dict[str, Any] | None):
 
687
  </div>
688
  </div>
689
  """)
690
+
691
+ demo.queue(max_size=10) # ํ ํฌ๊ธฐ ์ œํ•œ
692
  demo.launch(
693
  server_name="0.0.0.0",
694
  server_port=7860,
695
  share=False,
696
  enable_queue=True,
697
+ max_threads=2, # ์Šค๋ ˆ๋“œ ์ˆ˜ ์ œํ•œ
698
+ allowed_paths=["examples"],
699
+ memory_limit=0.8 # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ œํ•œ (80%)
700
  )