aiqtech commited on
Commit
0212dde
ยท
verified ยท
1 Parent(s): 36fedb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -59
app.py CHANGED
@@ -23,63 +23,55 @@ import gc
23
  from PIL import Image, ImageDraw, ImageFont
24
 
25
  def clear_memory():
26
- """๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜ - Spaces GPU ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์ˆ˜์ •"""
27
  gc.collect()
28
- if torch.cuda.is_available():
29
- try:
30
- with torch.cuda.device('cuda:0'): # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
31
- torch.cuda.empty_cache()
32
- except Exception as e:
33
- print(f"GPU memory management warning: {e}")
 
34
 
35
  def initialize_models():
36
- """๋ชจ๋ธ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ - Spaces GPU ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์ˆ˜์ •"""
37
  global segmenter, gd_model, gd_processor, pipe, translator
38
 
39
  try:
40
- # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
41
  clear_memory()
42
 
43
- # ๋ฒˆ์—ญ ๋ชจ๋ธ์€ CPU์—์„œ๋งŒ ์‹คํ–‰
44
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to('cpu')
 
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
- translator = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
 
47
 
48
- # GroundingDINO ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
49
  gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
50
  gd_model = GroundingDinoForObjectDetection.from_pretrained(
51
  gd_model_path,
52
  torch_dtype=torch.float16,
53
- device_map='cuda:0' # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
 
54
  )
55
 
56
- # Segmenter ์ดˆ๊ธฐํ™”
57
- segmenter = BoxSegmenter(device='cuda:0') # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
58
 
59
- # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
60
  pipe = FluxPipeline.from_pretrained(
61
  "black-forest-labs/FLUX.1-dev",
62
  torch_dtype=torch.float16,
63
- token=HF_TOKEN
 
64
  )
65
- pipe.enable_attention_slicing(slice_size="auto")
 
66
 
67
- # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
68
- pipe.load_lora_weights(
69
- hf_hub_download(
70
- "ByteDance/Hyper-SD",
71
- "Hyper-FLUX.1-dev-8steps-lora.safetensors",
72
- token=HF_TOKEN
73
- )
74
- )
75
- pipe.fuse_lora(lora_scale=0.125)
76
-
77
- if torch.cuda.is_available():
78
- pipe = pipe.to('cuda:0') # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
79
-
80
  except Exception as e:
81
  print(f"Model initialization error: {str(e)}")
82
  raise
 
83
  # GPU ์„ค์ •
84
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
85
 
@@ -234,32 +226,29 @@ def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int,
234
  return base_size * 4 // 3, base_size
235
  return base_size, base_size
236
 
237
- @spaces.GPU(duration=20) # 40์ดˆ์—์„œ 20์ดˆ๋กœ ๊ฐ์†Œ
238
  def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
239
  try:
240
  width, height = calculate_dimensions(aspect_ratio)
241
  width, height = adjust_size_to_multiple_of_8(width, height)
242
 
243
- max_size = 768
 
244
  if width > max_size or height > max_size:
245
  ratio = max_size / max(width, height)
246
  width = int(width * ratio)
247
  height = int(height * ratio)
248
  width, height = adjust_size_to_multiple_of_8(width, height)
249
 
250
- with timer("Background generation"):
251
- try:
252
- with torch.inference_mode():
253
- image = pipe(
254
- prompt=prompt,
255
- width=width,
256
- height=height,
257
- num_inference_steps=8,
258
- guidance_scale=4.0
259
- ).images[0]
260
- except Exception as e:
261
- print(f"Pipeline error: {str(e)}")
262
- return Image.new('RGB', (width, height), 'white')
263
 
264
  return image
265
  except Exception as e:
@@ -460,7 +449,7 @@ def update_box_button(img, box_input):
460
  except:
461
  return gr.update(interactive=False, variant="secondary")
462
 
463
- def process_image(img: Image.Image, max_size: int = 1024) -> Image.Image:
464
  """์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ตœ์ ํ™”"""
465
  if img.width > max_size or img.height > max_size:
466
  ratio = max_size / max(img.width, img.height)
@@ -824,14 +813,15 @@ if __name__ == "__main__":
824
  queue=True
825
  )
826
 
