Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -37,11 +37,11 @@ class ModelVersion:
|
|
37 |
ENABLE_ANTI_BLUR_DEFAULT = False
|
38 |
ENABLE_REALISM_DEFAULT = False
|
39 |
|
40 |
-
pipeline = None
|
41 |
loaded_pipeline_config = {
|
42 |
"model_version": "aes_stage2",
|
43 |
"enable_realism": False,
|
44 |
"enable_anti_blur": False,
|
|
|
45 |
}
|
46 |
|
47 |
|
@@ -61,22 +61,21 @@ def download_models():
|
|
61 |
|
62 |
|
63 |
def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
|
64 |
-
global pipeline
|
65 |
-
|
66 |
if (
|
67 |
-
pipeline
|
68 |
and loaded_pipeline_config["enable_realism"] == enable_realism
|
69 |
and loaded_pipeline_config["enable_anti_blur"] == enable_anti_blur
|
70 |
and model_version == loaded_pipeline_config["model_version"]
|
71 |
):
|
72 |
-
return
|
73 |
|
74 |
loaded_pipeline_config["enable_realism"] = enable_realism
|
75 |
loaded_pipeline_config["enable_anti_blur"] = enable_anti_blur
|
76 |
loaded_pipeline_config["model_version"] = model_version
|
77 |
|
|
|
78 |
if pipeline is None or pipeline.model_version != model_version:
|
79 |
-
del pipeline
|
80 |
|
81 |
model_path = f'./models/InfiniteYou/infu_flux_v1.0/{model_version}'
|
82 |
print(f'loading model from {model_path}')
|
@@ -90,6 +89,8 @@ def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
|
|
90 |
model_version=model_version,
|
91 |
)
|
92 |
|
|
|
|
|
93 |
pipeline.pipe.delete_adapters(['realism', 'anti_blur'])
|
94 |
loras = []
|
95 |
if enable_realism:
|
@@ -97,7 +98,7 @@ def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
|
|
97 |
if enable_anti_blur:
|
98 |
loras.append(['./models/InfiniteYou/supports/optional_loras/flux_anti_blur_lora.safetensors', 'anti_blur', 1.0])
|
99 |
pipeline.load_loras(loras)
|
100 |
-
|
101 |
|
102 |
# @spaces.GPU
|
103 |
def generate_image(
|
@@ -116,9 +117,7 @@ def generate_image(
|
|
116 |
enable_anti_blur,
|
117 |
model_version
|
118 |
):
|
119 |
-
|
120 |
-
|
121 |
-
prepare_pipeline(model_version=model_version, enable_realism=enable_realism, enable_anti_blur=enable_anti_blur)
|
122 |
|
123 |
if seed == 0:
|
124 |
seed = torch.seed() & 0xFFFFFFFF
|
|
|
37 |
ENABLE_ANTI_BLUR_DEFAULT = False
|
38 |
ENABLE_REALISM_DEFAULT = False
|
39 |
|
|
|
40 |
loaded_pipeline_config = {
|
41 |
"model_version": "aes_stage2",
|
42 |
"enable_realism": False,
|
43 |
"enable_anti_blur": False,
|
44 |
+
'pipeline': None
|
45 |
}
|
46 |
|
47 |
|
|
|
61 |
|
62 |
|
63 |
def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
|
|
|
|
|
64 |
if (
|
65 |
+
loaded_pipeline_config['pipeline'] is not None
|
66 |
and loaded_pipeline_config["enable_realism"] == enable_realism
|
67 |
and loaded_pipeline_config["enable_anti_blur"] == enable_anti_blur
|
68 |
and model_version == loaded_pipeline_config["model_version"]
|
69 |
):
|
70 |
+
return loaded_pipeline_config['pipeline']
|
71 |
|
72 |
loaded_pipeline_config["enable_realism"] = enable_realism
|
73 |
loaded_pipeline_config["enable_anti_blur"] = enable_anti_blur
|
74 |
loaded_pipeline_config["model_version"] = model_version
|
75 |
|
76 |
+
pipeline = loaded_pipeline_config['pipeline']
|
77 |
if pipeline is None or pipeline.model_version != model_version:
|
78 |
+
del loaded_pipeline_config['pipeline']
|
79 |
|
80 |
model_path = f'./models/InfiniteYou/infu_flux_v1.0/{model_version}'
|
81 |
print(f'loading model from {model_path}')
|
|
|
89 |
model_version=model_version,
|
90 |
)
|
91 |
|
92 |
+
loaded_pipeline_config['pipeline'] = pipeline
|
93 |
+
|
94 |
pipeline.pipe.delete_adapters(['realism', 'anti_blur'])
|
95 |
loras = []
|
96 |
if enable_realism:
|
|
|
98 |
if enable_anti_blur:
|
99 |
loras.append(['./models/InfiniteYou/supports/optional_loras/flux_anti_blur_lora.safetensors', 'anti_blur', 1.0])
|
100 |
pipeline.load_loras(loras)
|
101 |
+
return pipeline
|
102 |
|
103 |
# @spaces.GPU
|
104 |
def generate_image(
|
|
|
117 |
enable_anti_blur,
|
118 |
model_version
|
119 |
):
|
120 |
+
pipeline = prepare_pipeline(model_version=model_version, enable_realism=enable_realism, enable_anti_blur=enable_anti_blur)
|
|
|
|
|
121 |
|
122 |
if seed == 0:
|
123 |
seed = torch.seed() & 0xFFFFFFFF
|