aiqtech commited on
Commit
5b55adf
ยท
verified ยท
1 Parent(s): 958fc4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -39
app.py CHANGED
@@ -37,37 +37,52 @@ def initialize_models():
37
  global segmenter, gd_model, gd_processor, pipe, translator
38
 
39
  try:
40
- # CPU์—์„œ ์‹คํ–‰๋˜๋Š” ๋ฒˆ์—ญ ๋ชจ๋ธ
41
- model = AutoModelForSeq2SeqLM.from_pretrained(
42
- model_name,
43
- low_cpu_mem_usage=True
 
44
  ).to('cpu')
45
  tokenizer = AutoTokenizer.from_pretrained(model_name)
46
  translator = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
 
47
 
48
- # GroundingDINO ๋ชจ๋ธ
49
- gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
50
- gd_model = GroundingDinoForObjectDetection.from_pretrained(
51
- gd_model_path,
52
- torch_dtype=torch.float16,
53
- device_map=None # device_map์„ None์œผ๋กœ ์„ค์ •
54
  )
 
55
 
56
- # Segmenter
57
- segmenter = BoxSegmenter(device='cpu')
58
 
59
- # FLUX ํŒŒ์ดํ”„๋ผ์ธ
60
  pipe = FluxPipeline.from_pretrained(
61
  "black-forest-labs/FLUX.1-dev",
62
  torch_dtype=torch.float16,
63
- device_map=None, # device_map์„ None์œผ๋กœ ์„ค์ •
64
- low_cpu_mem_usage=True
65
  )
66
- pipe.enable_attention_slicing()
 
67
 
68
  except Exception as e:
69
  print(f"Model initialization error: {str(e)}")
70
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # GPU ์„ค์ •
73
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
@@ -374,26 +389,42 @@ def on_change_bbox(prompts: dict[str, Any] | None):
374
  def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
375
  return gr.update(interactive=bool(img and prompt))
376
 
377
-
378
- @spaces.GPU()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None,
380
  aspect_ratio: str = "1:1", position: str = "bottom-center",
381
  scale_percent: float = 100, text_params: dict | None = None):
382
  try:
383
- # GPU ์„ค์ •
384
- if torch.cuda.is_available():
385
- device = torch.device('cuda')
386
- # ๋ชจ๋ธ๋“ค์„ GPU๋กœ ์ด๋™
387
- gd_model.to(device)
388
- segmenter.to(device)
389
- pipe.to(device)
390
- else:
391
- device = torch.device('cpu')
392
 
393
- # ๋‚˜๋จธ์ง€ ์ฒ˜๋ฆฌ ๋กœ์ง...
 
 
394
 
 
 
 
 
395
  finally:
396
- # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
 
397
  if torch.cuda.is_available():
398
  try:
399
  with torch.cuda.device('cuda'):
@@ -440,13 +471,7 @@ def update_box_button(img, box_input):
440
  except:
441
  return gr.update(interactive=False, variant="secondary")
442
 
443
- def process_image(img: Image.Image, max_size: int = 768) -> Image.Image:
444
- """์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ตœ์ ํ™”"""
445
- if img.width > max_size or img.height > max_size:
446
- ratio = max_size / max(img.width, img.height)
447
- new_size = (int(img.width * ratio), int(img.height * ratio))
448
- return img.resize(new_size, Image.LANCZOS)
449
- return img
450
 
451
  # CSS ์ •์˜
