ginipick commited on
Commit
c0f2e23
ยท
verified ยท
1 Parent(s): deb56c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -17
app.py CHANGED
@@ -25,7 +25,6 @@ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
25
  import gc
26
  import torch.cuda.amp as amp
27
 
28
- # ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ํ•จ์ˆ˜ ๊ฐ•ํ™”
29
  def clear_memory():
30
  """๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
31
  if torch.cuda.is_available():
@@ -34,7 +33,11 @@ def clear_memory():
34
  gc.collect()
35
 
36
  # ์ž๋™ ํ˜ผํ•ฉ ์ •๋ฐ€๋„(Automatic Mixed Precision) ์„ค์ •
37
- scaler = amp.GradScaler()
 
 
 
 
38
 
39
 
40
  model_name = "Helsinki-NLP/opus-mt-ko-en"
@@ -85,17 +88,16 @@ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_
85
  gd_model = gd_model.to(device=device)
86
  assert isinstance(gd_model, GroundingDinoForObjectDetection)
87
 
 
88
  # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
89
  pipe = FluxPipeline.from_pretrained(
90
  "black-forest-labs/FLUX.1-dev",
91
  torch_dtype=torch.float16, # A100์— ์ตœ์ ํ™”๋œ float16 ์‚ฌ์šฉ
92
- use_auth_token=HF_TOKEN,
93
- device_map="balanced"
94
  )
95
  pipe.enable_attention_slicing(slice_size="auto") # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ตœ์ ํ™”
96
- pipe.enable_sequential_cpu_offload() # CPU ์˜คํ”„๋กœ๋”ฉ ํ™œ์„ฑํ™”
97
-
98
 
 
99
  pipe.load_lora_weights(
100
  hf_hub_download(
101
  "ByteDance/Hyper-SD",
@@ -105,6 +107,13 @@ pipe.load_lora_weights(
105
  )
106
  pipe.fuse_lora(lora_scale=0.125)
107
 
 
 
 
 
 
 
 
108
  os.environ["CUDA_VISIBLE_DEVICES"] = "0" # ๋‹จ์ผ GPU ์‚ฌ์šฉ
109
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" # CUDA ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ์„ค์ •
110
 
@@ -195,17 +204,19 @@ def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
195
  width, height = adjust_size_to_multiple_of_8(width, height)
196
 
197
  with timer("Background generation"):
198
- with torch.cuda.amp.autocast(): # ์ž๋™ ํ˜ผํ•ฉ ์ •๋ฐ€๋„ ์‚ฌ์šฉ
199
- image = pipe(
200
- prompt=prompt,
201
- width=width,
202
- height=height,
203
- num_inference_steps=8,
204
- guidance_scale=4.0,
205
- max_length=77,
206
- ).images[0]
207
-
208
- return image
 
 
209
  except Exception as e:
210
  print(f"Background generation error: {str(e)}")
211
  return Image.new('RGB', (512, 512), 'white')
 
25
  import gc
26
  import torch.cuda.amp as amp
27
 
 
28
  def clear_memory():
29
  """๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ํ•จ์ˆ˜"""
30
  if torch.cuda.is_available():
 
33
  gc.collect()
34
 
35
  # ์ž๋™ ํ˜ผํ•ฉ ์ •๋ฐ€๋„(Automatic Mixed Precision) ์„ค์ •
36
+ if torch.cuda.is_available():
37
+ scaler = torch.amp.GradScaler('cuda')
38
+ else:
39
+ scaler = None
40
+
41
 
42
 
43
  model_name = "Helsinki-NLP/opus-mt-ko-en"
 
88
  gd_model = gd_model.to(device=device)
89
  assert isinstance(gd_model, GroundingDinoForObjectDetection)
90
 
91
+
92
  # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”
93
  pipe = FluxPipeline.from_pretrained(
94
  "black-forest-labs/FLUX.1-dev",
95
  torch_dtype=torch.float16, # A100์— ์ตœ์ ํ™”๋œ float16 ์‚ฌ์šฉ
96
+ use_auth_token=HF_TOKEN
 
97
  )
98
  pipe.enable_attention_slicing(slice_size="auto") # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ตœ์ ํ™”
 
 
99
 
100
+ # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
101
  pipe.load_lora_weights(
102
  hf_hub_download(
103
  "ByteDance/Hyper-SD",
 
107
  )
108
  pipe.fuse_lora(lora_scale=0.125)
109
 
110
+ # GPU ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™”
111
+ if torch.cuda.is_available():
112
+ pipe.to("cuda")
113
+ pipe.enable_vae_slicing() # VAE ์Šฌ๋ผ์ด์‹ฑ ํ™œ์„ฑํ™”
114
+
115
+
116
+
117
  os.environ["CUDA_VISIBLE_DEVICES"] = "0" # ๋‹จ์ผ GPU ์‚ฌ์šฉ
118
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" # CUDA ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ์„ค์ •
119
 
 
204
  width, height = adjust_size_to_multiple_of_8(width, height)
205
 
206
  with timer("Background generation"):
207
+ with torch.inference_mode(): # inference_mode ์‚ฌ์šฉ
208
+ with torch.cuda.amp.autocast():
209
+ image = pipe(
210
+ prompt=prompt,
211
+ width=width,
212
+ height=height,
213
+ num_inference_steps=8,
214
+ guidance_scale=4.0,
215
+ max_length=77,
216
+ ).images[0]
217
+
218
+ clear_memory() # ์ฆ‰์‹œ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
219
+ return image
220
  except Exception as e:
221
  print(f"Background generation error: {str(e)}")
222
  return Image.new('RGB', (512, 512), 'white')