rahul7star commited on
Commit
4238a80
·
verified ·
1 Parent(s): 7b02a6f

Update flux_train_ui.py

Browse files
Files changed (1) hide show
  1. flux_train_ui.py +123 -41
flux_train_ui.py CHANGED
@@ -16,6 +16,23 @@ import json
16
  import yaml
17
  from slugify import slugify
18
  from transformers import AutoProcessor, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  sys.path.insert(0, "ai-toolkit")
21
  from toolkit.job import get_job
@@ -172,47 +189,90 @@ def start_training(
172
  use_more_advanced_options,
173
  more_advanced_options,
174
  ):
 
 
 
 
 
 
 
 
175
  push_to_hub = True
176
- print("flux ttain invoke ====================")
177
  if not lora_name:
 
178
  raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
 
 
179
  try:
180
- if whoami()["auth"]["accessToken"]["role"] == "write" or "repo.write" in whoami()["auth"]["accessToken"]["fineGrained"]["scoped"][0]["permissions"]:
181
- gr.Info(f"Starting training locally {whoami()['name']}. Your LoRA will be available locally and in Hugging Face after it finishes.")
 
 
 
182
  else:
183
  push_to_hub = False
 
184
  gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
185
- except:
186
  push_to_hub = False
 
187
  gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
188
-
189
- print("Started training")
190
  slugged_lora_name = slugify(lora_name)
 
191
 
192
  # Load the default config
193
- with open("config/examples/train_lora_flux_24gb.yaml", "r") as f:
194
- config = yaml.safe_load(f)
 
 
 
 
 
 
 
195
 
196
  # Update the config with user inputs
197
- config["config"]["name"] = slugged_lora_name
198
- config["config"]["process"][0]["model"]["low_vram"] = low_vram
199
- config["config"]["process"][0]["train"]["skip_first_sample"] = True
200
- config["config"]["process"][0]["train"]["steps"] = int(steps)
201
- config["config"]["process"][0]["train"]["lr"] = float(lr)
202
- config["config"]["process"][0]["network"]["linear"] = int(rank)
203
- config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
204
- config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
205
- config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub
206
- if(push_to_hub):
 
 
 
 
 
 
 
 
 
 
 
 
207
  try:
208
  username = whoami()["name"]
209
- except:
 
 
 
 
 
210
  raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?")
211
- config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
212
- config["config"]["process"][0]["save"]["hf_private"] = True
213
  if concept_sentence:
214
  config["config"]["process"][0]["trigger_word"] = concept_sentence
215
-
 
 
216
  if sample_1 or sample_2 or sample_3:
217
  config["config"]["process"][0]["train"]["disable_sampling"] = False
218
  config["config"]["process"][0]["sample"]["sample_every"] = steps