827
- demo.queue(max_size=3)
828
- demo.launch(
829
- server_name="0.0.0.0",
830
- server_port=7860,
831
- share=False,
832
- max_threads=2,
833
- enable_queue=True,
834
- cache_examples=False,
835
- show_error=True,
836
- show_tips=False
837
- )
 
 
23
  from PIL import Image, ImageDraw, ImageFont
24
 
25
  def clear_memory():
26
+ """๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
27
  gc.collect()
28
+ torch.cuda.empty_cache()
29
+
30
+ # ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ์บ์‹œ ์ •๋ฆฌ
31
+ if hasattr(torch.cuda, 'empty_cache'):
32
+ torch.cuda.empty_cache()
33
+ if hasattr(torch.cuda, 'ipc_collect'):
34
+ torch.cuda.ipc_collect()
35
 
36
  def initialize_models():
 
37
  global segmenter, gd_model, gd_processor, pipe, translator
38
 
39
  try:
 
40
  clear_memory()
41
 
42
+ # CPU์—์„œ๋งŒ ์‹คํ–‰๋˜๋Š” ๋ฒˆ์—ญ ๋ชจ๋ธ
43
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
44
+ low_cpu_mem_usage=True).to('cpu')
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+ translator = pipeline("translation", model=model, tokenizer=tokenizer,
47
+ device=-1)
48
 
49
+ # GroundingDINO ๋ชจ๋ธ
50
  gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
51
  gd_model = GroundingDinoForObjectDetection.from_pretrained(
52
  gd_model_path,
53
  torch_dtype=torch.float16,
54
+ device_map='auto',
55
+ low_cpu_mem_usage=True
56
  )
57
 
58
+ # Segmenter
59
+ segmenter = BoxSegmenter(device='cpu')
60
 
61
+ # FLUX ํŒŒ์ดํ”„๋ผ์ธ
62
  pipe = FluxPipeline.from_pretrained(
63
  "black-forest-labs/FLUX.1-dev",
64
  torch_dtype=torch.float16,
65
+ device_map='auto',
66
+ low_cpu_mem_usage=True
67
  )
68
+ pipe.enable_attention_slicing()
69
+ pipe.enable_model_cpu_offload()
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  except Exception as e:
72
  print(f"Model initialization error: {str(e)}")
73
  raise
74
+
75
  # GPU ์„ค์ •
76
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
77
 
 
226
  return base_size * 4 // 3, base_size
227
  return base_size, base_size
228
 
229
+ @spaces.GPU(duration=20)
230
  def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
231
  try:
232
  width, height = calculate_dimensions(aspect_ratio)
233
  width, height = adjust_size_to_multiple_of_8(width, height)
234
 
235
+ # ์ตœ๋Œ€ ํฌ๊ธฐ ์ œํ•œ
236
+ max_size = 512 # 768์—์„œ 512๋กœ ๊ฐ์†Œ
237
  if width > max_size or height > max_size:
238
  ratio = max_size / max(width, height)
239
  width = int(width * ratio)
240
  height = int(height * ratio)
241
  width, height = adjust_size_to_multiple_of_8(width, height)
242
 
243
+ with torch.inference_mode():
244
+ image = pipe(
245
+ prompt=prompt,
246
+ width=width,
247
+ height=height,
248
+ num_inference_steps=4, # 8์—์„œ 4๋กœ ๊ฐ์†Œ
249
+ guidance_scale=4.0,
250
+ batch_size=1
251
+ ).images[0]
 
 
 
 
252
 
253
  return image
254
  except Exception as e:
 
449
  except:
450
  return gr.update(interactive=False, variant="secondary")
451
 
452
+ def process_image(img: Image.Image, max_size: int = 768) -> Image.Image:
453
  """์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ตœ์ ํ™”"""
454
  if img.width > max_size or img.height > max_size:
455
  ratio = max_size / max(img.width, img.height)
 
813
  queue=True
814
  )
815
 
816
+ demo.launch(
817
+ server_name="0.0.0.0",
818
+ server_port=7860,
819
+ share=False,
820
+ max_threads=2,
821
+ enable_queue=True,
822
+ cache_examples=False,
823
+ show_error=True,
824
+ show_tips=False,
825
+ max_size=1, # ํ ํฌ๊ธฐ ์ œํ•œ
826
+ memory_limit="48Gi" # ๋ฉ”๋ชจ๋ฆฌ ์ œํ•œ ์„ค์ •
827
+ )