aiqtech commited on
Commit
36fedb3
ยท
verified ยท
1 Parent(s): b785f0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -21
app.py CHANGED
@@ -22,9 +22,19 @@ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
22
  import gc
23
  from PIL import Image, ImageDraw, ImageFont
24
 
25
- # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ถ€๋ถ„ ์ˆ˜์ •
 
 
 
 
 
 
 
 
 
26
  def initialize_models():
27
- global segmenter, gd_model, gd_processor, pipe
 
28
 
29
  try:
30
  # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
@@ -39,14 +49,12 @@ def initialize_models():
39
  gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
40
  gd_model = GroundingDinoForObjectDetection.from_pretrained(
41
  gd_model_path,
42
- torch_dtype=torch.float16, # float32 ๋Œ€์‹  float16 ์‚ฌ์šฉ
43
- device_map="auto" # ์ž๋™ ๋””๋ฐ”์ด์Šค ๋งคํ•‘
44
  )
45
 
46
  # Segmenter ์ดˆ๊ธฐํ™”
47
- segmenter = BoxSegmenter(device="cpu")
48
- if torch.cuda.is_available():
49
- segmenter.to(device)
50
 
51
  # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
52
  pipe = FluxPipeline.from_pretrained(
@@ -61,27 +69,17 @@ def initialize_models():
61
  hf_hub_download(
62
  "ByteDance/Hyper-SD",
63
  "Hyper-FLUX.1-dev-8steps-lora.safetensors",
64
- use_auth_token=HF_TOKEN
65
  )
66
  )
67
  pipe.fuse_lora(lora_scale=0.125)
68
 
69
  if torch.cuda.is_available():
70
- pipe = pipe.to("cuda:0")
71
 
72
  except Exception as e:
73
  print(f"Model initialization error: {str(e)}")
74
  raise
75
-
76
- def clear_memory():
77
- """๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ๊ฐ•ํ™” ํ•จ์ˆ˜"""
78
- gc.collect()
79
- torch.cuda.empty_cache()
80
-
81
- if torch.cuda.is_available():
82
- with torch.cuda.device(0):
83
- torch.cuda.reset_peak_memory_stats()
84
- torch.cuda.empty_cache()
85
  # GPU ์„ค์ •
86
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
87
 
@@ -605,8 +603,20 @@ def update_process_button(img, prompt):
605
  interactive=bool(img and prompt),
606
  variant="primary" if bool(img and prompt) else "secondary"
607
  )
608
-
609
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
611
  initialize_models()
612
 
@@ -814,7 +824,6 @@ if __name__ == "__main__":
814
  queue=True
815
  )
816
 
817
- # Gradio ์•ฑ ์‹คํ–‰
818
  demo.queue(max_size=3)
819
  demo.launch(
820
  server_name="0.0.0.0",
 
22
  import gc
23
  from PIL import Image, ImageDraw, ImageFont
24
 
25
+ def clear_memory():
26
+ """๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜ - Spaces GPU ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์ˆ˜์ •"""
27
+ gc.collect()
28
+ if torch.cuda.is_available():
29
+ try:
30
+ with torch.cuda.device('cuda:0'): # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
31
+ torch.cuda.empty_cache()
32
+ except Exception as e:
33
+ print(f"GPU memory management warning: {e}")
34
+
35
  def initialize_models():
36
+ """๋ชจ๋ธ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ - Spaces GPU ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์ˆ˜์ •"""
37
+ global segmenter, gd_model, gd_processor, pipe, translator
38
 
39
  try:
40
  # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
 
49
  gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
50
  gd_model = GroundingDinoForObjectDetection.from_pretrained(
51
  gd_model_path,
52
+ torch_dtype=torch.float16,
53
+ device_map='cuda:0' # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
54
  )
55
 
56
  # Segmenter ์ดˆ๊ธฐํ™”
57
+ segmenter = BoxSegmenter(device='cuda:0') # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
 
 
58
 
59
  # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
60
  pipe = FluxPipeline.from_pretrained(
 
69
  hf_hub_download(
70
  "ByteDance/Hyper-SD",
71
  "Hyper-FLUX.1-dev-8steps-lora.safetensors",
72
+ token=HF_TOKEN
73
  )
74
  )
75
  pipe.fuse_lora(lora_scale=0.125)
76
 
77
  if torch.cuda.is_available():
78
+ pipe = pipe.to('cuda:0') # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
79
 
80
  except Exception as e:
81
  print(f"Model initialization error: {str(e)}")
82
  raise
 
 
 
 
 
 
 
 
 
 
83
  # GPU ์„ค์ •
84
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์ง€์ •
85
 
 
603
  interactive=bool(img and prompt),
604
  variant="primary" if bool(img and prompt) else "secondary"
605
  )
 
606
  if __name__ == "__main__":
607
+ # CUDA ์„ค์ •
608
+ if torch.cuda.is_available():
609
+ try:
610
+ torch.cuda.set_device('cuda:0') # ๋ช…์‹œ์ ์œผ๋กœ cuda:0 ์„ค์ •
611
+ torch.backends.cudnn.benchmark = True
612
+ torch.backends.cuda.matmul.allow_tf32 = True
613
+ except Exception as e:
614
+ print(f"CUDA setup warning: {e}")
615
+
616
+ # HF ํ† ํฐ ์„ค์ •
617
+ if HF_TOKEN:
618
+ login(token=HF_TOKEN, add_to_git_credential=False)
619
+
620
  # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
621
  initialize_models()
622
 
 
824
  queue=True
825
  )
826
 
 
827
  demo.queue(max_size=3)
828
  demo.launch(
829
  server_name="0.0.0.0",