@@ -224,33 +284,56 @@ def start_training(
224
  config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
225
  if sample_3:
226
  config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
 
227
  else:
228
  config["config"]["process"][0]["train"]["disable_sampling"] = True
229
- if(model_to_train == "schnell"):
 
 
 
230
  config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
231
  config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
232
  config["config"]["process"][0]["sample"]["sample_steps"] = 4
233
- if(use_more_advanced_options):
234
- more_advanced_options_dict = yaml.safe_load(more_advanced_options)
235
- config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
236
- print(config)
237
-
 
 
 
 
 
 
 
 
238
  # Save the updated config
239
- # generate a random name for the config
240
  random_config_name = str(uuid.uuid4())
241
  os.makedirs("tmp", exist_ok=True)
242
  config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml"
243
- with open(config_path, "w") as f:
244
- yaml.dump(config, f)
245
-
246
-
247
-
248
- print(f"[INFO] Starting training with config: {config_path}")
249
- # run the job locally
250
- job = get_job(config_path)
251
- job.run()
252
- job.cleanup()
253
-
 
 
 
 
 
 
 
 
 
 
 
254
  return f"Training completed successfully. Model saved as {slugged_lora_name}"
255
 
256
 
@@ -267,7 +350,6 @@ def start_training(
267
 
268
 
269
 
270
-
271
  config_yaml = '''
272
  device: cuda:0
273
  model:
 
16
  import yaml
17
  from slugify import slugify
18
  from transformers import AutoProcessor, AutoModelForCausalLM
19
+ import logging
20
+ import os
21
+ import yaml
22
+ import uuid
23
+ from slugify import slugify
24
+ import gradio as gr # Assuming gr is from gradio for error/warning handling
25
+
26
+ # Configure logging
27
+ logging.basicConfig(
28
+ level=logging.DEBUG,
29
+ format='%(asctime)s - %(levelname)s - %(message)s',
30
+ handlers=[
31
+ logging.StreamHandler(), # Output to console
32
+ logging.FileHandler('training.log') # Save logs to a file
33
+ ]
34
+ )
35
+ logger = logging.getLogger(__name__)
36
 
37
  sys.path.insert(0, "ai-toolkit")
38
  from toolkit.job import get_job
 
189
  use_more_advanced_options,
190
  more_advanced_options,
191
  ):
192
+ logger.info("Starting training process")
193
+ logger.debug(f"Input parameters: lora_name={lora_name}, concept_sentence={concept_sentence}, "
194
+ f"steps={steps}, lr={lr}, rank={rank}, model_to_train={model_to_train}, "
195
+ f"low_vram={low_vram}, dataset_folder={dataset_folder}, "
196
+ f"sample_1={sample_1}, sample_2={sample_2}, sample_3={sample_3}, "
197
+ f"use_more_advanced_options={use_more_advanced_options}, "
198
+ f"more_advanced_options={more_advanced_options}")
199
+
200
  push_to_hub = True
201
+ logger.info("Checking LoRA name")
202
  if not lora_name:
203
+ logger.error("LoRA name is empty or None")
204
  raise gr.Error("You forgot to insert your LoRA name! This name has to be unique.")
205
+
206
+ # Check Hugging Face permissions
207
  try:
208
+ user_info = whoami()
209
+ logger.debug(f"Hugging Face user info: {user_info}")
210
+ if user_info["auth"]["accessToken"]["role"] == "write" or \
211
+ "repo.edit" in user_info["auth"]["accessToken"]["fineGrained"]["scoped"][0]["permissions"]:
212
+ logger.info(f"Starting training locally for user: {user_info['name']}. LoRA will be available locally and on Hugging Face.")
213
  else:
214
  push_to_hub = False
215
+ logger.warning("No write access to Hugging Face. Training locally only.")
216
  gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
217
+ except Exception as e:
218
  push_to_hub = False
219
+ logger.error(f"Error checking Hugging Face permissions: {str(e)}")
220
  gr.Warning("Started training locally. Your LoRa will only be available locally because you didn't login with a `write` token to Hugging Face")
221
+
222
+ logger.info("Training started")
223
  slugged_lora_name = slugify(lora_name)
224
+ logger.debug(f"Slugged LoRA name: {slugged_lora_name}")
225
 
226
  # Load the default config
227
+ config_path_default = "config/examples/train_lora_flux_24gb.yaml"
228
+ logger.info(f"Loading default config from: {config_path_default}")
229
+ try:
230
+ with open(config_path_default, "r") as f:
231
+ config = yaml.safe_load(f)
232
+ logger.debug(f"Loaded config: {config}")
233
+ except Exception as e:
234
+ logger.error(f"Failed to load config from {config_path_default}: {str(e)}")
235
+ raise
236
 
237
  # Update the config with user inputs
238
+ logger.info("Updating config with user inputs")
239
+ try:
240
+ config["config"]["name"] = slugged_lora_name
241
+ config["config"]["process"][0]["model"]["low_vram"] = low_vram
242
+ config["config"]["process"][0]["train"]["skip_first_sample"] = True
243
+ config["config"]["process"][0]["train"]["steps"] = int(steps)
244
+ config["config"]["process"][0]["train"]["lr"] = float(lr)
245
+ config["config"]["process"][0]["network"]["linear"] = int(rank)
246
+ config["config"]["process"][0]["network"]["linear_alpha"] = int(rank)
247
+ config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_folder
248
+ config["config"]["process"][0]["save"]["push_to_hub"] = push_to_hub
249
+ logger.debug(f"Updated config fields: name={slugged_lora_name}, low_vram={low_vram}, steps={steps}, "
250
+ f"lr={lr}, rank={rank}, dataset_folder={dataset_folder}, push_to_hub={push_to_hub}")
251
+ except KeyError as e:
252
+ logger.error(f"Config structure error: Missing key {str(e)}")
253
+ raise
254
+ except Exception as e:
255
+ logger.error(f"Error updating config: {str(e)}")
256
+ raise
257
+
258
+ # Handle Hugging Face repository settings
259
+ if push_to_hub:
260
  try:
261
  username = whoami()["name"]
262
+ logger.debug(f"Hugging Face username: {username}")
263
+ config["config"]["process"][0]["save"]["hf_repo_id"] = f"{username}/{slugged_lora_name}"
264
+ config["config"]["process"][0]["save"]["hf_private"] = True
265
+ logger.debug(f"Set Hugging Face repo: {username}/{slugged_lora_name}")
266
+ except Exception as e:
267
+ logger.error(f"Error retrieving Hugging Face username: {str(e)}")
268
  raise gr.Error("Error trying to retrieve your username. Are you sure you are logged in with Hugging Face?")
269
+
270
+ # Handle concept sentence
271
  if concept_sentence:
272
  config["config"]["process"][0]["trigger_word"] = concept_sentence
273
+ logger.debug(f"Set trigger_word: {concept_sentence}")
274
+
275
+ # Handle sampling prompts
276
  if sample_1 or sample_2 or sample_3:
277
  config["config"]["process"][0]["train"]["disable_sampling"] = False
278
  config["config"]["process"][0]["sample"]["sample_every"] = steps
 
284
  config["config"]["process"][0]["sample"]["prompts"].append(sample_2)
285
  if sample_3:
286
  config["config"]["process"][0]["sample"]["prompts"].append(sample_3)
287
+ logger.debug(f"Sampling enabled with prompts: {config['config']['process'][0]['sample']['prompts']}")
288
  else:
289
  config["config"]["process"][0]["train"]["disable_sampling"] = True
290
+ logger.debug("Sampling disabled")
291
+
292
+ # Handle model selection
293
+ if model_to_train == "schnell":
294
  config["config"]["process"][0]["model"]["name_or_path"] = "black-forest-labs/FLUX.1-schnell"
295
  config["config"]["process"][0]["model"]["assistant_lora_path"] = "ostris/FLUX.1-schnell-training-adapter"
296
  config["config"]["process"][0]["sample"]["sample_steps"] = 4
297
+ logger.debug("Using schnell model configuration")
298
+
299
+ # Handle advanced options
300
+ if use_more_advanced_options:
301
+ try:
302
+ more_advanced_options_dict = yaml.safe_load(more_advanced_options)
303
+ logger.debug(f"Advanced options parsed: {more_advanced_options_dict}")
304
+ config["config"]["process"][0] = recursive_update(config["config"]["process"][0], more_advanced_options_dict)
305
+ logger.debug(f"Config after advanced options update: {config}")
306
+ except Exception as e:
307
+ logger.error(f"Error parsing or applying advanced options: {str(e)}")
308
+ raise
309
+
310
  # Save the updated config
311
+ logger.info("Saving updated config")
312
  random_config_name = str(uuid.uuid4())
313
  os.makedirs("tmp", exist_ok=True)
314
  config_path = f"tmp/{random_config_name}-{slugged_lora_name}.yaml"
315
+ try:
316
+ with open(config_path, "w") as f:
317
+ yaml.dump(config, f)
318
+ logger.info(f"Config saved to: {config_path}")
319
+ except Exception as e:
320
+ logger.error(f"Error saving config to {config_path}: {str(e)}")
321
+ raise
322
+
323
+ # Run the training job
324
+ logger.info(f"Starting training job with config: {config_path}")
325
+ try:
326
+ job = get_job(config_path)
327
+ logger.debug("Job object created successfully")
328
+ job.run()
329
+ logger.info("Training job completed")
330
+ job.cleanup()
331
+ logger.info("Job cleanup completed")
332
+ except Exception as e:
333
+ logger.error(f"Error during training job execution: {str(e)}")
334
+ raise
335
+
336
+ logger.info(f"Training completed successfully. Model saved as {slugged_lora_name}")
337
  return f"Training completed successfully. Model saved as {slugged_lora_name}"
338
 
339
 
 
350
 
351
 
352
 
 
353
  config_yaml = '''
354
  device: cuda:0
355
  model: