ginipick commited on
Commit
189dd29
ยท
verified ยท
1 Parent(s): b600457

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -158
app.py CHANGED
@@ -16,30 +16,54 @@ import os
16
  import random
17
  import gc
18
 
19
- # ์ƒ์ˆ˜ ์ •์˜ (๋งจ ์•ž์œผ๋กœ ์ด๋™)
20
  MAX_SEED = 2**32 - 1
21
  BASE_MODEL = "black-forest-labs/FLUX.1-dev"
22
  MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
23
  CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
24
 
25
- # ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์„ค์ •
26
- torch.cuda.empty_cache()
27
- gc.collect()
28
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
29
- torch.backends.cudnn.benchmark = True
30
- torch.backends.cuda.matmul.allow_tf32 = True
31
- torch.backends.cuda.max_split_size_mb = 128
32
-
33
- # Hugging Face ํ† ํฐ ์„ค์ •
34
- HF_TOKEN = os.getenv("HF_TOKEN")
35
- if HF_TOKEN is None:
36
- raise ValueError("Please set the HF_TOKEN environment variable")
37
- login(token=HF_TOKEN)
38
 
39
- # CUDA ์„ค์ •
40
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
41
 
42
- # ์ „์—ญ ๋ณ€์ˆ˜๋กœ ๋ชจ๋ธ๋“ค์„ ์„ ์–ธ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  fashion_pipe = None
44
  translator = None
45
  mask_predictor = None
@@ -48,91 +72,31 @@ vt_model = None
48
  pt_model = None
49
  vt_inference = None
50
  pt_inference = None
 
 
51
 
52
- def clear_memory():
53
- if torch.cuda.is_available():
54
- torch.cuda.empty_cache()
55
- torch.cuda.synchronize()
56
- gc.collect()
57
-
58
- # ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜
59
- def initialize_models():
60
- global fashion_pipe
61
- if fashion_pipe is None:
62
- fashion_pipe = DiffusionPipeline.from_pretrained(
63
- BASE_MODEL,
64
- torch_dtype=torch.float16,
65
- use_auth_token=HF_TOKEN
66
- )
67
- fashion_pipe.to(device)
68
-
69
- # ์—ฌ๊ธฐ์„œ initialize_models ํ˜ธ์ถœ
70
- initialize_models()
71
-
72
-
73
-
74
- # ๋ชจ๋ธ ์‚ฌ์šฉ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ
75
- def unload_models():
76
- global fashion_pipe, translator, mask_predictor, densepose_predictor, vt_model, pt_model
77
- fashion_pipe = None
78
- translator = None
79
- mask_predictor = None
80
- densepose_predictor = None
81
- vt_model = None
82
- pt_model = None
83
- clear_memory()
84
-
85
- # Hugging Face ํ† ํฐ ์„ค์ •
86
- HF_TOKEN = os.getenv("HF_TOKEN")
87
- if HF_TOKEN is None:
88
- raise ValueError("Please set the HF_TOKEN environment variable")
89
- login(token=HF_TOKEN)
90
-
91
- # CUDA ์„ค์ •
92
- device = "cuda" if torch.cuda.is_available() else "cpu"
93
 
94
- # ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜
95
- def load_model_with_optimization(model_class, *args, **kwargs):
96
- torch.cuda.empty_cache()
97
- gc.collect()
98
- model = model_class(*args, **kwargs)
99
- if device == "cuda":
100
- model = model.half() # FP16์œผ๋กœ ๋ณ€ํ™˜
101
- return model.to(device)
102
-
103
- def load_lora(pipe, lora_path):
104
- try:
105
- pipe.unload_lora_weights() # ๊ธฐ์กด LoRA ๊ฐ€์ค‘์น˜ ์ œ๊ฑฐ
106
- except:
107
- pass
108
-
109
- try:
110
- pipe.load_lora_weights(lora_path)
111
- return pipe
112
- except Exception as e:
113
- print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
114
- return pipe
115
 
116
- # FLUX ๋ชจ๋ธ ์ดˆ๊ธฐํ™” (ํ•„์š”ํ•  ๋•Œ๋งŒ ๋กœ๋“œ)
117
- fashion_pipe = None
118
- def get_fashion_pipe():
119
  global fashion_pipe
120
  if fashion_pipe is None:
121
- torch.cuda.empty_cache()
122
  fashion_pipe = DiffusionPipeline.from_pretrained(
123
  BASE_MODEL,
124
  torch_dtype=torch.float16,
125
  use_auth_token=HF_TOKEN
126
  )
127
  try:
128
- fashion_pipe.enable_xformers_memory_efficient_attention() # ์ˆ˜์ •๋œ ๋ถ€๋ถ„
129
  except Exception as e:
130
  print(f"Warning: Could not enable memory efficient attention: {e}")
131
  fashion_pipe.enable_sequential_cpu_offload()
132
  return fashion_pipe
133
 
134
- # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™” (ํ•„์š”ํ•  ๋•Œ๋งŒ ๋กœ๋“œ)
135
- translator = None
136
  def get_translator():
137
  global translator
138
  if translator is None:
@@ -141,8 +105,7 @@ def get_translator():
141
  device=device if device == "cuda" else -1)
142
  return translator
143
 
144
-
145
- # Leffa ๋ชจ๋ธ ๊ด€๋ จ ํ•จ์ˆ˜๋“ค
146
  def get_mask_predictor():
147
  global mask_predictor
148
  if mask_predictor is None:
@@ -152,6 +115,7 @@ def get_mask_predictor():
152
  )
153
  return mask_predictor
154
 
 
155
  def get_densepose_predictor():
156
  global densepose_predictor
157
  if densepose_predictor is None:
@@ -161,41 +125,60 @@ def get_densepose_predictor():
161
  )
162
  return densepose_predictor
163
 
 
164
  def get_vt_model():
165
  global vt_model, vt_inference
166
  if vt_model is None:
167
- torch.cuda.empty_cache()
168
- vt_model = load_model_with_optimization(
169
- LeffaModel,
170
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
171
  pretrained_model="./ckpts/virtual_tryon.pth"
172
  )
 
173
  vt_inference = LeffaInference(model=vt_model)
174
  return vt_model, vt_inference
175
 
 
176
  def get_pt_model():
177
  global pt_model, pt_inference
178
  if pt_model is None:
179
- torch.cuda.empty_cache()
180
- pt_model = load_model_with_optimization(
181
- LeffaModel,
182
  pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
183
  pretrained_model="./ckpts/pose_transfer.pth"
184
  )
 
185
  pt_inference = LeffaInference(model=pt_model)
186
  return pt_model, pt_inference
187
 
188
- # Leffa ์ฒดํฌํฌ์ธํŠธ ๋‹ค์šด๋กœ๋“œ
189
- snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
 
191
  def contains_korean(text):
192
  return any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text)
193
 
 
 
194
  @spaces.GPU()
 
195
  def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
196
- clear_memory() # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
197
-
198
  try:
 
199
  if contains_korean(prompt):
200
  translator = get_translator()
201
  translated = translator(prompt)[0]['translation_text']
@@ -203,9 +186,10 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
203
  else:
204
  actual_prompt = prompt
205
 
206
- pipe = get_fashion_pipe()
 
207
 
208
- # ๋ชจ๋“œ์— ๋”ฐ๋ฅธ LoRA ๋กœ๋”ฉ ๋ฐ ํŠธ๋ฆฌ๊ฑฐ์›Œ๋“œ ์„ค์ •
209
  if mode == "Generate Model":
210
  pipe = load_lora(pipe, MODEL_LORA_REPO)
211
  trigger_word = "fashion photography, professional model"
@@ -213,19 +197,23 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
213
  pipe = load_lora(pipe, CLOTHES_LORA_REPO)
214
  trigger_word = "upper clothing, fashion item"
215
 
216
- # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ œํ•œ์„ ์œ„ํ•œ ํฌ๊ธฐ ์กฐ์ •
217
- width = min(width, 768) # ์ตœ๋Œ€ ํฌ๊ธฐ ์ œํ•œ
218
- height = min(height, 768) # ์ตœ๋Œ€ ํฌ๊ธฐ ์ œํ•œ
 
219
 
 
220
  if randomize_seed:
221
  seed = random.randint(0, MAX_SEED)
222
  generator = torch.Generator(device="cuda").manual_seed(seed)
223
 
 
224
  progress(0, "Starting fashion generation...")
225
 
 
226
  image = pipe(
227
  prompt=f"{actual_prompt} {trigger_word}",
228
- num_inference_steps=min(steps, 30), # ์Šคํ… ์ˆ˜ ์ œํ•œ
229
  guidance_scale=cfg_scale,
230
  width=width,
231
  height=height,
@@ -233,51 +221,25 @@ def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width
233
  joint_attention_kwargs={"scale": lora_scale},
234
  ).images[0]
235
 
236
- clear_memory() # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
237
  return image, seed
238
 
239
  except Exception as e:
