Spaces:
Running
on
Zero
Running
on
Zero
modified: gradio_app_asy.py
Browse files- gradio_app_asy.py +13 -12
gradio_app_asy.py
CHANGED
@@ -64,6 +64,7 @@ def download_file(repo_id, file_name):
|
|
64 |
# Load model function with dynamic paths based on the selected model
|
65 |
def load_target_model(frame, domain):
|
66 |
global model, clip_l, t5xxl, ae, lora_model
|
|
|
67 |
BASE_FLUX_CHECKPOINT=download_file(flux_repo_id, flux_file)
|
68 |
CLIP_L_PATH = download_file(clip_repo_id, clip_l_file)
|
69 |
T5XXL_PATH = download_file(clip_repo_id, t5xxl_file)
|
@@ -71,17 +72,15 @@ def load_target_model(frame, domain):
|
|
71 |
LORA_WEIGHTS_PATH = download_file(lora_repo_id, lora_paths[frame])
|
72 |
|
73 |
logger.info("Loading models...")
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
84 |
-
logger.info("Models loaded successfully.")
|
85 |
# Load LoRA weights
|
86 |
multiplier = 1.0
|
87 |
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
@@ -197,6 +196,8 @@ def infer(prompt, frame, seed=0):
|
|
197 |
generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
198 |
|
199 |
logger.info("Image generation completed.")
|
|
|
|
|
200 |
return generated_image
|
201 |
|
202 |
def update_domains(floor):
|
@@ -232,7 +233,7 @@ with gr.Blocks() as demo:
|
|
232 |
with gr.Column(scale=1):
|
233 |
# Status message box
|
234 |
status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3)
|
235 |
-
|
236 |
with gr.Row():
|
237 |
with gr.Column(scale=1):
|
238 |
# Input for the prompt
|
|
|
64 |
# Load model function with dynamic paths based on the selected model
|
65 |
def load_target_model(frame, domain):
|
66 |
global model, clip_l, t5xxl, ae, lora_model
|
67 |
+
|
68 |
BASE_FLUX_CHECKPOINT=download_file(flux_repo_id, flux_file)
|
69 |
CLIP_L_PATH = download_file(clip_repo_id, clip_l_file)
|
70 |
T5XXL_PATH = download_file(clip_repo_id, t5xxl_file)
|
|
|
72 |
LORA_WEIGHTS_PATH = download_file(lora_repo_id, lora_paths[frame])
|
73 |
|
74 |
logger.info("Loading models...")
|
75 |
+
_, model = flux_utils.load_flow_model(
|
76 |
+
BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
|
77 |
+
)
|
78 |
+
clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
79 |
+
clip_l.eval()
|
80 |
+
t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
81 |
+
t5xxl.eval()
|
82 |
+
ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
|
83 |
+
logger.info("Models loaded successfully.")
|
|
|
|
|
84 |
# Load LoRA weights
|
85 |
multiplier = 1.0
|
86 |
weights_sd = load_file(LORA_WEIGHTS_PATH)
|
|
|
196 |
generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
197 |
|
198 |
logger.info("Image generation completed.")
|
199 |
+
torch.cuda.empty_cache()
|
200 |
+
|
201 |
return generated_image
|
202 |
|
203 |
def update_domains(floor):
|
|
|
233 |
with gr.Column(scale=1):
|
234 |
# Status message box
|
235 |
status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3)
|
236 |
+
|
237 |
with gr.Row():
|
238 |
with gr.Column(scale=1):
|
239 |
# Input for the prompt
|