452
  css = """
@@ -804,14 +829,15 @@ if __name__ == "__main__":
804
  queue=True
805
  )
806
 
807
- # Gradio ์‹คํ–‰ ์„ค์ • ์ˆ˜์ •
808
  demo.launch(
809
  server_name="0.0.0.0",
810
  server_port=7860,
811
  share=False,
812
- max_threads=2,
813
  enable_queue=True,
814
  cache_examples=False,
815
  show_error=True,
816
- show_tips=False
 
817
  )
 
37
  global segmenter, gd_model, gd_processor, pipe, translator
38
 
39
  try:
40
+ # ๋ฒˆ์—ญ ๋ชจ๋ธ - ๊ฐ€๋ฒผ์šด ๋ฒ„์ „ ์‚ฌ์šฉ
41
+ model = AutoModelForSeq2SeqLength.from_pretrained(
42
+ model_name,
43
+ low_cpu_mem_usage=True,
44
+ torch_dtype=torch.float16
45
  ).to('cpu')
46
  tokenizer = AutoTokenizer.from_pretrained(model_name)
47
  translator = pipeline("translation", model=model, tokenizer=tokenizer, device=-1)
48
+ del model # ๋ช…์‹œ์  ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ
49
 
50
+ # GroundingDINO - ๋” ์ž‘์€ ๋ชจ๋ธ ์‚ฌ์šฉ
51
+ gd_processor = GroundingDinoProcessor.from_pretrained(
52
+ "IDEA-Research/grounding-dino-base", # ๋” ์ž‘์€ base ๋ชจ๋ธ
53
+ torch_dtype=torch.float16
 
 
54
  )
55
+ gd_model = None # ํ•„์š”ํ•  ๋•Œ ๋กœ๋“œ
56
 
57
+ # Segmenter - ๊ธฐ๋ณธ ์„ค์ •
58
+ segmenter = None # ํ•„์š”ํ•  ๋•Œ ๋กœ๋“œ
59
 
60
+ # FLUX ํŒŒ์ดํ”„๋ผ์ธ - ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์  ์„ค์ •
61
  pipe = FluxPipeline.from_pretrained(
62
  "black-forest-labs/FLUX.1-dev",
63
  torch_dtype=torch.float16,
64
+ low_cpu_mem_usage=True,
65
+ use_safetensors=True
66
  )
67
+ pipe.enable_attention_slicing(slice_size=1)
68
+ pipe.enable_sequential_cpu_offload()
69
 
70
  except Exception as e:
71
  print(f"Model initialization error: {str(e)}")
72
  raise
73
+
74
+ def load_model_on_demand(model_type: str):
75
+ """ํ•„์š”ํ•  ๋•Œ๋งŒ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๋Š” ํ•จ์ˆ˜"""
76
+ global gd_model, segmenter
77
+
78
+ if model_type == "gd" and gd_model is None:
79
+ gd_model = GroundingDinoForObjectDetection.from_pretrained(
80
+ "IDEA-Research/grounding-dino-base",
81
+ torch_dtype=torch.float16,
82
+ low_cpu_mem_usage=True
83
+ )
84
+ elif model_type == "segmenter" and segmenter is None:
85
+ segmenter = BoxSegmenter(device='cpu')
86
 
87
  # GPU ์„ค์ •
88
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
 
389
  def on_change_prompt(img: Image.Image | None, prompt: str | None, bg_prompt: str | None = None):
390
  return gr.update(interactive=bool(img and prompt))
391
 
392
+ def process_image(img: Image.Image) -> Image.Image:
393
+ """์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ ์ตœ์ ํ™”"""
394
+ # ์ตœ๋Œ€ ํฌ๊ธฐ ์ œํ•œ
395
+ max_size = 512 # ๋” ์ž‘์€ ํฌ๊ธฐ๋กœ ์ œํ•œ
396
+ if img.width > max_size or img.height > max_size:
397
+ ratio = max_size / max(img.width, img.height)
398
+ new_size = (int(img.width * ratio), int(img.height * ratio))
399
+ img = img.resize(new_size, Image.LANCZOS)
400
+
401
+ # ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ ์œ„ํ•œ ์ด๋ฏธ์ง€ ๋ชจ๋“œ ๋ณ€ํ™˜
402
+ if img.mode in ['RGBA', 'LA']:
403
+ background = Image.new('RGB', img.size, (255, 255, 255))
404
+ background.paste(img, mask=img.split()[-1])
405
+ img = background
406
+
407
+ return img
408
+
409
+ @spaces.GPU(duration=15) # ๋” ์งง์€ ์‹œ๊ฐ„ ์ œํ•œ
410
  def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None,
411
  aspect_ratio: str = "1:1", position: str = "bottom-center",
412
  scale_percent: float = 100, text_params: dict | None = None):
413
  try:
414
+ # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
415
+ img = process_image(img)
 
 
 
 
 
 
 
416
 
417
+ # ํ•„์š”ํ•œ ๋ชจ๋ธ๋งŒ ๋กœ๋“œ
418
+ load_model_on_demand("gd")
419
+ load_model_on_demand("segmenter")
420
 
421
+ with torch.cuda.amp.autocast(): # ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ ์œ„ํ•œ mixed precision
422
+ # ์ฒ˜๋ฆฌ ๋กœ์ง...
423
+ pass
424
+
425
  finally:
426
+ # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
427
+ clear_memory()
428
  if torch.cuda.is_available():
429
  try:
430
  with torch.cuda.device('cuda'):
 
471
  except:
472
  return gr.update(interactive=False, variant="secondary")
473
 
474
+
 
 
 
 
 
 
475
 
476
  # CSS ์ •์˜
477
  css = """
 
829
  queue=True
830
  )
831
 
832
+ demo.queue(max_size=1) # ํ ํฌ๊ธฐ ์ œํ•œ
833
  demo.launch(
834
  server_name="0.0.0.0",
835
  server_port=7860,
836
  share=False,
837
+ max_threads=1, # ์Šค๋ ˆ๋“œ ์ˆ˜ ์ œํ•œ
838
  enable_queue=True,
839
  cache_examples=False,
840
  show_error=True,
841
+ show_tips=False,
842
+ quiet=True
843
  )