yiren98 commited on
Commit
35cb2b8
·
1 Parent(s): b1be519

modified: gradio_app_asy.py

Browse files
Files changed (1) hide show
  1. gradio_app_asy.py +13 -12
gradio_app_asy.py CHANGED
@@ -64,6 +64,7 @@ def download_file(repo_id, file_name):
64
  # Load model function with dynamic paths based on the selected model
65
  def load_target_model(frame, domain):
66
  global model, clip_l, t5xxl, ae, lora_model
 
67
  BASE_FLUX_CHECKPOINT=download_file(flux_repo_id, flux_file)
68
  CLIP_L_PATH = download_file(clip_repo_id, clip_l_file)
69
  T5XXL_PATH = download_file(clip_repo_id, t5xxl_file)
@@ -71,17 +72,15 @@ def load_target_model(frame, domain):
71
  LORA_WEIGHTS_PATH = download_file(lora_repo_id, lora_paths[frame])
72
 
73
  logger.info("Loading models...")
74
- # try:
75
- if model is None is None or clip_l is None or t5xxl is None or ae is None:
76
- _, model = flux_utils.load_flow_model(
77
- BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
78
- )
79
- clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
80
- clip_l.eval()
81
- t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
82
- t5xxl.eval()
83
- ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
84
- logger.info("Models loaded successfully.")
85
  # Load LoRA weights
86
  multiplier = 1.0
87
  weights_sd = load_file(LORA_WEIGHTS_PATH)
@@ -197,6 +196,8 @@ def infer(prompt, frame, seed=0):
197
  generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
198
 
199
  logger.info("Image generation completed.")
 
 
200
  return generated_image
201
 
202
  def update_domains(floor):
@@ -232,7 +233,7 @@ with gr.Blocks() as demo:
232
  with gr.Column(scale=1):
233
  # Status message box
234
  status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3)
235
-
236
  with gr.Row():
237
  with gr.Column(scale=1):
238
  # Input for the prompt
 
64
  # Load model function with dynamic paths based on the selected model
65
  def load_target_model(frame, domain):
66
  global model, clip_l, t5xxl, ae, lora_model
67
+
68
  BASE_FLUX_CHECKPOINT=download_file(flux_repo_id, flux_file)
69
  CLIP_L_PATH = download_file(clip_repo_id, clip_l_file)
70
  T5XXL_PATH = download_file(clip_repo_id, t5xxl_file)
 
72
  LORA_WEIGHTS_PATH = download_file(lora_repo_id, lora_paths[frame])
73
 
74
  logger.info("Loading models...")
75
+ _, model = flux_utils.load_flow_model(
76
+ BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
77
+ )
78
+ clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
79
+ clip_l.eval()
80
+ t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
81
+ t5xxl.eval()
82
+ ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
83
+ logger.info("Models loaded successfully.")
 
 
84
  # Load LoRA weights
85
  multiplier = 1.0
86
  weights_sd = load_file(LORA_WEIGHTS_PATH)
 
196
  generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
197
 
198
  logger.info("Image generation completed.")
199
+ torch.cuda.empty_cache()
200
+
201
  return generated_image
202
 
203
  def update_domains(floor):
 
233
  with gr.Column(scale=1):
234
  # Status message box
235
  status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3)
236
+
237
  with gr.Row():
238
  with gr.Column(scale=1):
239
  # Input for the prompt