openfree commited on
Commit
c23ced1
ยท
verified ยท
1 Parent(s): a0fcee4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -29
app.py CHANGED
@@ -71,19 +71,23 @@ def translate_to_english(text: str) -> str:
71
  print(f"Translation error: {str(e)}")
72
  return text
73
 
 
74
  print("Initializing FLUX pipeline...")
75
  try:
76
  pipe = FluxPipeline.from_pretrained(
77
  "black-forest-labs/FLUX.1-dev",
78
- torch_dtype=torch.float16,
79
  use_auth_token=HF_TOKEN,
80
- safety_checker=None # ์•ˆ์ „์„ฑ ๊ฒ€์‚ฌ ๋น„ํ™œ์„ฑํ™”
 
 
81
  )
82
  print("FLUX pipeline initialized successfully")
83
 
84
- # ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ์„ค์ •
85
- pipe.enable_attention_slicing(slice_size="auto")
86
  pipe.enable_model_cpu_offload() # CPU ์˜คํ”„๋กœ๋”ฉ ํ™œ์„ฑํ™”
 
87
  print("Pipeline optimization settings applied")
88
 
89
  except Exception as e:
@@ -93,34 +97,28 @@ except Exception as e:
93
  # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ๋ถ€๋ถ„ ์ˆ˜์ •
94
  print("Loading LoRA weights...")
95
  try:
96
- # LoRA ํŒŒ์ผ ๊ฒฝ๋กœ ํ™•์ธ
97
  lora_path = hf_hub_download(
98
  repo_id="openfree/myt-flux-fantasy",
99
- filename="myt-flux-fantasy.safetensors", # ์ •ํ™•ํ•œ ํŒŒ์ผ๋ช… ํ™•์ธ ํ•„์š”
100
  use_auth_token=HF_TOKEN
101
  )
102
  print(f"LoRA weights downloaded to: {lora_path}")
103
 
104
- # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ
105
- pipe.load_lora_weights(lora_path)
106
  pipe.fuse_lora(lora_scale=0.125)
 
 
 
 
 
107
  print("LoRA weights loaded and fused successfully")
108
 
109
  except Exception as e:
110
  print(f"Error loading LoRA weights: {str(e)}")
111
- print(f"Full error details: {repr(e)}")
112
- raise ValueError("Failed to load LoRA weights. Please check your HF_TOKEN and model access.")
113
 
114
- # GPU ์ด๋™ ๋ถ€๋ถ„ ์ˆ˜์ •
115
- if torch.cuda.is_available():
116
- try:
117
- print("Moving pipeline to GPU...")
118
- pipe = pipe.to("cuda:0")
119
- print("Pipeline successfully moved to GPU")
120
- print(f"Current device: {pipe.device}")
121
- except Exception as e:
122
- print(f"Warning: Could not move pipeline to CUDA: {str(e)}")
123
- print("Falling back to CPU")
124
 
125
 
126
  # ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
@@ -139,6 +137,7 @@ def save_generated_image(image, prompt):
139
  image.save(filepath)
140
  return filepath
141
 
 
142
  @spaces.GPU(duration=60)
143
  def generate_image(
144
  prompt: str,
@@ -153,18 +152,19 @@ def generate_image(
153
  try:
154
  print(f"\nStarting image generation with prompt: {prompt}")
155
 
156
- # ํ”„๋กฌํ”„ํŠธ ๋ฒˆ์—ญ
 
 
157
  translated_prompt = translate_to_english(prompt)
158
  print(f"Translated prompt: {translated_prompt}")
159
 
160
  if randomize_seed:
161
  seed = random.randint(0, MAX_SEED)
162
- print(f"Using seed: {seed}")
163
 
164
  generator = torch.Generator(device=device).manual_seed(seed)
165
 
166
- print("Starting inference...")
167
- with torch.inference_mode():
168
  image = pipe(
169
  prompt=translated_prompt,
170
  width=width,
@@ -172,20 +172,20 @@ def generate_image(
172
  num_inference_steps=num_inference_steps,
173
  guidance_scale=guidance_scale,
174
  generator=generator,
 
175
  ).images[0]
176
- print("Image generation completed successfully")
177
 
178
  filepath = save_generated_image(image, translated_prompt)
179
- print(f"Image saved to: {filepath}")
 
 
180
 
181
  return image, seed
182
 
183
  except Exception as e:
184
  print(f"Error in generate_image: {str(e)}")
185
- print(f"Full error details: {repr(e)}")
186
- raise gr.Error(f"Image generation failed: {str(e)}")
187
- finally:
188
  clear_memory()
 
189
 
190
  def add_text_with_stroke(draw, text, x, y, font, text_color, stroke_width):
191
  """ํ…์ŠคํŠธ์— ์™ธ๊ณฝ์„ ์„ ์ถ”๊ฐ€ํ•˜๋Š” ํ•จ์ˆ˜"""
 
71
  print(f"Translation error: {str(e)}")
72
  return text
73
 
74
+ # FLUX ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ๋ถ€๋ถ„ ์ˆ˜์ •
75
  print("Initializing FLUX pipeline...")
76
  try:
77
  pipe = FluxPipeline.from_pretrained(
78
  "black-forest-labs/FLUX.1-dev",
79
+ torch_dtype=torch.float16, # ๋ฐ˜์ •๋ฐ€๋„ ์‚ฌ์šฉ
80
  use_auth_token=HF_TOKEN,
81
+ safety_checker=None,
82
+ variant="fp16", # fp16 ๋ณ€ํ˜• ์‚ฌ์šฉ
83
+ device_map="auto" # ์ž๋™ ์žฅ์น˜ ๋งคํ•‘
84
  )
85
  print("FLUX pipeline initialized successfully")
86
 
87
+ # ๋ฉ”๋ชจ๋ฆฌ ์ตœ์ ํ™” ์„ค์ • ๊ฐ•ํ™”
88
+ pipe.enable_attention_slicing(slice_size=1) # ๋” ์ž‘์€ ์Šฌ๋ผ์ด์Šค ํฌ๊ธฐ
89
  pipe.enable_model_cpu_offload() # CPU ์˜คํ”„๋กœ๋”ฉ ํ™œ์„ฑํ™”
90
+ pipe.enable_sequential_cpu_offload() # ์ˆœ์ฐจ์  CPU ์˜คํ”„๋กœ๋”ฉ
91
  print("Pipeline optimization settings applied")
92
 
93
  except Exception as e:
 
97
  # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ๋ถ€๋ถ„ ์ˆ˜์ •
98
  print("Loading LoRA weights...")
99
  try:
 
100
  lora_path = hf_hub_download(
101
  repo_id="openfree/myt-flux-fantasy",
102
+ filename="myt-flux-fantasy.safetensors",
103
  use_auth_token=HF_TOKEN
104
  )
105
  print(f"LoRA weights downloaded to: {lora_path}")
106
 
107
+ # LoRA ๊ฐ€์ค‘์น˜ ๋กœ๋“œ (๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์  ๋ฐฉ์‹)
108
+ pipe.load_lora_weights(lora_path, adapter_name="fantasy")
109
  pipe.fuse_lora(lora_scale=0.125)
110
+
111
+ # ๋ถˆํ•„์š”ํ•œ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
112
+ torch.cuda.empty_cache()
113
+ gc.collect()
114
+
115
  print("LoRA weights loaded and fused successfully")
116
 
117
  except Exception as e:
118
  print(f"Error loading LoRA weights: {str(e)}")
119
+ raise ValueError("Failed to load LoRA weights")
120
+
121
 
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
  # ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
 
137
  image.save(filepath)
138
  return filepath
139
 
140
+ # generate_image ํ•จ์ˆ˜ ์ˆ˜์ •
141
  @spaces.GPU(duration=60)
142
  def generate_image(
143
  prompt: str,
 
152
  try:
153
  print(f"\nStarting image generation with prompt: {prompt}")
154
 
155
+ # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
156
+ clear_memory()
157
+
158
  translated_prompt = translate_to_english(prompt)
159
  print(f"Translated prompt: {translated_prompt}")
160
 
161
  if randomize_seed:
162
  seed = random.randint(0, MAX_SEED)
 
163
 
164
  generator = torch.Generator(device=device).manual_seed(seed)
165
 
166
+ # ๋ฐฐ์น˜ ํฌ๊ธฐ 1๋กœ ๊ณ ์ •ํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ ์ตœ์†Œํ™”
167
+ with torch.inference_mode(), torch.cuda.amp.autocast():
168
  image = pipe(
169
  prompt=translated_prompt,
170
  width=width,
 
172
  num_inference_steps=num_inference_steps,
173
  guidance_scale=guidance_scale,
174
  generator=generator,
175
+ num_images_per_prompt=1,
176
  ).images[0]
 
177
 
178
  filepath = save_generated_image(image, translated_prompt)
179
+
180
+ # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
181
+ clear_memory()
182
 
183
  return image, seed
184
 
185
  except Exception as e:
186
  print(f"Error in generate_image: {str(e)}")
 
 
 
187
  clear_memory()
188
+ raise gr.Error(f"Image generation failed: {str(e)}")
189
 
190
  def add_text_with_stroke(draw, text, x, y, font, text_color, stroke_width):
191
  """ํ…์ŠคํŠธ์— ์™ธ๊ณฝ์„ ์„ ์ถ”๊ฐ€ํ•˜๋Š” ํ•จ์ˆ˜"""