EndlessSora commited on
Commit
dc8acb8
·
1 Parent(s): cbaad8e

improve memory usage for zero GPUs

Browse files
app.py CHANGED
@@ -60,6 +60,38 @@ def download_models():
60
  exit()
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
64
  if (
65
  loaded_pipeline_config['pipeline'] is not None
@@ -74,34 +106,34 @@ def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
74
  loaded_pipeline_config["model_version"] = model_version
75
 
76
  pipeline = loaded_pipeline_config['pipeline']
77
- if pipeline is None or pipeline.model_version != model_version:
78
- del loaded_pipeline_config['pipeline']
79
- del pipeline
80
-
81
- gc.collect()
82
- torch.cuda.empty_cache()
83
-
84
- model_path = f'./models/InfiniteYou/infu_flux_v1.0/{model_version}'
85
- print(f'loading model from {model_path}')
86
-
87
- pipeline = InfUFluxPipeline(
88
- base_model_path='./models/FLUX.1-dev',
89
- infu_model_path=model_path,
90
- insightface_root_path='./models/InfiniteYou/supports/insightface',
91
- image_proj_num_tokens=8,
92
- infu_flux_version='v1.0',
93
- model_version=model_version,
94
- )
 
95
 
96
  loaded_pipeline_config['pipeline'] = pipeline
97
 
98
  pipeline.pipe.delete_adapters(['realism', 'anti_blur'])
99
  loras = []
100
- if enable_realism:
101
- loras.append(['./models/InfiniteYou/supports/optional_loras/flux_realism_lora.safetensors', 'realism', 1.0])
102
- if enable_anti_blur:
103
- loras.append(['./models/InfiniteYou/supports/optional_loras/flux_anti_blur_lora.safetensors', 'anti_blur', 1.0])
104
- pipeline.load_loras(loras)
105
  return pipeline
106
 
107
 
@@ -238,7 +270,7 @@ with gr.Blocks() as demo:
238
  inputs=[ui_id_image, ui_control_image, ui_prompt_text, ui_seed, ui_enable_realism, ui_enable_anti_blur, ui_model_version],
239
  outputs=[image_output],
240
  fn=generate_examples,
241
- cache_examples=False
242
  )
243
 
244
  ui_btn_generate.click(
@@ -309,10 +341,9 @@ huggingface_hub.login(os.getenv('PRIVATE_HF_TOKEN'))
309
 
310
  download_models()
311
 
312
- prepare_pipeline(model_version='sim_stage1', enable_realism=True, enable_anti_blur=True)
313
- prepare_pipeline(model_version=ModelVersion.DEFAULT_VERSION, enable_realism=ENABLE_REALISM_DEFAULT, enable_anti_blur=ENABLE_ANTI_BLUR_DEFAULT)
314
 
315
- demo.queue()
316
  demo.launch()
317
  # demo.launch(server_name='0.0.0.0') # IPv4
318
  # demo.launch(server_name='[::]') # IPv6
 
60
  exit()
61
 
62
 
63
+ def init_pipeline(model_version, enable_realism, enable_anti_blur):
64
+ loaded_pipeline_config["enable_realism"] = enable_realism
65
+ loaded_pipeline_config["enable_anti_blur"] = enable_anti_blur
66
+ loaded_pipeline_config["model_version"] = model_version
67
+
68
+ pipeline = loaded_pipeline_config['pipeline']
69
+ gc.collect()
70
+ torch.cuda.empty_cache()
71
+
72
+ model_path = f'./models/InfiniteYou/infu_flux_v1.0/{model_version}'
73
+ print(f'loading model from {model_path}')
74
+
75
+ pipeline = InfUFluxPipeline(
76
+ base_model_path='./models/FLUX.1-dev',
77
+ infu_model_path=model_path,
78
+ insightface_root_path='./models/InfiniteYou/supports/insightface',
79
+ image_proj_num_tokens=8,
80
+ infu_flux_version='v1.0',
81
+ model_version=model_version,
82
+ )
83
+
84
+ loaded_pipeline_config['pipeline'] = pipeline
85
+
86
+ pipeline.pipe.delete_adapters(['realism', 'anti_blur'])
87
+ loras = []
88
+ if enable_realism: loras.append(['realism', 1.0])
89
+ if enable_anti_blur: loras.append(['anti_blur', 1.0])
90
+ pipeline.load_loras_state_dict(loras)
91
+
92
+ return pipeline
93
+
94
+
95
  def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
96
  if (
97
  loaded_pipeline_config['pipeline'] is not None
 
106
  loaded_pipeline_config["model_version"] = model_version
107
 
108
  pipeline = loaded_pipeline_config['pipeline']
109
+ if pipeline is None or pipeline.model_version != model_version:
110
+ print(f'Switching model to {model_version}')
111
+ pipeline.model_version = model_version
112
+ if model_version == 'aes_stage2':
113
+ pipeline.infusenet_sim.cpu()
114
+ pipeline.image_proj_model_sim.cpu()
115
+ torch.cuda.empty_cache()
116
+ pipeline.infusenet_aes.to(pipeline.pipe.device)
117
+ pipeline.pipe.controlnet = pipeline.infusenet_aes
118
+ pipeline.image_proj_model_aes.to(pipeline.pipe.device)
119
+ pipeline.image_proj_model = pipeline.image_proj_model_aes
120
+ else:
121
+ pipeline.infusenet_aes.cpu()
122
+ pipeline.image_proj_model_aes.cpu()
123
+ torch.cuda.empty_cache()
124
+ pipeline.infusenet_sim.to(pipeline.pipe.device)
125
+ pipeline.pipe.controlnet = pipeline.infusenet_sim
126
+ pipeline.image_proj_model_sim.to(pipeline.pipe.device)
127
+ pipeline.image_proj_model = pipeline.image_proj_model_sim
128
 
129
  loaded_pipeline_config['pipeline'] = pipeline
130
 
131
  pipeline.pipe.delete_adapters(['realism', 'anti_blur'])
132
  loras = []
133
+ if enable_realism: loras.append(['realism', 1.0])
134
+ if enable_anti_blur: loras.append(['anti_blur', 1.0])
135
+ pipeline.load_loras_state_dict(loras)
136
+
 
137
  return pipeline
138
 
139
 
 
270
  inputs=[ui_id_image, ui_control_image, ui_prompt_text, ui_seed, ui_enable_realism, ui_enable_anti_blur, ui_model_version],
271
  outputs=[image_output],
272
  fn=generate_examples,
273
+ cache_examples=True
274
  )
275
 
276
  ui_btn_generate.click(
 
341
 
342
  download_models()
343
 
344
+ init_pipeline(model_version=ModelVersion.DEFAULT_VERSION, enable_realism=ENABLE_REALISM_DEFAULT, enable_anti_blur=ENABLE_ANTI_BLUR_DEFAULT)
 
345
 
346
+ # demo.queue()
347
  demo.launch()
348
  # demo.launch(server_name='0.0.0.0') # IPv4
349
  # demo.launch(server_name='[::]') # IPv6
pipelines/pipeline_flux_infusenet.py CHANGED
@@ -261,9 +261,6 @@ class FluxInfuseNetPipeline(FluxControlNetPipeline):
261
  images.
262
  """
263
 
264
- # CPU offload controlnet
265
- self.controlnet.cpu()
266
-
267
  height = height or self.default_sample_size * self.vae_scale_factor
268
  width = width or self.default_sample_size * self.vae_scale_factor
269
 
@@ -307,6 +304,11 @@ class FluxInfuseNetPipeline(FluxControlNetPipeline):
307
  device = self._execution_device
308
  dtype = self.transformer.dtype
309
 
 
 
 
 
 
310
  lora_scale = (
311
  self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
312
  )
@@ -599,11 +601,6 @@ class FluxInfuseNetPipeline(FluxControlNetPipeline):
599
 
600
  if XLA_AVAILABLE:
601
  xm.mark_step()
602
-
603
- # CPU offload controlnet, move back T5 to GPU
604
- self.controlnet.cpu()
605
- torch.cuda.empty_cache()
606
- self.text_encoder_2.to(device)
607
 
608
  if output_type == "latent":
609
  image = latents
 
261
  images.
262
  """
263
 
 
 
 
264
  height = height or self.default_sample_size * self.vae_scale_factor
265
  width = width or self.default_sample_size * self.vae_scale_factor
266
 
 
304
  device = self._execution_device
305
  dtype = self.transformer.dtype
306
 
307
+ # CPU offload controlnet, move back T5 to GPU
308
+ self.controlnet.cpu()
309
+ torch.cuda.empty_cache()
310
+ self.text_encoder_2.to(device)
311
+
312
  lora_scale = (
313
  self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
314
  )
 
601
 
602
  if XLA_AVAILABLE:
603
  xm.mark_step()
 
 
 
 
 
604
 
605
  if output_type == "latent":
606
  image = latents
pipelines/pipeline_infu_flux.py CHANGED
@@ -137,26 +137,33 @@ class InfUFluxPipeline:
137
 
138
  # Load pipeline
139
  try:
140
- infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel')
141
- self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16)
 
 
142
  except:
143
  print("No InfiniteYou model found. Downloading from HuggingFace `ByteDance/InfiniteYou` to `./models/InfiniteYou` ...")
144
  snapshot_download(repo_id='ByteDance/InfiniteYou', local_dir='./models/InfiniteYou', local_dir_use_symlinks=False)
145
- infu_model_path = os.path.join('./models/InfiniteYou', f'infu_flux_{infu_flux_version}', model_version)
 
 
 
146
  infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel')
147
- self.infusenet = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16)
148
  insightface_root_path = './models/InfiniteYou/supports/insightface'
 
 
149
  try:
150
  pipe = FluxInfuseNetPipeline.from_pretrained(
151
  base_model_path,
152
- controlnet=self.infusenet,
153
  torch_dtype=torch.bfloat16,
154
  )
155
  except:
156
  try:
157
  pipe = FluxInfuseNetPipeline.from_single_file(
158
  base_model_path,
159
- controlnet=self.infusenet,
160
  torch_dtype=torch.bfloat16,
161
  )
162
  except Exception as e:
@@ -168,8 +175,9 @@ class InfUFluxPipeline:
168
  print('\nIf you are using other models, please download them to a local directory and use `base_model_path` to specify the correct path.')
169
  exit()
170
  pipe.to('cuda', torch.bfloat16)
171
- # CPU offload controlnet in advance
172
  pipe.controlnet.cpu()
 
173
  torch.cuda.empty_cache()
174
  # pipe.enable_model_cpu_offload()
175
  self.pipe = pipe
@@ -187,14 +195,33 @@ class InfUFluxPipeline:
187
  output_dim=4096,
188
  ff_mult=4,
189
  )
190
- image_proj_model_path = os.path.join(infu_model_path, 'image_proj_model.bin')
191
  ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu")
192
  image_proj_model.load_state_dict(ipm_state_dict['image_proj'])
193
  del ipm_state_dict
194
  image_proj_model.to('cuda', torch.bfloat16)
195
  image_proj_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- self.image_proj_model = image_proj_model
198
 
199
  # Load face encoder
200
  self.app_640 = FaceAnalysis(name='antelopev2',
@@ -211,12 +238,34 @@ class InfUFluxPipeline:
211
 
212
  self.arcface_model = init_recognition_model('arcface', device='cuda')
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  def load_loras(self, loras):
215
  names, scales = [],[]
216
  for lora_path, lora_name, lora_scale in loras:
217
  if lora_path != "":
218
  print(f"loading lora {lora_path}")
219
- self.pipe.load_lora_weights(lora_path, adapter_name = lora_name)
220
  names.append(lora_name)
221
  scales.append(lora_scale)
222
 
 
137
 
138
  # Load pipeline
139
  try:
140
+ infusenet_path = os.path.join(os.path.dirname(infu_model_path), 'aes_stage2', 'InfuseNetModel')
141
+ self.infusenet_aes = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16)
142
+ infusenet_path = os.path.join(os.path.dirname(infu_model_path), 'sim_stage1', 'InfuseNetModel')
143
+ self.infusenet_sim = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16)
144
  except:
145
  print("No InfiniteYou model found. Downloading from HuggingFace `ByteDance/InfiniteYou` to `./models/InfiniteYou` ...")
146
  snapshot_download(repo_id='ByteDance/InfiniteYou', local_dir='./models/InfiniteYou', local_dir_use_symlinks=False)
147
+ infu_model_path = os.path.join('./models/InfiniteYou', f'infu_flux_{infu_flux_version}', 'aes_stage2')
148
+ infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel')
149
+ self.infusenet_aes = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16)
150
+ infu_model_path = os.path.join('./models/InfiniteYou', f'infu_flux_{infu_flux_version}', 'sim_stage1')
151
  infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel')
152
+ self.infusenet_sim = FluxControlNetModel.from_pretrained(infusenet_path, torch_dtype=torch.bfloat16)
153
  insightface_root_path = './models/InfiniteYou/supports/insightface'
154
+ self.infusenet_sim.cpu()
155
+ torch.cuda.empty_cache()
156
  try:
157
  pipe = FluxInfuseNetPipeline.from_pretrained(
158
  base_model_path,
159
+ controlnet=self.infusenet_aes,
160
  torch_dtype=torch.bfloat16,
161
  )
162
  except:
163
  try:
164
  pipe = FluxInfuseNetPipeline.from_single_file(
165
  base_model_path,
166
+ controlnet=self.infusenet_aes,
167
  torch_dtype=torch.bfloat16,
168
  )
169
  except Exception as e:
 
175
  print('\nIf you are using other models, please download them to a local directory and use `base_model_path` to specify the correct path.')
176
  exit()
177
  pipe.to('cuda', torch.bfloat16)
178
+ # CPU offload controlnet and T5 in advance
179
  pipe.controlnet.cpu()
180
+ pipe.text_encoder_2.cpu()
181
  torch.cuda.empty_cache()
182
  # pipe.enable_model_cpu_offload()
183
  self.pipe = pipe
 
195
  output_dim=4096,
196
  ff_mult=4,
197
  )
198
+ image_proj_model_path = os.path.join(os.path.dirname(infu_model_path), 'aes_stage2', 'image_proj_model.bin')
199
  ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu")
200
  image_proj_model.load_state_dict(ipm_state_dict['image_proj'])
201
  del ipm_state_dict
202
  image_proj_model.to('cuda', torch.bfloat16)
203
  image_proj_model.eval()
204
+ self.image_proj_model_aes = image_proj_model
205
+
206
+ image_proj_model = Resampler(
207
+ dim=1280,
208
+ depth=4,
209
+ dim_head=64,
210
+ heads=20,
211
+ num_queries=num_tokens,
212
+ embedding_dim=image_emb_dim,
213
+ output_dim=4096,
214
+ ff_mult=4,
215
+ )
216
+ image_proj_model_path = os.path.join(os.path.dirname(infu_model_path), 'sim_stage1', 'image_proj_model.bin')
217
+ ipm_state_dict = torch.load(image_proj_model_path, map_location="cpu")
218
+ image_proj_model.load_state_dict(ipm_state_dict['image_proj'])
219
+ del ipm_state_dict
220
+ image_proj_model.to('cpu', torch.bfloat16)
221
+ image_proj_model.eval()
222
+ self.image_proj_model_sim = image_proj_model
223
 
224
+ self.image_proj_model = self.image_proj_model_aes
225
 
226
  # Load face encoder
227
  self.app_640 = FaceAnalysis(name='antelopev2',
 
238
 
239
  self.arcface_model = init_recognition_model('arcface', device='cuda')
240
 
241
+ # Load LoRAs in advance
242
+ user_agent = {
243
+ "file_type": "attn_procs_weights",
244
+ "framework": "pytorch",
245
+ }
246
+ self.loras_state_dict = {}
247
+ self.loras_state_dict['realism'] = self.pipe._fetch_state_dict(os.path.join(os.path.dirname(insightface_root_path), 'optional_loras', 'flux_realism_lora.safetensors'),
248
+ weight_name=None, use_safetensors=True, local_files_only=None, cache_dir=None, force_download=False, proxies=None, token=None, revision=None, subfolder=None, user_agent=user_agent, allow_pickle=True)
249
+ self.loras_state_dict['anti_blur'] = self.pipe._fetch_state_dict(os.path.join(os.path.dirname(insightface_root_path), 'optional_loras', 'flux_anti_blur_lora.safetensors'),
250
+ weight_name=None, use_safetensors=True, local_files_only=None, cache_dir=None, force_download=False, proxies=None, token=None, revision=None, subfolder=None, user_agent=user_agent, allow_pickle=True)
251
+
252
+ def load_loras_state_dict(self, loras):
253
+ names, scales = [],[]
254
+ for lora_name, lora_scale in loras:
255
+ print(f"loading lora state dict of {lora_name}")
256
+ self.pipe.load_lora_weights(self.loras_state_dict[lora_name], adapter_name=lora_name)
257
+ names.append(lora_name)
258
+ scales.append(lora_scale)
259
+
260
+ if len(names) > 0:
261
+ self.pipe.set_adapters(names, adapter_weights=scales)
262
+
263
  def load_loras(self, loras):
264
  names, scales = [],[]
265
  for lora_path, lora_name, lora_scale in loras:
266
  if lora_path != "":
267
  print(f"loading lora {lora_path}")
268
+ self.pipe.load_lora_weights(lora_path, adapter_name=lora_name)
269
  names.append(lora_name)
270
  scales.append(lora_scale)
271