240
- clear_memory() # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ์—๋„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
241
- raise e
242
-
243
-
244
 
 
245
  def leffa_predict(src_image_path, ref_image_path, control_type):
246
- global mask_predictor, densepose_predictor, vt_model, pt_model, vt_inference, pt_inference
247
-
248
- clear_memory()
249
-
250
  try:
251
- # ํ•„์š”ํ•œ ๏ฟฝ๏ฟฝ๋ธ ์ดˆ๊ธฐํ™”
252
- if control_type == "virtual_tryon" and vt_model is None:
253
- vt_model = LeffaModel(
254
- pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
255
- pretrained_model="./ckpts/virtual_tryon.pth"
256
- )
257
- vt_model.to(device)
258
- vt_inference = LeffaInference(model=vt_model)
259
-
260
- elif control_type == "pose_transfer" and pt_model is None:
261
- pt_model = LeffaModel(
262
- pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
263
- pretrained_model="./ckpts/pose_transfer.pth"
264
- )
265
- pt_model.to(device)
266
- pt_inference = LeffaInference(model=pt_model)
267
-
268
- if mask_predictor is None:
269
- mask_predictor = AutoMasker(
270
- densepose_path="./ckpts/densepose",
271
- schp_path="./ckpts/schp",
272
- )
273
-
274
- if densepose_predictor is None:
275
- densepose_predictor = DensePosePredictor(
276
- config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
277
- weights_path="./ckpts/densepose/model_final_162be9.pkl",
278
- )
279
-
280
- # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
281
  src_image = Image.open(src_image_path)
282
  ref_image = Image.open(ref_image_path)
283
  src_image = resize_and_center(src_image, 768, 1024)
@@ -289,22 +251,18 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
289
  # Mask ์ƒ์„ฑ
290
  if control_type == "virtual_tryon":
291
  src_image = src_image.convert("RGB")
292
- mask = mask_predictor(src_image, "upper")["mask"]
293
  else:
294
  mask = Image.fromarray(np.ones_like(src_image_array) * 255)
295
 
296
  # DensePose ์˜ˆ์ธก
297
- src_image_iuv_array = densepose_predictor.predict_iuv(src_image_array)
298
- src_image_seg_array = densepose_predictor.predict_seg(src_image_array)
299
- src_image_iuv = Image.fromarray(src_image_iuv_array)
300
- src_image_seg = Image.fromarray(src_image_seg_array)
301
 
302
  if control_type == "virtual_tryon":
303
- densepose = src_image_seg
304
- inference = vt_inference
305
  else:
306
- densepose = src_image_iuv
307
- inference = pt_inference
308
 
309
  # Leffa ๋ณ€ํ™˜ ๋ฐ ์ถ”๋ก 
310
  transform = LeffaTransform()
@@ -317,20 +275,22 @@ def leffa_predict(src_image_path, ref_image_path, control_type):
317
  data = transform(data)
318
 
319
  output = inference(data)
320
- gen_image = output["generated_image"][0]
321
-
322
- clear_memory()
323
- return np.array(gen_image)
324
 
325
  except Exception as e:
326
- clear_memory()
327
- raise e
328
 
 
329
  def leffa_predict_vt(src_image_path, ref_image_path):
330
  return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
331
 
 
332
  def leffa_predict_pt(src_image_path, ref_image_path):
333
  return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
 
 
 
334
 
335
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค
336
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo:
 
16
  import random
17
  import gc
18
 
19
+ # ์ƒ์ˆ˜ ์ •์˜
20
  MAX_SEED = 2**32 - 1
21
  BASE_MODEL = "black-forest-labs/FLUX.1-dev"
22
  MODEL_LORA_REPO = "Motas/Flux_Fashion_Photography_Style"
23
  CLOTHES_LORA_REPO = "prithivMLmods/Canopus-Clothing-Flux-LoRA"
24
 
25
+ # ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ๋ฅผ ์œ„ํ•œ ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ
26
+ def safe_model_call(func):
27
+ def wrapper(*args, **kwargs):
28
+ try:
29
+ clear_memory()
30
+ result = func(*args, **kwargs)
31
+ clear_memory()
32
+ return result
33
+ except Exception as e:
34
+ clear_memory()
35
+ print(f"Error in {func.__name__}: {str(e)}")
36
+ raise
37
+ return wrapper
38
 
39
+ # ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ํ•จ์ˆ˜
40
+ def clear_memory():
41
+ if torch.cuda.is_available():
42
+ torch.cuda.empty_cache()
43
+ torch.cuda.synchronize()
44
+ gc.collect()
45
 
46
+ def setup_environment():
47
+ # ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์„ค์ •
48
+ torch.cuda.empty_cache()
49
+ gc.collect()
50
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
51
+ torch.backends.cudnn.benchmark = True
52
+ torch.backends.cuda.matmul.allow_tf32 = True
53
+ torch.backends.cuda.max_split_size_mb = 128
54
+
55
+ # Hugging Face ํ† ํฐ ์„ค์ •
56
+ global HF_TOKEN
57
+ HF_TOKEN = os.getenv("HF_TOKEN")
58
+ if HF_TOKEN is None:
59
+ raise ValueError("Please set the HF_TOKEN environment variable")
60
+ login(token=HF_TOKEN)
61
+
62
+ # CUDA ์„ค์ •
63
+ global device
64
+ device = "cuda" if torch.cuda.is_available() else "cpu"
65
+
66
+ # ์ „์—ญ ๋ณ€์ˆ˜ ์ดˆ๊ธฐํ™”
67
  fashion_pipe = None
68
  translator = None
69
  mask_predictor = None
 
72
  pt_model = None
73
  vt_inference = None
74
  pt_inference = None
75
+ device = None
76
+ HF_TOKEN = None
77
 
78
+ # ํ™˜๊ฒฝ ์„ค์ • ์‹คํ–‰
79
+ setup_environment()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ # ๋ชจ๋ธ ๊ด€๋ฆฌ ํ•จ์ˆ˜๋“ค
83
+ def initialize_fashion_pipe():
 
84
  global fashion_pipe
85
  if fashion_pipe is None:
86
+ clear_memory()
87
  fashion_pipe = DiffusionPipeline.from_pretrained(
88
  BASE_MODEL,
89
  torch_dtype=torch.float16,
90
  use_auth_token=HF_TOKEN
91
  )
92
  try:
93
+ fashion_pipe.enable_xformers_memory_efficient_attention()
94
  except Exception as e:
95
  print(f"Warning: Could not enable memory efficient attention: {e}")
96
  fashion_pipe.enable_sequential_cpu_offload()
97
  return fashion_pipe
98
 
99
+ @safe_model_call
 
100
  def get_translator():
101
  global translator
102
  if translator is None:
 
105
  device=device if device == "cuda" else -1)
106
  return translator
107
 
108
+ @safe_model_call
 
109
  def get_mask_predictor():
110
  global mask_predictor
111
  if mask_predictor is None:
 
115
  )
116
  return mask_predictor
117
 
118
+ @safe_model_call
119
  def get_densepose_predictor():
120
  global densepose_predictor
121
  if densepose_predictor is None:
 
125
  )
126
  return densepose_predictor
127
 
128
+ @safe_model_call
129
  def get_vt_model():
130
  global vt_model, vt_inference
131
  if vt_model is None:
132
+ vt_model = LeffaModel(
 
 
133
  pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
134
  pretrained_model="./ckpts/virtual_tryon.pth"
135
  )
136
+ vt_model = vt_model.half().to(device)
137
  vt_inference = LeffaInference(model=vt_model)
138
  return vt_model, vt_inference
139
 
140
+ @safe_model_call
141
  def get_pt_model():
142
  global pt_model, pt_inference
143
  if pt_model is None:
144
+ pt_model = LeffaModel(
 
 
145
  pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
146
  pretrained_model="./ckpts/pose_transfer.pth"
147
  )
148
+ pt_model = pt_model.half().to(device)
149
  pt_inference = LeffaInference(model=pt_model)
150
  return pt_model, pt_inference
151
 
152
+ def load_lora(pipe, lora_path):
153
+ try:
154
+ pipe.unload_lora_weights()
155
+ except:
156
+ pass
157
+ try:
158
+ pipe.load_lora_weights(lora_path)
159
+ return pipe
160
+ except Exception as e:
161
+ print(f"Warning: Failed to load LoRA weights from {lora_path}: {e}")
162
+ return pipe
163
+
164
+ # ์ดˆ๊ธฐ ์„ค์ • ํ•จ์ˆ˜
165
+ def setup():
166
+ # Leffa ์ฒดํฌํฌ์ธํŠธ ๋‹ค์šด๋กœ๋“œ
167
+ snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
168
+ # ๊ธฐ๋ณธ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
169
+ initialize_fashion_pipe()
170
 
171
+ # ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜
172
  def contains_korean(text):
173
  return any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in text)
174
 
175
+
176
+ # ๋ฉ”์ธ ๊ธฐ๋Šฅ ํ•จ์ˆ˜๋“ค
177
  @spaces.GPU()
178
+ @safe_model_call
179
  def generate_fashion(prompt, mode, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
 
 
180
  try:
181
+ # ํ•œ๊ธ€ ์ฒ˜๋ฆฌ
182
  if contains_korean(prompt):
183
  translator = get_translator()
184
  translated = translator(prompt)[0]['translation_text']
 
186
  else:
187
  actual_prompt = prompt
188
 
189
+ # ํŒŒ์ดํ”„๋ผ์ธ ๊ฐ€์ ธ์˜ค๊ธฐ
190
+ pipe = initialize_fashion_pipe()
191
 
192
+ # LoRA ์„ค์ •
193
  if mode == "Generate Model":
194
  pipe = load_lora(pipe, MODEL_LORA_REPO)
195
  trigger_word = "fashion photography, professional model"
 
197
  pipe = load_lora(pipe, CLOTHES_LORA_REPO)
198
  trigger_word = "upper clothing, fashion item"
199
 
200
+ # ํŒŒ๋ผ๋ฏธํ„ฐ ์ œํ•œ
201
+ width = min(width, 768)
202
+ height = min(height, 768)
203
+ steps = min(steps, 30)
204
 
205
+ # ์‹œ๋“œ ์„ค์ •
206
  if randomize_seed:
207
  seed = random.randint(0, MAX_SEED)
208
  generator = torch.Generator(device="cuda").manual_seed(seed)
209
 
210
+ # ์ง„ํ–‰๋ฅ  ํ‘œ์‹œ
211
  progress(0, "Starting fashion generation...")
212
 
213
+ # ์ด๋ฏธ์ง€ ์ƒ์„ฑ
214
  image = pipe(
215
  prompt=f"{actual_prompt} {trigger_word}",
216
+ num_inference_steps=steps,
217
  guidance_scale=cfg_scale,
218
  width=width,
219
  height=height,
 
221
  joint_attention_kwargs={"scale": lora_scale},
222
  ).images[0]
223
 
 
224
  return image, seed
225
 
226
  except Exception as e:
227
+ print(f"Error in generate_fashion: {str(e)}")
228
+ raise
 
 
229
 
230
+ @safe_model_call
231
  def leffa_predict(src_image_path, ref_image_path, control_type):
 
 
 
 
232
  try:
233
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
234
+ if control_type == "virtual_tryon":
235
+ model, inference = get_vt_model()
236
+ else:
237
+ model, inference = get_pt_model()
238
+
239
+ mask_pred = get_mask_predictor()
240
+ dense_pred = get_densepose_predictor()
241
+
242
+ # ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ ์ „์ฒ˜๋ฆฌ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  src_image = Image.open(src_image_path)
244
  ref_image = Image.open(ref_image_path)
245
  src_image = resize_and_center(src_image, 768, 1024)
 
251
  # Mask ์ƒ์„ฑ
252
  if control_type == "virtual_tryon":
253
  src_image = src_image.convert("RGB")
254
+ mask = mask_pred(src_image, "upper")["mask"]
255
  else:
256
  mask = Image.fromarray(np.ones_like(src_image_array) * 255)
257
 
258
  # DensePose ์˜ˆ์ธก
259
+ src_image_iuv_array = dense_pred.predict_iuv(src_image_array)
260
+ src_image_seg_array = dense_pred.predict_seg(src_image_array)
 
 
261
 
262
  if control_type == "virtual_tryon":
263
+ densepose = Image.fromarray(src_image_seg_array)
 
264
  else:
265
+ densepose = Image.fromarray(src_image_iuv_array)
 
266
 
267
  # Leffa ๋ณ€ํ™˜ ๋ฐ ์ถ”๋ก 
268
  transform = LeffaTransform()
 
275
  data = transform(data)
276
 
277
  output = inference(data)
278
+ return np.array(output["generated_image"][0])
 
 
 
279
 
280
  except Exception as e:
281
+ print(f"Error in leffa_predict: {str(e)}")
282
+ raise
283
 
284
+ @safe_model_call
285
  def leffa_predict_vt(src_image_path, ref_image_path):
286
  return leffa_predict(src_image_path, ref_image_path, "virtual_tryon")
287
 
288
+ @safe_model_call
289
  def leffa_predict_pt(src_image_path, ref_image_path):
290
  return leffa_predict(src_image_path, ref_image_path, "pose_transfer")
291
+
292
+ # ์ดˆ๊ธฐ ์„ค์ • ์‹คํ–‰
293
+ setup()
294
 
295
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค
296
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange") as demo: