MegaTronX commited on
Commit
5c739d0
·
verified ·
1 Parent(s): c6e7ef9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1119 -1119
app.py CHANGED
@@ -1,1119 +1,1119 @@
1
- import os
2
- import sys
3
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
4
- os.environ['GRADIO_ANALYTICS_ENABLED'] = '0'
5
- sys.path.insert(0, os.getcwd())
6
- sys.path.append(os.path.join(os.path.dirname(__file__), 'sd-scripts'))
7
- import subprocess
8
- import gradio as gr
9
- from PIL import Image
10
- import torch
11
- import uuid
12
- import shutil
13
- import json
14
- import yaml
15
- from slugify import slugify
16
- from transformers import AutoProcessor, AutoModelForCausalLM
17
- from gradio_logsview import LogsView, LogsViewRunner
18
- from huggingface_hub import hf_hub_download, HfApi
19
- from library import flux_train_utils, huggingface_util
20
- from argparse import Namespace
21
- import train_network
22
- import toml
23
- import re
24
- MAX_IMAGES = 150
25
-
26
- with open('models.yaml', 'r') as file:
27
- models = yaml.safe_load(file)
28
-
29
- def readme(base_model, lora_name, instance_prompt, sample_prompts):
30
-
31
- # model license
32
- model_config = models[base_model]
33
- model_file = model_config["file"]
34
- base_model_name = model_config["base"]
35
- license = None
36
- license_name = None
37
- license_link = None
38
- license_items = []
39
- if "license" in model_config:
40
- license = model_config["license"]
41
- license_items.append(f"license: {license}")
42
- if "license_name" in model_config:
43
- license_name = model_config["license_name"]
44
- license_items.append(f"license_name: {license_name}")
45
- if "license_link" in model_config:
46
- license_link = model_config["license_link"]
47
- license_items.append(f"license_link: {license_link}")
48
- license_str = "\n".join(license_items)
49
- print(f"license_items={license_items}")
50
- print(f"license_str = {license_str}")
51
-
52
- # tags
53
- tags = [ "text-to-image", "flux", "lora", "diffusers", "template:sd-lora", "fluxgym" ]
54
-
55
- # widgets
56
- widgets = []
57
- sample_image_paths = []
58
- output_name = slugify(lora_name)
59
- samples_dir = resolve_path_without_quotes(f"outputs/{output_name}/sample")
60
- try:
61
- for filename in os.listdir(samples_dir):
62
- # Filename Schema: [name]_[steps]_[index]_[timestamp].png
63
- match = re.search(r"_(\d+)_(\d+)_(\d+)\.png$", filename)
64
- if match:
65
- steps, index, timestamp = int(match.group(1)), int(match.group(2)), int(match.group(3))
66
- sample_image_paths.append((steps, index, f"sample/{filename}"))
67
-
68
- # Sort by numeric index
69
- sample_image_paths.sort(key=lambda x: x[0], reverse=True)
70
-
71
- final_sample_image_paths = sample_image_paths[:len(sample_prompts)]
72
- final_sample_image_paths.sort(key=lambda x: x[1])
73
- for i, prompt in enumerate(sample_prompts):
74
- _, _, image_path = final_sample_image_paths[i]
75
- widgets.append(
76
- {
77
- "text": prompt,
78
- "output": {
79
- "url": image_path
80
- },
81
- }
82
- )
83
- except:
84
- print(f"no samples")
85
- dtype = "torch.bfloat16"
86
- # Construct the README content
87
- readme_content = f"""---
88
- tags:
89
- {yaml.dump(tags, indent=4).strip()}
90
- {"widget:" if os.path.isdir(samples_dir) else ""}
91
- {yaml.dump(widgets, indent=4).strip() if widgets else ""}
92
- base_model: {base_model_name}
93
- {"instance_prompt: " + instance_prompt if instance_prompt else ""}
94
- {license_str}
95
- ---
96
-
97
- # {lora_name}
98
-
99
- A Flux LoRA trained on a local computer with [Fluxgym](https://github.com/cocktailpeanut/fluxgym)
100
-
101
- <Gallery />
102
-
103
- ## Trigger words
104
-
105
- {"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
106
-
107
- ## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, Forge, etc.
108
-
109
- Weights for this model are available in Safetensors format.
110
-
111
- """
112
- return readme_content
113
-
114
- def account_hf():
115
- try:
116
- with open("HF_TOKEN", "r") as file:
117
- token = file.read()
118
- api = HfApi(token=token)
119
- try:
120
- account = api.whoami()
121
- return { "token": token, "account": account['name'] }
122
- except:
123
- return None
124
- except:
125
- return None
126
-
127
- """
128
- hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
129
- """
130
- def logout_hf():
131
- os.remove("HF_TOKEN")
132
- global current_account
133
- current_account = account_hf()
134
- print(f"current_account={current_account}")
135
- return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
136
-
137
-
138
- """
139
- hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
140
- """
141
- def login_hf(hf_token):
142
- api = HfApi(token=hf_token)
143
- try:
144
- account = api.whoami()
145
- if account != None:
146
- if "name" in account:
147
- with open("HF_TOKEN", "w") as file:
148
- file.write(hf_token)
149
- global current_account
150
- current_account = account_hf()
151
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
152
- return gr.update(), gr.update(), gr.update(), gr.update()
153
- except:
154
- print(f"incorrect hf_token")
155
- return gr.update(), gr.update(), gr.update(), gr.update()
156
-
157
- def upload_hf(base_model, lora_rows, repo_owner, repo_name, repo_visibility, hf_token):
158
- src = lora_rows
159
- repo_id = f"{repo_owner}/{repo_name}"
160
- gr.Info(f"Uploading to Huggingface. Please Stand by...", duration=None)
161
- args = Namespace(
162
- huggingface_repo_id=repo_id,
163
- huggingface_repo_type="model",
164
- huggingface_repo_visibility=repo_visibility,
165
- huggingface_path_in_repo="",
166
- huggingface_token=hf_token,
167
- async_upload=False
168
- )
169
- print(f"upload_hf args={args}")
170
- huggingface_util.upload(args=args, src=src)
171
- gr.Info(f"[Upload Complete] https://huggingface.co/{repo_id}", duration=None)
172
-
173
- def load_captioning(uploaded_files, concept_sentence):
174
- uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')]
175
- txt_files = [file for file in uploaded_files if file.endswith('.txt')]
176
- txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files}
177
- updates = []
178
- if len(uploaded_images) <= 1:
179
- raise gr.Error(
180
- "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
181
- )
182
- elif len(uploaded_images) > MAX_IMAGES:
183
- raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training")
184
- # Update for the captioning_area
185
- # for _ in range(3):
186
- updates.append(gr.update(visible=True))
187
- # Update visibility and image for each captioning row and image
188
- for i in range(1, MAX_IMAGES + 1):
189
- # Determine if the current row and image should be visible
190
- visible = i <= len(uploaded_images)
191
-
192
- # Update visibility of the captioning row
193
- updates.append(gr.update(visible=visible))
194
-
195
- # Update for image component - display image if available, otherwise hide
196
- image_value = uploaded_images[i - 1] if visible else None
197
- updates.append(gr.update(value=image_value, visible=visible))
198
-
199
- corresponding_caption = False
200
- if(image_value):
201
- base_name = os.path.splitext(os.path.basename(image_value))[0]
202
- if base_name in txt_files_dict:
203
- with open(txt_files_dict[base_name], 'r') as file:
204
- corresponding_caption = file.read()
205
-
206
- # Update value of captioning area
207
- text_value = corresponding_caption if visible and corresponding_caption else concept_sentence if visible and concept_sentence else None
208
- updates.append(gr.update(value=text_value, visible=visible))
209
-
210
- # Update for the sample caption area
211
- updates.append(gr.update(visible=True))
212
- updates.append(gr.update(visible=True))
213
-
214
- return updates
215
-
216
- def hide_captioning():
217
- return gr.update(visible=False), gr.update(visible=False)
218
-
219
- def resize_image(image_path, output_path, size):
220
- with Image.open(image_path) as img:
221
- width, height = img.size
222
- if width < height:
223
- new_width = size
224
- new_height = int((size/width) * height)
225
- else:
226
- new_height = size
227
- new_width = int((size/height) * width)
228
- print(f"resize {image_path} : {new_width}x{new_height}")
229
- img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
230
- img_resized.save(output_path)
231
-
232
- def create_dataset(destination_folder, size, *inputs):
233
- print("Creating dataset")
234
- images = inputs[0]
235
- if not os.path.exists(destination_folder):
236
- os.makedirs(destination_folder)
237
-
238
- for index, image in enumerate(images):
239
- # copy the images to the datasets folder
240
- new_image_path = shutil.copy(image, destination_folder)
241
-
242
- # if it's a caption text file skip the next bit
243
- ext = os.path.splitext(new_image_path)[-1].lower()
244
- if ext == '.txt':
245
- continue
246
-
247
- # resize the images
248
- resize_image(new_image_path, new_image_path, size)
249
-
250
- # copy the captions
251
-
252
- original_caption = inputs[index + 1]
253
-
254
- image_file_name = os.path.basename(new_image_path)
255
- caption_file_name = os.path.splitext(image_file_name)[0] + ".txt"
256
- caption_path = resolve_path_without_quotes(os.path.join(destination_folder, caption_file_name))
257
- print(f"image_path={new_image_path}, caption_path = {caption_path}, original_caption={original_caption}")
258
- # if caption_path exists, do not write
259
- if os.path.exists(caption_path):
260
- print(f"{caption_path} already exists. use the existing .txt file")
261
- else:
262
- print(f"{caption_path} create a .txt caption file")
263
- with open(caption_path, 'w') as file:
264
- file.write(original_caption)
265
-
266
- print(f"destination_folder {destination_folder}")
267
- return destination_folder
268
-
269
-
270
- def run_captioning(images, concept_sentence, *captions):
271
- print(f"run_captioning")
272
- print(f"concept sentence {concept_sentence}")
273
- print(f"captions {captions}")
274
- #Load internally to not consume resources for training
275
- device = "cuda" if torch.cuda.is_available() else "cpu"
276
- print(f"device={device}")
277
- torch_dtype = torch.float16
278
- model = AutoModelForCausalLM.from_pretrained(
279
- "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True
280
- ).to(device)
281
- processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True)
282
-
283
- captions = list(captions)
284
- for i, image_path in enumerate(images):
285
- print(captions[i])
286
- if isinstance(image_path, str): # If image is a file path
287
- image = Image.open(image_path).convert("RGB")
288
-
289
- prompt = "<DETAILED_CAPTION>"
290
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
291
- print(f"inputs {inputs}")
292
-
293
- generated_ids = model.generate(
294
- input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
295
- )
296
- print(f"generated_ids {generated_ids}")
297
-
298
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
299
- print(f"generated_text: {generated_text}")
300
- parsed_answer = processor.post_process_generation(
301
- generated_text, task=prompt, image_size=(image.width, image.height)
302
- )
303
- print(f"parsed_answer = {parsed_answer}")
304
- caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "")
305
- print(f"caption_text = {caption_text}, concept_sentence={concept_sentence}")
306
- if concept_sentence:
307
- caption_text = f"{concept_sentence} {caption_text}"
308
- captions[i] = caption_text
309
-
310
- yield captions
311
- model.to("cpu")
312
- del model
313
- del processor
314
- if torch.cuda.is_available():
315
- torch.cuda.empty_cache()
316
-
317
- def recursive_update(d, u):
318
- for k, v in u.items():
319
- if isinstance(v, dict) and v:
320
- d[k] = recursive_update(d.get(k, {}), v)
321
- else:
322
- d[k] = v
323
- return d
324
-
325
- def download(base_model):
326
- model = models[base_model]
327
- model_file = model["file"]
328
- repo = model["repo"]
329
-
330
- # download unet
331
- if base_model == "flux-dev" or base_model == "flux-schnell":
332
- unet_folder = "models/unet"
333
- else:
334
- unet_folder = f"models/unet/{repo}"
335
- unet_path = os.path.join(unet_folder, model_file)
336
- if not os.path.exists(unet_path):
337
- os.makedirs(unet_folder, exist_ok=True)
338
- gr.Info(f"Downloading base model: {base_model}. Please wait. (You can check the terminal for the download progress)", duration=None)
339
- print(f"download {base_model}")
340
- hf_hub_download(repo_id=repo, local_dir=unet_folder, filename=model_file)
341
-
342
- # download vae
343
- vae_folder = "models/vae"
344
- vae_path = os.path.join(vae_folder, "ae.sft")
345
- if not os.path.exists(vae_path):
346
- os.makedirs(vae_folder, exist_ok=True)
347
- gr.Info(f"Downloading vae")
348
- print(f"downloading ae.sft...")
349
- hf_hub_download(repo_id="cocktailpeanut/xulf-dev", local_dir=vae_folder, filename="ae.sft")
350
-
351
- # download clip
352
- clip_folder = "models/clip"
353
- clip_l_path = os.path.join(clip_folder, "clip_l.safetensors")
354
- if not os.path.exists(clip_l_path):
355
- os.makedirs(clip_folder, exist_ok=True)
356
- gr.Info(f"Downloading clip...")
357
- print(f"download clip_l.safetensors")
358
- hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="clip_l.safetensors")
359
-
360
- # download t5xxl
361
- t5xxl_path = os.path.join(clip_folder, "t5xxl_fp16.safetensors")
362
- if not os.path.exists(t5xxl_path):
363
- print(f"download t5xxl_fp16.safetensors")
364
- gr.Info(f"Downloading t5xxl...")
365
- hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="t5xxl_fp16.safetensors")
366
-
367
-
368
- def resolve_path(p):
369
- current_dir = os.path.dirname(os.path.abspath(__file__))
370
- norm_path = os.path.normpath(os.path.join(current_dir, p))
371
- return f"\"{norm_path}\""
372
- def resolve_path_without_quotes(p):
373
- current_dir = os.path.dirname(os.path.abspath(__file__))
374
- norm_path = os.path.normpath(os.path.join(current_dir, p))
375
- return norm_path
376
-
377
- def gen_sh(
378
- base_model,
379
- output_name,
380
- resolution,
381
- seed,
382
- workers,
383
- learning_rate,
384
- network_dim,
385
- max_train_epochs,
386
- save_every_n_epochs,
387
- timestep_sampling,
388
- guidance_scale,
389
- vram,
390
- sample_prompts,
391
- sample_every_n_steps,
392
- *advanced_components
393
- ):
394
-
395
- print(f"gen_sh: network_dim:{network_dim}, max_train_epochs={max_train_epochs}, save_every_n_epochs={save_every_n_epochs}, timestep_sampling={timestep_sampling}, guidance_scale={guidance_scale}, vram={vram}, sample_prompts={sample_prompts}, sample_every_n_steps={sample_every_n_steps}")
396
-
397
- output_dir = resolve_path(f"outputs/{output_name}")
398
- sample_prompts_path = resolve_path(f"outputs/{output_name}/sample_prompts.txt")
399
-
400
- line_break = "\\"
401
- file_type = "sh"
402
- if sys.platform == "win32":
403
- line_break = "^"
404
- file_type = "bat"
405
-
406
- ############# Sample args ########################
407
- sample = ""
408
- if len(sample_prompts) > 0 and sample_every_n_steps > 0:
409
- sample = f"""--sample_prompts={sample_prompts_path} --sample_every_n_steps="{sample_every_n_steps}" {line_break}"""
410
-
411
-
412
- ############# Optimizer args ########################
413
- # if vram == "8G":
414
- # optimizer = f"""--optimizer_type adafactor {line_break}
415
- # --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
416
- # --split_mode {line_break}
417
- # --network_args "train_blocks=single" {line_break}
418
- # --lr_scheduler constant_with_warmup {line_break}
419
- # --max_grad_norm 0.0 {line_break}"""
420
- if vram == "16G":
421
- # 16G VRAM
422
- optimizer = f"""--optimizer_type adafactor {line_break}
423
- --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
424
- --lr_scheduler constant_with_warmup {line_break}
425
- --max_grad_norm 0.0 {line_break}"""
426
- elif vram == "12G":
427
- # 12G VRAM
428
- optimizer = f"""--optimizer_type adafactor {line_break}
429
- --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
430
- --split_mode {line_break}
431
- --network_args "train_blocks=single" {line_break}
432
- --lr_scheduler constant_with_warmup {line_break}
433
- --max_grad_norm 0.0 {line_break}"""
434
- else:
435
- # 20G+ VRAM
436
- optimizer = f"--optimizer_type adamw8bit {line_break}"
437
-
438
-
439
- #######################################################
440
- model_config = models[base_model]
441
- model_file = model_config["file"]
442
- repo = model_config["repo"]
443
- if base_model == "flux-dev" or base_model == "flux-schnell":
444
- model_folder = "models/unet"
445
- else:
446
- model_folder = f"models/unet/{repo}"
447
- model_path = os.path.join(model_folder, model_file)
448
- pretrained_model_path = resolve_path(model_path)
449
-
450
- clip_path = resolve_path("models/clip/clip_l.safetensors")
451
- t5_path = resolve_path("models/clip/t5xxl_fp16.safetensors")
452
- ae_path = resolve_path("models/vae/ae.sft")
453
- sh = f"""accelerate launch {line_break}
454
- --mixed_precision bf16 {line_break}
455
- --num_cpu_threads_per_process 1 {line_break}
456
- sd-scripts/flux_train_network.py {line_break}
457
- --pretrained_model_name_or_path {pretrained_model_path} {line_break}
458
- --clip_l {clip_path} {line_break}
459
- --t5xxl {t5_path} {line_break}
460
- --ae {ae_path} {line_break}
461
- --cache_latents_to_disk {line_break}
462
- --save_model_as safetensors {line_break}
463
- --sdpa --persistent_data_loader_workers {line_break}
464
- --max_data_loader_n_workers {workers} {line_break}
465
- --seed {seed} {line_break}
466
- --gradient_checkpointing {line_break}
467
- --mixed_precision bf16 {line_break}
468
- --save_precision bf16 {line_break}
469
- --network_module networks.lora_flux {line_break}
470
- --network_dim {network_dim} {line_break}
471
- {optimizer}{sample}
472
- --learning_rate {learning_rate} {line_break}
473
- --cache_text_encoder_outputs {line_break}
474
- --cache_text_encoder_outputs_to_disk {line_break}
475
- --fp8_base {line_break}
476
- --highvram {line_break}
477
- --max_train_epochs {max_train_epochs} {line_break}
478
- --save_every_n_epochs {save_every_n_epochs} {line_break}
479
- --dataset_config {resolve_path(f"outputs/{output_name}/dataset.toml")} {line_break}
480
- --output_dir {output_dir} {line_break}
481
- --output_name {output_name} {line_break}
482
- --timestep_sampling {timestep_sampling} {line_break}
483
- --discrete_flow_shift 3.1582 {line_break}
484
- --model_prediction_type raw {line_break}
485
- --guidance_scale {guidance_scale} {line_break}
486
- --loss_type l2 {line_break}"""
487
-
488
-
489
-
490
- ############# Advanced args ########################
491
- global advanced_component_ids
492
- global original_advanced_component_values
493
-
494
- # check dirty
495
- print(f"original_advanced_component_values = {original_advanced_component_values}")
496
- advanced_flags = []
497
- for i, current_value in enumerate(advanced_components):
498
- # print(f"compare {advanced_component_ids[i]}: old={original_advanced_component_values[i]}, new={current_value}")
499
- if original_advanced_component_values[i] != current_value:
500
- # dirty
501
- if current_value == True:
502
- # Boolean
503
- advanced_flags.append(advanced_component_ids[i])
504
- else:
505
- # string
506
- advanced_flags.append(f"{advanced_component_ids[i]} {current_value}")
507
-
508
- if len(advanced_flags) > 0:
509
- advanced_flags_str = f" {line_break}\n ".join(advanced_flags)
510
- sh = sh + "\n " + advanced_flags_str
511
-
512
- return sh
513
-
514
- def gen_toml(
515
- dataset_folder,
516
- resolution,
517
- class_tokens,
518
- num_repeats
519
- ):
520
- toml = f"""[general]
521
- shuffle_caption = false
522
- caption_extension = '.txt'
523
- keep_tokens = 1
524
-
525
- [[datasets]]
526
- resolution = {resolution}
527
- batch_size = 1
528
- keep_tokens = 1
529
-
530
- [[datasets.subsets]]
531
- image_dir = '{resolve_path_without_quotes(dataset_folder)}'
532
- class_tokens = '{class_tokens}'
533
- num_repeats = {num_repeats}"""
534
- return toml
535
-
536
- def update_total_steps(max_train_epochs, num_repeats, images):
537
- try:
538
- num_images = len(images)
539
- total_steps = max_train_epochs * num_images * num_repeats
540
- print(f"max_train_epochs={max_train_epochs} num_images={num_images}, num_repeats={num_repeats}, total_steps={total_steps}")
541
- return gr.update(value = total_steps)
542
- except:
543
- print("")
544
-
545
- def set_repo(lora_rows):
546
- selected_name = os.path.basename(lora_rows)
547
- return gr.update(value=selected_name)
548
-
549
- def get_loras():
550
- try:
551
- outputs_path = resolve_path_without_quotes(f"outputs")
552
- files = os.listdir(outputs_path)
553
- folders = [os.path.join(outputs_path, item) for item in files if os.path.isdir(os.path.join(outputs_path, item)) and item != "sample"]
554
- folders.sort(key=lambda file: os.path.getctime(file), reverse=True)
555
- return folders
556
- except Exception as e:
557
- return []
558
-
559
- def get_samples(lora_name):
560
- output_name = slugify(lora_name)
561
- try:
562
- samples_path = resolve_path_without_quotes(f"outputs/{output_name}/sample")
563
- files = [os.path.join(samples_path, file) for file in os.listdir(samples_path)]
564
- files.sort(key=lambda file: os.path.getctime(file), reverse=True)
565
- return files
566
- except:
567
- return []
568
-
569
- def start_training(
570
- base_model,
571
- lora_name,
572
- train_script,
573
- train_config,
574
- sample_prompts,
575
- ):
576
- # write custom script and toml
577
- if not os.path.exists("models"):
578
- os.makedirs("models", exist_ok=True)
579
- if not os.path.exists("outputs"):
580
- os.makedirs("outputs", exist_ok=True)
581
- output_name = slugify(lora_name)
582
- output_dir = resolve_path_without_quotes(f"outputs/{output_name}")
583
- if not os.path.exists(output_dir):
584
- os.makedirs(output_dir, exist_ok=True)
585
-
586
- download(base_model)
587
-
588
- file_type = "sh"
589
- if sys.platform == "win32":
590
- file_type = "bat"
591
-
592
- sh_filename = f"train.{file_type}"
593
- sh_filepath = resolve_path_without_quotes(f"outputs/{output_name}/{sh_filename}")
594
- with open(sh_filepath, 'w', encoding="utf-8") as file:
595
- file.write(train_script)
596
- gr.Info(f"Generated train script at {sh_filename}")
597
-
598
-
599
- dataset_path = resolve_path_without_quotes(f"outputs/{output_name}/dataset.toml")
600
- with open(dataset_path, 'w', encoding="utf-8") as file:
601
- file.write(train_config)
602
- gr.Info(f"Generated dataset.toml")
603
-
604
- sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
605
- with open(sample_prompts_path, 'w', encoding='utf-8') as file:
606
- file.write(sample_prompts)
607
- gr.Info(f"Generated sample_prompts.txt")
608
-
609
- # Train
610
- if sys.platform == "win32":
611
- command = sh_filepath
612
- else:
613
- command = f"bash \"{sh_filepath}\""
614
-
615
- # Use Popen to run the command and capture output in real-time
616
- env = os.environ.copy()
617
- env['PYTHONIOENCODING'] = 'utf-8'
618
- env['LOG_LEVEL'] = 'DEBUG'
619
- runner = LogsViewRunner()
620
- cwd = os.path.dirname(os.path.abspath(__file__))
621
- gr.Info(f"Started training")
622
- yield from runner.run_command([command], cwd=cwd)
623
- yield runner.log(f"Runner: {runner}")
624
-
625
- # Generate Readme
626
- config = toml.loads(train_config)
627
- concept_sentence = config['datasets'][0]['subsets'][0]['class_tokens']
628
- print(f"concept_sentence={concept_sentence}")
629
- print(f"lora_name {lora_name}, concept_sentence={concept_sentence}, output_name={output_name}")
630
- sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
631
- with open(sample_prompts_path, "r", encoding="utf-8") as f:
632
- lines = f.readlines()
633
- sample_prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
634
- md = readme(base_model, lora_name, concept_sentence, sample_prompts)
635
- readme_path = resolve_path_without_quotes(f"outputs/{output_name}/README.md")
636
- with open(readme_path, "w", encoding="utf-8") as f:
637
- f.write(md)
638
-
639
- gr.Info(f"Training Complete. Check the outputs folder for the LoRA files.", duration=None)
640
-
641
-
642
- def update(
643
- base_model,
644
- lora_name,
645
- resolution,
646
- seed,
647
- workers,
648
- class_tokens,
649
- learning_rate,
650
- network_dim,
651
- max_train_epochs,
652
- save_every_n_epochs,
653
- timestep_sampling,
654
- guidance_scale,
655
- vram,
656
- num_repeats,
657
- sample_prompts,
658
- sample_every_n_steps,
659
- *advanced_components,
660
- ):
661
- output_name = slugify(lora_name)
662
- dataset_folder = str(f"datasets/{output_name}")
663
- sh = gen_sh(
664
- base_model,
665
- output_name,
666
- resolution,
667
- seed,
668
- workers,
669
- learning_rate,
670
- network_dim,
671
- max_train_epochs,
672
- save_every_n_epochs,
673
- timestep_sampling,
674
- guidance_scale,
675
- vram,
676
- sample_prompts,
677
- sample_every_n_steps,
678
- *advanced_components,
679
- )
680
- toml = gen_toml(
681
- dataset_folder,
682
- resolution,
683
- class_tokens,
684
- num_repeats
685
- )
686
- return gr.update(value=sh), gr.update(value=toml), dataset_folder
687
-
688
- """
689
- demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, hf_account])
690
- """
691
- def loaded():
692
- global current_account
693
- current_account = account_hf()
694
- print(f"current_account={current_account}")
695
- if current_account != None:
696
- return gr.update(value=current_account["token"]), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
697
- else:
698
- return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
699
-
700
- def update_sample(concept_sentence):
701
- return gr.update(value=concept_sentence)
702
-
703
- def refresh_publish_tab():
704
- loras = get_loras()
705
- return gr.Dropdown(label="Trained LoRAs", choices=loras)
706
-
707
- def init_advanced():
708
- # if basic_args
709
- basic_args = {
710
- 'pretrained_model_name_or_path',
711
- 'clip_l',
712
- 't5xxl',
713
- 'ae',
714
- 'cache_latents_to_disk',
715
- 'save_model_as',
716
- 'sdpa',
717
- 'persistent_data_loader_workers',
718
- 'max_data_loader_n_workers',
719
- 'seed',
720
- 'gradient_checkpointing',
721
- 'mixed_precision',
722
- 'save_precision',
723
- 'network_module',
724
- 'network_dim',
725
- 'learning_rate',
726
- 'cache_text_encoder_outputs',
727
- 'cache_text_encoder_outputs_to_disk',
728
- 'fp8_base',
729
- 'highvram',
730
- 'max_train_epochs',
731
- 'save_every_n_epochs',
732
- 'dataset_config',
733
- 'output_dir',
734
- 'output_name',
735
- 'timestep_sampling',
736
- 'discrete_flow_shift',
737
- 'model_prediction_type',
738
- 'guidance_scale',
739
- 'loss_type',
740
- 'optimizer_type',
741
- 'optimizer_args',
742
- 'lr_scheduler',
743
- 'sample_prompts',
744
- 'sample_every_n_steps',
745
- 'max_grad_norm',
746
- 'split_mode',
747
- 'network_args'
748
- }
749
-
750
- # generate a UI config
751
- # if not in basic_args, create a simple form
752
- parser = train_network.setup_parser()
753
- flux_train_utils.add_flux_train_arguments(parser)
754
- args_info = {}
755
- for action in parser._actions:
756
- if action.dest != 'help': # Skip the default help argument
757
- # if the dest is included in basic_args
758
- args_info[action.dest] = {
759
- "action": action.option_strings, # Option strings like '--use_8bit_adam'
760
- "type": action.type, # Type of the argument
761
- "help": action.help, # Help message
762
- "default": action.default, # Default value, if any
763
- "required": action.required # Whether the argument is required
764
- }
765
- temp = []
766
- for key in args_info:
767
- temp.append({ 'key': key, 'action': args_info[key] })
768
- temp.sort(key=lambda x: x['key'])
769
- advanced_component_ids = []
770
- advanced_components = []
771
- for item in temp:
772
- key = item['key']
773
- action = item['action']
774
- if key in basic_args:
775
- print("")
776
- else:
777
- action_type = str(action['type'])
778
- component = None
779
- with gr.Column(min_width=300):
780
- if action_type == "None":
781
- # radio
782
- component = gr.Checkbox()
783
- # elif action_type == "<class 'str'>":
784
- # component = gr.Textbox()
785
- # elif action_type == "<class 'int'>":
786
- # component = gr.Number(precision=0)
787
- # elif action_type == "<class 'float'>":
788
- # component = gr.Number()
789
- # elif "int_or_float" in action_type:
790
- # component = gr.Number()
791
- else:
792
- component = gr.Textbox(value="")
793
- if component != None:
794
- component.interactive = True
795
- component.elem_id = action['action'][0]
796
- component.label = component.elem_id
797
- component.elem_classes = ["advanced"]
798
- if action['help'] != None:
799
- component.info = action['help']
800
- advanced_components.append(component)
801
- advanced_component_ids.append(component.elem_id)
802
- return advanced_components, advanced_component_ids
803
-
804
-
805
- theme = gr.themes.Monochrome(
806
- text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
807
- font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
808
- )
809
- css = """
810
- @keyframes rotate {
811
- 0% {
812
- transform: rotate(0deg);
813
- }
814
- 100% {
815
- transform: rotate(360deg);
816
- }
817
- }
818
- #advanced_options .advanced:nth-child(even) { background: rgba(0,0,100,0.04) !important; }
819
- h1{font-family: georgia; font-style: italic; font-weight: bold; font-size: 30px; letter-spacing: -1px;}
820
- h3{margin-top: 0}
821
- .tabitem{border: 0px}
822
- .group_padding{}
823
- nav{position: fixed; top: 0; left: 0; right: 0; z-index: 1000; text-align: center; padding: 10px; box-sizing: border-box; display: flex; align-items: center; backdrop-filter: blur(10px); }
824
- nav button { background: none; color: firebrick; font-weight: bold; border: 2px solid firebrick; padding: 5px 10px; border-radius: 5px; font-size: 14px; }
825
- nav img { height: 40px; width: 40px; border-radius: 40px; }
826
- nav img.rotate { animation: rotate 2s linear infinite; }
827
- .flexible { flex-grow: 1; }
828
- .tast-details { margin: 10px 0 !important; }
829
- .toast-wrap { bottom: var(--size-4) !important; top: auto !important; border: none !important; backdrop-filter: blur(10px); }
830
- .toast-title, .toast-text, .toast-icon, .toast-close { color: black !important; font-size: 14px; }
831
- .toast-body { border: none !important; }
832
- #terminal { box-shadow: none !important; margin-bottom: 25px; background: rgba(0,0,0,0.03); }
833
- #terminal .generating { border: none !important; }
834
- #terminal label { position: absolute !important; }
835
- .tabs { margin-top: 50px; }
836
- .hidden { display: none !important; }
837
- .codemirror-wrapper .cm-line { font-size: 12px !important; }
838
- label { font-weight: bold !important; }
839
- #start_training.clicked { background: silver; color: black; }
840
- """
841
-
842
- js = """
843
- function() {
844
- let autoscroll = document.querySelector("#autoscroll")
845
- if (window.iidxx) {
846
- window.clearInterval(window.iidxx);
847
- }
848
- window.iidxx = window.setInterval(function() {
849
- let text=document.querySelector(".codemirror-wrapper .cm-line").innerText.trim()
850
- let img = document.querySelector("#logo")
851
- if (text.length > 0) {
852
- autoscroll.classList.remove("hidden")
853
- if (autoscroll.classList.contains("on")) {
854
- autoscroll.textContent = "Autoscroll ON"
855
- window.scrollTo(0, document.body.scrollHeight, { behavior: "smooth" });
856
- img.classList.add("rotate")
857
- } else {
858
- autoscroll.textContent = "Autoscroll OFF"
859
- img.classList.remove("rotate")
860
- }
861
- }
862
- }, 500);
863
- console.log("autoscroll", autoscroll)
864
- autoscroll.addEventListener("click", (e) => {
865
- autoscroll.classList.toggle("on")
866
- })
867
- function debounce(fn, delay) {
868
- let timeoutId;
869
- return function(...args) {
870
- clearTimeout(timeoutId);
871
- timeoutId = setTimeout(() => fn(...args), delay);
872
- };
873
- }
874
-
875
- function handleClick() {
876
- console.log("refresh")
877
- document.querySelector("#refresh").click();
878
- }
879
- const debouncedClick = debounce(handleClick, 1000);
880
- document.addEventListener("input", debouncedClick);
881
-
882
- document.querySelector("#start_training").addEventListener("click", (e) => {
883
- e.target.classList.add("clicked")
884
- e.target.innerHTML = "Training..."
885
- })
886
-
887
- }
888
- """
889
-
890
- current_account = account_hf()
891
- print(f"current_account={current_account}")
892
-
893
- with gr.Blocks(elem_id="app", theme=theme, css=css, fill_width=True) as demo:
894
- with gr.Tabs() as tabs:
895
- with gr.TabItem("Gym"):
896
- output_components = []
897
- with gr.Row():
898
- gr.HTML("""<nav>
899
- <img id='logo' src='/file=icon.png' width='80' height='80'>
900
- <div class='flexible'></div>
901
- <button id='autoscroll' class='on hidden'></button>
902
- </nav>
903
- """)
904
- with gr.Row(elem_id='container'):
905
- with gr.Column():
906
- gr.Markdown(
907
- """# Step 1. LoRA Info
908
- <p style="margin-top:0">Configure your LoRA train settings.</p>
909
- """, elem_classes="group_padding")
910
- lora_name = gr.Textbox(
911
- label="The name of your LoRA",
912
- info="This has to be a unique name",
913
- placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
914
- )
915
- concept_sentence = gr.Textbox(
916
- elem_id="--concept_sentence",
917
- label="Trigger word/sentence",
918
- info="Trigger word or sentence to be used",
919
- placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
920
- interactive=True,
921
- )
922
- model_names = list(models.keys())
923
- print(f"model_names={model_names}")
924
- base_model = gr.Dropdown(label="Base model (edit the models.yaml file to add more to this list)", choices=model_names, value=model_names[0])
925
- vram = gr.Radio(["20G", "16G", "12G" ], value="20G", label="VRAM", interactive=True)
926
- num_repeats = gr.Number(value=10, precision=0, label="Repeat trains per image", interactive=True)
927
- max_train_epochs = gr.Number(label="Max Train Epochs", value=16, interactive=True)
928
- total_steps = gr.Number(0, interactive=False, label="Expected training steps")
929
- sample_prompts = gr.Textbox("", lines=5, label="Sample Image Prompts (Separate with new lines)", interactive=True)
930
- sample_every_n_steps = gr.Number(0, precision=0, label="Sample Image Every N Steps", interactive=True)
931
- resolution = gr.Number(value=512, precision=0, label="Resize dataset images")
932
- with gr.Column():
933
- gr.Markdown(
934
- """# Step 2. Dataset
935
- <p style="margin-top:0">Make sure the captions include the trigger word.</p>
936
- """, elem_classes="group_padding")
937
- with gr.Group():
938
- images = gr.File(
939
- file_types=["image", ".txt"],
940
- label="Upload your images",
941
- #info="If you want, you can also manually upload caption files that match the image names (example: img0.png => img0.txt)",
942
- file_count="multiple",
943
- interactive=True,
944
- visible=True,
945
- scale=1,
946
- )
947
- with gr.Group(visible=False) as captioning_area:
948
- do_captioning = gr.Button("Add AI captions with Florence-2")
949
- output_components.append(captioning_area)
950
- #output_components = [captioning_area]
951
- caption_list = []
952
- for i in range(1, MAX_IMAGES + 1):
953
- locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
954
- with locals()[f"captioning_row_{i}"]:
955
- locals()[f"image_{i}"] = gr.Image(
956
- type="filepath",
957
- width=111,
958
- height=111,
959
- min_width=111,
960
- interactive=False,
961
- scale=2,
962
- show_label=False,
963
- show_share_button=False,
964
- show_download_button=False,
965
- )
966
- locals()[f"caption_{i}"] = gr.Textbox(
967
- label=f"Caption {i}", scale=15, interactive=True
968
- )
969
-
970
- output_components.append(locals()[f"captioning_row_{i}"])
971
- output_components.append(locals()[f"image_{i}"])
972
- output_components.append(locals()[f"caption_{i}"])
973
- caption_list.append(locals()[f"caption_{i}"])
974
- with gr.Column():
975
- gr.Markdown(
976
- """# Step 3. Train
977
- <p style="margin-top:0">Press start to start training.</p>
978
- """, elem_classes="group_padding")
979
- refresh = gr.Button("Refresh", elem_id="refresh", visible=False)
980
- start = gr.Button("Start training", visible=False, elem_id="start_training")
981
- output_components.append(start)
982
- train_script = gr.Textbox(label="Train script", max_lines=100, interactive=True)
983
- train_config = gr.Textbox(label="Train config", max_lines=100, interactive=True)
984
- with gr.Accordion("Advanced options", elem_id='advanced_options', open=False):
985
- with gr.Row():
986
- with gr.Column(min_width=300):
987
- seed = gr.Number(label="--seed", info="Seed", value=42, interactive=True)
988
- with gr.Column(min_width=300):
989
- workers = gr.Number(label="--max_data_loader_n_workers", info="Number of Workers", value=2, interactive=True)
990
- with gr.Column(min_width=300):
991
- learning_rate = gr.Textbox(label="--learning_rate", info="Learning Rate", value="8e-4", interactive=True)
992
- with gr.Column(min_width=300):
993
- save_every_n_epochs = gr.Number(label="--save_every_n_epochs", info="Save every N epochs", value=4, interactive=True)
994
- with gr.Column(min_width=300):
995
- guidance_scale = gr.Number(label="--guidance_scale", info="Guidance Scale", value=1.0, interactive=True)
996
- with gr.Column(min_width=300):
997
- timestep_sampling = gr.Textbox(label="--timestep_sampling", info="Timestep Sampling", value="shift", interactive=True)
998
- with gr.Column(min_width=300):
999
- network_dim = gr.Number(label="--network_dim", info="LoRA Rank", value=4, minimum=4, maximum=128, step=4, interactive=True)
1000
- advanced_components, advanced_component_ids = init_advanced()
1001
- with gr.Row():
1002
- terminal = LogsView(label="Train log", elem_id="terminal")
1003
- with gr.Row():
1004
- gallery = gr.Gallery(get_samples, inputs=[lora_name], label="Samples", every=10, columns=6)
1005
-
1006
- with gr.TabItem("Publish") as publish_tab:
1007
- hf_token = gr.Textbox(label="Huggingface Token")
1008
- hf_login = gr.Button("Login")
1009
- hf_logout = gr.Button("Logout")
1010
- with gr.Row() as row:
1011
- gr.Markdown("**LoRA**")
1012
- gr.Markdown("**Upload**")
1013
- loras = get_loras()
1014
- with gr.Row():
1015
- lora_rows = refresh_publish_tab()
1016
- with gr.Column():
1017
- with gr.Row():
1018
- repo_owner = gr.Textbox(label="Account", interactive=False)
1019
- repo_name = gr.Textbox(label="Repository Name")
1020
- repo_visibility = gr.Textbox(label="Repository Visibility ('public' or 'private')", value="public")
1021
- upload_button = gr.Button("Upload to HuggingFace")
1022
- upload_button.click(
1023
- fn=upload_hf,
1024
- inputs=[
1025
- base_model,
1026
- lora_rows,
1027
- repo_owner,
1028
- repo_name,
1029
- repo_visibility,
1030
- hf_token,
1031
- ]
1032
- )
1033
- hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
1034
- hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
1035
-
1036
-
1037
- publish_tab.select(refresh_publish_tab, outputs=lora_rows)
1038
- lora_rows.select(fn=set_repo, inputs=[lora_rows], outputs=[repo_name])
1039
-
1040
- dataset_folder = gr.State()
1041
-
1042
- listeners = [
1043
- base_model,
1044
- lora_name,
1045
- resolution,
1046
- seed,
1047
- workers,
1048
- concept_sentence,
1049
- learning_rate,
1050
- network_dim,
1051
- max_train_epochs,
1052
- save_every_n_epochs,
1053
- timestep_sampling,
1054
- guidance_scale,
1055
- vram,
1056
- num_repeats,
1057
- sample_prompts,
1058
- sample_every_n_steps,
1059
- *advanced_components
1060
- ]
1061
- advanced_component_ids = [x.elem_id for x in advanced_components]
1062
- original_advanced_component_values = [comp.value for comp in advanced_components]
1063
- images.upload(
1064
- load_captioning,
1065
- inputs=[images, concept_sentence],
1066
- outputs=output_components
1067
- )
1068
- images.delete(
1069
- load_captioning,
1070
- inputs=[images, concept_sentence],
1071
- outputs=output_components
1072
- )
1073
- images.clear(
1074
- hide_captioning,
1075
- outputs=[captioning_area, start]
1076
- )
1077
- max_train_epochs.change(
1078
- fn=update_total_steps,
1079
- inputs=[max_train_epochs, num_repeats, images],
1080
- outputs=[total_steps]
1081
- )
1082
- num_repeats.change(
1083
- fn=update_total_steps,
1084
- inputs=[max_train_epochs, num_repeats, images],
1085
- outputs=[total_steps]
1086
- )
1087
- images.upload(
1088
- fn=update_total_steps,
1089
- inputs=[max_train_epochs, num_repeats, images],
1090
- outputs=[total_steps]
1091
- )
1092
- images.delete(
1093
- fn=update_total_steps,
1094
- inputs=[max_train_epochs, num_repeats, images],
1095
- outputs=[total_steps]
1096
- )
1097
- images.clear(
1098
- fn=update_total_steps,
1099
- inputs=[max_train_epochs, num_repeats, images],
1100
- outputs=[total_steps]
1101
- )
1102
- concept_sentence.change(fn=update_sample, inputs=[concept_sentence], outputs=sample_prompts)
1103
- start.click(fn=create_dataset, inputs=[dataset_folder, resolution, images] + caption_list, outputs=dataset_folder).then(
1104
- fn=start_training,
1105
- inputs=[
1106
- base_model,
1107
- lora_name,
1108
- train_script,
1109
- train_config,
1110
- sample_prompts,
1111
- ],
1112
- outputs=terminal,
1113
- )
1114
- do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
1115
- demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, repo_owner])
1116
- refresh.click(update, inputs=listeners, outputs=[train_script, train_config, dataset_folder])
1117
- if __name__ == "__main__":
1118
- cwd = os.path.dirname(os.path.abspath(__file__))
1119
- demo.launch(debug=True, show_error=True, allowed_paths=[cwd])
 
1
+ import os
2
+ import sys
3
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
4
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = '0'
5
+ sys.path.insert(0, os.getcwd())
6
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'sd-scripts'))
7
+ import subprocess
8
+ import gradio as gr
9
+ from PIL import Image
10
+ import torch
11
+ import uuid
12
+ import shutil
13
+ import json
14
+ import yaml
15
+ from slugify import slugify
16
+ from transformers import AutoProcessor, AutoModelForCausalLM
17
+ from gradio_logsview import LogsView, LogsViewRunner
18
+ from huggingface_hub import hf_hub_download, HfApi
19
+ #from library import flux_train_utils, huggingface_util
20
+ from argparse import Namespace
21
+ import train_network
22
+ import toml
23
+ import re
24
+ MAX_IMAGES = 150
25
+
26
+ with open('models.yaml', 'r') as file:
27
+ models = yaml.safe_load(file)
28
+
29
+ def readme(base_model, lora_name, instance_prompt, sample_prompts):
30
+
31
+ # model license
32
+ model_config = models[base_model]
33
+ model_file = model_config["file"]
34
+ base_model_name = model_config["base"]
35
+ license = None
36
+ license_name = None
37
+ license_link = None
38
+ license_items = []
39
+ if "license" in model_config:
40
+ license = model_config["license"]
41
+ license_items.append(f"license: {license}")
42
+ if "license_name" in model_config:
43
+ license_name = model_config["license_name"]
44
+ license_items.append(f"license_name: {license_name}")
45
+ if "license_link" in model_config:
46
+ license_link = model_config["license_link"]
47
+ license_items.append(f"license_link: {license_link}")
48
+ license_str = "\n".join(license_items)
49
+ print(f"license_items={license_items}")
50
+ print(f"license_str = {license_str}")
51
+
52
+ # tags
53
+ tags = [ "text-to-image", "flux", "lora", "diffusers", "template:sd-lora", "fluxgym" ]
54
+
55
+ # widgets
56
+ widgets = []
57
+ sample_image_paths = []
58
+ output_name = slugify(lora_name)
59
+ samples_dir = resolve_path_without_quotes(f"outputs/{output_name}/sample")
60
+ try:
61
+ for filename in os.listdir(samples_dir):
62
+ # Filename Schema: [name]_[steps]_[index]_[timestamp].png
63
+ match = re.search(r"_(\d+)_(\d+)_(\d+)\.png$", filename)
64
+ if match:
65
+ steps, index, timestamp = int(match.group(1)), int(match.group(2)), int(match.group(3))
66
+ sample_image_paths.append((steps, index, f"sample/{filename}"))
67
+
68
+ # Sort by numeric index
69
+ sample_image_paths.sort(key=lambda x: x[0], reverse=True)
70
+
71
+ final_sample_image_paths = sample_image_paths[:len(sample_prompts)]
72
+ final_sample_image_paths.sort(key=lambda x: x[1])
73
+ for i, prompt in enumerate(sample_prompts):
74
+ _, _, image_path = final_sample_image_paths[i]
75
+ widgets.append(
76
+ {
77
+ "text": prompt,
78
+ "output": {
79
+ "url": image_path
80
+ },
81
+ }
82
+ )
83
+ except:
84
+ print(f"no samples")
85
+ dtype = "torch.bfloat16"
86
+ # Construct the README content
87
+ readme_content = f"""---
88
+ tags:
89
+ {yaml.dump(tags, indent=4).strip()}
90
+ {"widget:" if os.path.isdir(samples_dir) else ""}
91
+ {yaml.dump(widgets, indent=4).strip() if widgets else ""}
92
+ base_model: {base_model_name}
93
+ {"instance_prompt: " + instance_prompt if instance_prompt else ""}
94
+ {license_str}
95
+ ---
96
+
97
+ # {lora_name}
98
+
99
+ A Flux LoRA trained on a local computer with [Fluxgym](https://github.com/cocktailpeanut/fluxgym)
100
+
101
+ <Gallery />
102
+
103
+ ## Trigger words
104
+
105
+ {"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
106
+
107
+ ## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, Forge, etc.
108
+
109
+ Weights for this model are available in Safetensors format.
110
+
111
+ """
112
+ return readme_content
113
+
114
+ def account_hf():
115
+ try:
116
+ with open("HF_TOKEN", "r") as file:
117
+ token = file.read()
118
+ api = HfApi(token=token)
119
+ try:
120
+ account = api.whoami()
121
+ return { "token": token, "account": account['name'] }
122
+ except:
123
+ return None
124
+ except:
125
+ return None
126
+
127
+ """
128
+ hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
129
+ """
130
+ def logout_hf():
131
+ os.remove("HF_TOKEN")
132
+ global current_account
133
+ current_account = account_hf()
134
+ print(f"current_account={current_account}")
135
+ return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
136
+
137
+
138
+ """
139
+ hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
140
+ """
141
+ def login_hf(hf_token):
142
+ api = HfApi(token=hf_token)
143
+ try:
144
+ account = api.whoami()
145
+ if account != None:
146
+ if "name" in account:
147
+ with open("HF_TOKEN", "w") as file:
148
+ file.write(hf_token)
149
+ global current_account
150
+ current_account = account_hf()
151
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
152
+ return gr.update(), gr.update(), gr.update(), gr.update()
153
+ except:
154
+ print(f"incorrect hf_token")
155
+ return gr.update(), gr.update(), gr.update(), gr.update()
156
+
157
+ def upload_hf(base_model, lora_rows, repo_owner, repo_name, repo_visibility, hf_token):
158
+ src = lora_rows
159
+ repo_id = f"{repo_owner}/{repo_name}"
160
+ gr.Info(f"Uploading to Huggingface. Please Stand by...", duration=None)
161
+ args = Namespace(
162
+ huggingface_repo_id=repo_id,
163
+ huggingface_repo_type="model",
164
+ huggingface_repo_visibility=repo_visibility,
165
+ huggingface_path_in_repo="",
166
+ huggingface_token=hf_token,
167
+ async_upload=False
168
+ )
169
+ print(f"upload_hf args={args}")
170
+ huggingface_util.upload(args=args, src=src)
171
+ gr.Info(f"[Upload Complete] https://huggingface.co/{repo_id}", duration=None)
172
+
173
+ def load_captioning(uploaded_files, concept_sentence):
174
+ uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')]
175
+ txt_files = [file for file in uploaded_files if file.endswith('.txt')]
176
+ txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files}
177
+ updates = []
178
+ if len(uploaded_images) <= 1:
179
+ raise gr.Error(
180
+ "Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
181
+ )
182
+ elif len(uploaded_images) > MAX_IMAGES:
183
+ raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training")
184
+ # Update for the captioning_area
185
+ # for _ in range(3):
186
+ updates.append(gr.update(visible=True))
187
+ # Update visibility and image for each captioning row and image
188
+ for i in range(1, MAX_IMAGES + 1):
189
+ # Determine if the current row and image should be visible
190
+ visible = i <= len(uploaded_images)
191
+
192
+ # Update visibility of the captioning row
193
+ updates.append(gr.update(visible=visible))
194
+
195
+ # Update for image component - display image if available, otherwise hide
196
+ image_value = uploaded_images[i - 1] if visible else None
197
+ updates.append(gr.update(value=image_value, visible=visible))
198
+
199
+ corresponding_caption = False
200
+ if(image_value):
201
+ base_name = os.path.splitext(os.path.basename(image_value))[0]
202
+ if base_name in txt_files_dict:
203
+ with open(txt_files_dict[base_name], 'r') as file:
204
+ corresponding_caption = file.read()
205
+
206
+ # Update value of captioning area
207
+ text_value = corresponding_caption if visible and corresponding_caption else concept_sentence if visible and concept_sentence else None
208
+ updates.append(gr.update(value=text_value, visible=visible))
209
+
210
+ # Update for the sample caption area
211
+ updates.append(gr.update(visible=True))
212
+ updates.append(gr.update(visible=True))
213
+
214
+ return updates
215
+
216
+ def hide_captioning():
217
+ return gr.update(visible=False), gr.update(visible=False)
218
+
219
+ def resize_image(image_path, output_path, size):
220
+ with Image.open(image_path) as img:
221
+ width, height = img.size
222
+ if width < height:
223
+ new_width = size
224
+ new_height = int((size/width) * height)
225
+ else:
226
+ new_height = size
227
+ new_width = int((size/height) * width)
228
+ print(f"resize {image_path} : {new_width}x{new_height}")
229
+ img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
230
+ img_resized.save(output_path)
231
+
232
+ def create_dataset(destination_folder, size, *inputs):
233
+ print("Creating dataset")
234
+ images = inputs[0]
235
+ if not os.path.exists(destination_folder):
236
+ os.makedirs(destination_folder)
237
+
238
+ for index, image in enumerate(images):
239
+ # copy the images to the datasets folder
240
+ new_image_path = shutil.copy(image, destination_folder)
241
+
242
+ # if it's a caption text file skip the next bit
243
+ ext = os.path.splitext(new_image_path)[-1].lower()
244
+ if ext == '.txt':
245
+ continue
246
+
247
+ # resize the images
248
+ resize_image(new_image_path, new_image_path, size)
249
+
250
+ # copy the captions
251
+
252
+ original_caption = inputs[index + 1]
253
+
254
+ image_file_name = os.path.basename(new_image_path)
255
+ caption_file_name = os.path.splitext(image_file_name)[0] + ".txt"
256
+ caption_path = resolve_path_without_quotes(os.path.join(destination_folder, caption_file_name))
257
+ print(f"image_path={new_image_path}, caption_path = {caption_path}, original_caption={original_caption}")
258
+ # if caption_path exists, do not write
259
+ if os.path.exists(caption_path):
260
+ print(f"{caption_path} already exists. use the existing .txt file")
261
+ else:
262
+ print(f"{caption_path} create a .txt caption file")
263
+ with open(caption_path, 'w') as file:
264
+ file.write(original_caption)
265
+
266
+ print(f"destination_folder {destination_folder}")
267
+ return destination_folder
268
+
269
+
270
+ def run_captioning(images, concept_sentence, *captions):
271
+ print(f"run_captioning")
272
+ print(f"concept sentence {concept_sentence}")
273
+ print(f"captions {captions}")
274
+ #Load internally to not consume resources for training
275
+ device = "cuda" if torch.cuda.is_available() else "cpu"
276
+ print(f"device={device}")
277
+ torch_dtype = torch.float16
278
+ model = AutoModelForCausalLM.from_pretrained(
279
+ "multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True
280
+ ).to(device)
281
+ processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True)
282
+
283
+ captions = list(captions)
284
+ for i, image_path in enumerate(images):
285
+ print(captions[i])
286
+ if isinstance(image_path, str): # If image is a file path
287
+ image = Image.open(image_path).convert("RGB")
288
+
289
+ prompt = "<DETAILED_CAPTION>"
290
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
291
+ print(f"inputs {inputs}")
292
+
293
+ generated_ids = model.generate(
294
+ input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
295
+ )
296
+ print(f"generated_ids {generated_ids}")
297
+
298
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
299
+ print(f"generated_text: {generated_text}")
300
+ parsed_answer = processor.post_process_generation(
301
+ generated_text, task=prompt, image_size=(image.width, image.height)
302
+ )
303
+ print(f"parsed_answer = {parsed_answer}")
304
+ caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "")
305
+ print(f"caption_text = {caption_text}, concept_sentence={concept_sentence}")
306
+ if concept_sentence:
307
+ caption_text = f"{concept_sentence} {caption_text}"
308
+ captions[i] = caption_text
309
+
310
+ yield captions
311
+ model.to("cpu")
312
+ del model
313
+ del processor
314
+ if torch.cuda.is_available():
315
+ torch.cuda.empty_cache()
316
+
317
+ def recursive_update(d, u):
318
+ for k, v in u.items():
319
+ if isinstance(v, dict) and v:
320
+ d[k] = recursive_update(d.get(k, {}), v)
321
+ else:
322
+ d[k] = v
323
+ return d
324
+
325
+ def download(base_model):
326
+ model = models[base_model]
327
+ model_file = model["file"]
328
+ repo = model["repo"]
329
+
330
+ # download unet
331
+ if base_model == "flux-dev" or base_model == "flux-schnell":
332
+ unet_folder = "models/unet"
333
+ else:
334
+ unet_folder = f"models/unet/{repo}"
335
+ unet_path = os.path.join(unet_folder, model_file)
336
+ if not os.path.exists(unet_path):
337
+ os.makedirs(unet_folder, exist_ok=True)
338
+ gr.Info(f"Downloading base model: {base_model}. Please wait. (You can check the terminal for the download progress)", duration=None)
339
+ print(f"download {base_model}")
340
+ hf_hub_download(repo_id=repo, local_dir=unet_folder, filename=model_file)
341
+
342
+ # download vae
343
+ vae_folder = "models/vae"
344
+ vae_path = os.path.join(vae_folder, "ae.sft")
345
+ if not os.path.exists(vae_path):
346
+ os.makedirs(vae_folder, exist_ok=True)
347
+ gr.Info(f"Downloading vae")
348
+ print(f"downloading ae.sft...")
349
+ hf_hub_download(repo_id="cocktailpeanut/xulf-dev", local_dir=vae_folder, filename="ae.sft")
350
+
351
+ # download clip
352
+ clip_folder = "models/clip"
353
+ clip_l_path = os.path.join(clip_folder, "clip_l.safetensors")
354
+ if not os.path.exists(clip_l_path):
355
+ os.makedirs(clip_folder, exist_ok=True)
356
+ gr.Info(f"Downloading clip...")
357
+ print(f"download clip_l.safetensors")
358
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="clip_l.safetensors")
359
+
360
+ # download t5xxl
361
+ t5xxl_path = os.path.join(clip_folder, "t5xxl_fp16.safetensors")
362
+ if not os.path.exists(t5xxl_path):
363
+ print(f"download t5xxl_fp16.safetensors")
364
+ gr.Info(f"Downloading t5xxl...")
365
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="t5xxl_fp16.safetensors")
366
+
367
+
368
+ def resolve_path(p):
369
+ current_dir = os.path.dirname(os.path.abspath(__file__))
370
+ norm_path = os.path.normpath(os.path.join(current_dir, p))
371
+ return f"\"{norm_path}\""
372
+ def resolve_path_without_quotes(p):
373
+ current_dir = os.path.dirname(os.path.abspath(__file__))
374
+ norm_path = os.path.normpath(os.path.join(current_dir, p))
375
+ return norm_path
376
+
377
+ def gen_sh(
378
+ base_model,
379
+ output_name,
380
+ resolution,
381
+ seed,
382
+ workers,
383
+ learning_rate,
384
+ network_dim,
385
+ max_train_epochs,
386
+ save_every_n_epochs,
387
+ timestep_sampling,
388
+ guidance_scale,
389
+ vram,
390
+ sample_prompts,
391
+ sample_every_n_steps,
392
+ *advanced_components
393
+ ):
394
+
395
+ print(f"gen_sh: network_dim:{network_dim}, max_train_epochs={max_train_epochs}, save_every_n_epochs={save_every_n_epochs}, timestep_sampling={timestep_sampling}, guidance_scale={guidance_scale}, vram={vram}, sample_prompts={sample_prompts}, sample_every_n_steps={sample_every_n_steps}")
396
+
397
+ output_dir = resolve_path(f"outputs/{output_name}")
398
+ sample_prompts_path = resolve_path(f"outputs/{output_name}/sample_prompts.txt")
399
+
400
+ line_break = "\\"
401
+ file_type = "sh"
402
+ if sys.platform == "win32":
403
+ line_break = "^"
404
+ file_type = "bat"
405
+
406
+ ############# Sample args ########################
407
+ sample = ""
408
+ if len(sample_prompts) > 0 and sample_every_n_steps > 0:
409
+ sample = f"""--sample_prompts={sample_prompts_path} --sample_every_n_steps="{sample_every_n_steps}" {line_break}"""
410
+
411
+
412
+ ############# Optimizer args ########################
413
+ # if vram == "8G":
414
+ # optimizer = f"""--optimizer_type adafactor {line_break}
415
+ # --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
416
+ # --split_mode {line_break}
417
+ # --network_args "train_blocks=single" {line_break}
418
+ # --lr_scheduler constant_with_warmup {line_break}
419
+ # --max_grad_norm 0.0 {line_break}"""
420
+ if vram == "16G":
421
+ # 16G VRAM
422
+ optimizer = f"""--optimizer_type adafactor {line_break}
423
+ --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
424
+ --lr_scheduler constant_with_warmup {line_break}
425
+ --max_grad_norm 0.0 {line_break}"""
426
+ elif vram == "12G":
427
+ # 12G VRAM
428
+ optimizer = f"""--optimizer_type adafactor {line_break}
429
+ --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
430
+ --split_mode {line_break}
431
+ --network_args "train_blocks=single" {line_break}
432
+ --lr_scheduler constant_with_warmup {line_break}
433
+ --max_grad_norm 0.0 {line_break}"""
434
+ else:
435
+ # 20G+ VRAM
436
+ optimizer = f"--optimizer_type adamw8bit {line_break}"
437
+
438
+
439
+ #######################################################
440
+ model_config = models[base_model]
441
+ model_file = model_config["file"]
442
+ repo = model_config["repo"]
443
+ if base_model == "flux-dev" or base_model == "flux-schnell":
444
+ model_folder = "models/unet"
445
+ else:
446
+ model_folder = f"models/unet/{repo}"
447
+ model_path = os.path.join(model_folder, model_file)
448
+ pretrained_model_path = resolve_path(model_path)
449
+
450
+ clip_path = resolve_path("models/clip/clip_l.safetensors")
451
+ t5_path = resolve_path("models/clip/t5xxl_fp16.safetensors")
452
+ ae_path = resolve_path("models/vae/ae.sft")
453
+ sh = f"""accelerate launch {line_break}
454
+ --mixed_precision bf16 {line_break}
455
+ --num_cpu_threads_per_process 1 {line_break}
456
+ sd-scripts/flux_train_network.py {line_break}
457
+ --pretrained_model_name_or_path {pretrained_model_path} {line_break}
458
+ --clip_l {clip_path} {line_break}
459
+ --t5xxl {t5_path} {line_break}
460
+ --ae {ae_path} {line_break}
461
+ --cache_latents_to_disk {line_break}
462
+ --save_model_as safetensors {line_break}
463
+ --sdpa --persistent_data_loader_workers {line_break}
464
+ --max_data_loader_n_workers {workers} {line_break}
465
+ --seed {seed} {line_break}
466
+ --gradient_checkpointing {line_break}
467
+ --mixed_precision bf16 {line_break}
468
+ --save_precision bf16 {line_break}
469
+ --network_module networks.lora_flux {line_break}
470
+ --network_dim {network_dim} {line_break}
471
+ {optimizer}{sample}
472
+ --learning_rate {learning_rate} {line_break}
473
+ --cache_text_encoder_outputs {line_break}
474
+ --cache_text_encoder_outputs_to_disk {line_break}
475
+ --fp8_base {line_break}
476
+ --highvram {line_break}
477
+ --max_train_epochs {max_train_epochs} {line_break}
478
+ --save_every_n_epochs {save_every_n_epochs} {line_break}
479
+ --dataset_config {resolve_path(f"outputs/{output_name}/dataset.toml")} {line_break}
480
+ --output_dir {output_dir} {line_break}
481
+ --output_name {output_name} {line_break}
482
+ --timestep_sampling {timestep_sampling} {line_break}
483
+ --discrete_flow_shift 3.1582 {line_break}
484
+ --model_prediction_type raw {line_break}
485
+ --guidance_scale {guidance_scale} {line_break}
486
+ --loss_type l2 {line_break}"""
487
+
488
+
489
+
490
+ ############# Advanced args ########################
491
+ global advanced_component_ids
492
+ global original_advanced_component_values
493
+
494
+ # check dirty
495
+ print(f"original_advanced_component_values = {original_advanced_component_values}")
496
+ advanced_flags = []
497
+ for i, current_value in enumerate(advanced_components):
498
+ # print(f"compare {advanced_component_ids[i]}: old={original_advanced_component_values[i]}, new={current_value}")
499
+ if original_advanced_component_values[i] != current_value:
500
+ # dirty
501
+ if current_value == True:
502
+ # Boolean
503
+ advanced_flags.append(advanced_component_ids[i])
504
+ else:
505
+ # string
506
+ advanced_flags.append(f"{advanced_component_ids[i]} {current_value}")
507
+
508
+ if len(advanced_flags) > 0:
509
+ advanced_flags_str = f" {line_break}\n ".join(advanced_flags)
510
+ sh = sh + "\n " + advanced_flags_str
511
+
512
+ return sh
513
+
514
+ def gen_toml(
515
+ dataset_folder,
516
+ resolution,
517
+ class_tokens,
518
+ num_repeats
519
+ ):
520
+ toml = f"""[general]
521
+ shuffle_caption = false
522
+ caption_extension = '.txt'
523
+ keep_tokens = 1
524
+
525
+ [[datasets]]
526
+ resolution = {resolution}
527
+ batch_size = 1
528
+ keep_tokens = 1
529
+
530
+ [[datasets.subsets]]
531
+ image_dir = '{resolve_path_without_quotes(dataset_folder)}'
532
+ class_tokens = '{class_tokens}'
533
+ num_repeats = {num_repeats}"""
534
+ return toml
535
+
536
+ def update_total_steps(max_train_epochs, num_repeats, images):
537
+ try:
538
+ num_images = len(images)
539
+ total_steps = max_train_epochs * num_images * num_repeats
540
+ print(f"max_train_epochs={max_train_epochs} num_images={num_images}, num_repeats={num_repeats}, total_steps={total_steps}")
541
+ return gr.update(value = total_steps)
542
+ except:
543
+ print("")
544
+
545
+ def set_repo(lora_rows):
546
+ selected_name = os.path.basename(lora_rows)
547
+ return gr.update(value=selected_name)
548
+
549
+ def get_loras():
550
+ try:
551
+ outputs_path = resolve_path_without_quotes(f"outputs")
552
+ files = os.listdir(outputs_path)
553
+ folders = [os.path.join(outputs_path, item) for item in files if os.path.isdir(os.path.join(outputs_path, item)) and item != "sample"]
554
+ folders.sort(key=lambda file: os.path.getctime(file), reverse=True)
555
+ return folders
556
+ except Exception as e:
557
+ return []
558
+
559
+ def get_samples(lora_name):
560
+ output_name = slugify(lora_name)
561
+ try:
562
+ samples_path = resolve_path_without_quotes(f"outputs/{output_name}/sample")
563
+ files = [os.path.join(samples_path, file) for file in os.listdir(samples_path)]
564
+ files.sort(key=lambda file: os.path.getctime(file), reverse=True)
565
+ return files
566
+ except:
567
+ return []
568
+
569
+ def start_training(
570
+ base_model,
571
+ lora_name,
572
+ train_script,
573
+ train_config,
574
+ sample_prompts,
575
+ ):
576
+ # write custom script and toml
577
+ if not os.path.exists("models"):
578
+ os.makedirs("models", exist_ok=True)
579
+ if not os.path.exists("outputs"):
580
+ os.makedirs("outputs", exist_ok=True)
581
+ output_name = slugify(lora_name)
582
+ output_dir = resolve_path_without_quotes(f"outputs/{output_name}")
583
+ if not os.path.exists(output_dir):
584
+ os.makedirs(output_dir, exist_ok=True)
585
+
586
+ download(base_model)
587
+
588
+ file_type = "sh"
589
+ if sys.platform == "win32":
590
+ file_type = "bat"
591
+
592
+ sh_filename = f"train.{file_type}"
593
+ sh_filepath = resolve_path_without_quotes(f"outputs/{output_name}/{sh_filename}")
594
+ with open(sh_filepath, 'w', encoding="utf-8") as file:
595
+ file.write(train_script)
596
+ gr.Info(f"Generated train script at {sh_filename}")
597
+
598
+
599
+ dataset_path = resolve_path_without_quotes(f"outputs/{output_name}/dataset.toml")
600
+ with open(dataset_path, 'w', encoding="utf-8") as file:
601
+ file.write(train_config)
602
+ gr.Info(f"Generated dataset.toml")
603
+
604
+ sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
605
+ with open(sample_prompts_path, 'w', encoding='utf-8') as file:
606
+ file.write(sample_prompts)
607
+ gr.Info(f"Generated sample_prompts.txt")
608
+
609
+ # Train
610
+ if sys.platform == "win32":
611
+ command = sh_filepath
612
+ else:
613
+ command = f"bash \"{sh_filepath}\""
614
+
615
+ # Use Popen to run the command and capture output in real-time
616
+ env = os.environ.copy()
617
+ env['PYTHONIOENCODING'] = 'utf-8'
618
+ env['LOG_LEVEL'] = 'DEBUG'
619
+ runner = LogsViewRunner()
620
+ cwd = os.path.dirname(os.path.abspath(__file__))
621
+ gr.Info(f"Started training")
622
+ yield from runner.run_command([command], cwd=cwd)
623
+ yield runner.log(f"Runner: {runner}")
624
+
625
+ # Generate Readme
626
+ config = toml.loads(train_config)
627
+ concept_sentence = config['datasets'][0]['subsets'][0]['class_tokens']
628
+ print(f"concept_sentence={concept_sentence}")
629
+ print(f"lora_name {lora_name}, concept_sentence={concept_sentence}, output_name={output_name}")
630
+ sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
631
+ with open(sample_prompts_path, "r", encoding="utf-8") as f:
632
+ lines = f.readlines()
633
+ sample_prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
634
+ md = readme(base_model, lora_name, concept_sentence, sample_prompts)
635
+ readme_path = resolve_path_without_quotes(f"outputs/{output_name}/README.md")
636
+ with open(readme_path, "w", encoding="utf-8") as f:
637
+ f.write(md)
638
+
639
+ gr.Info(f"Training Complete. Check the outputs folder for the LoRA files.", duration=None)
640
+
641
+
642
+ def update(
643
+ base_model,
644
+ lora_name,
645
+ resolution,
646
+ seed,
647
+ workers,
648
+ class_tokens,
649
+ learning_rate,
650
+ network_dim,
651
+ max_train_epochs,
652
+ save_every_n_epochs,
653
+ timestep_sampling,
654
+ guidance_scale,
655
+ vram,
656
+ num_repeats,
657
+ sample_prompts,
658
+ sample_every_n_steps,
659
+ *advanced_components,
660
+ ):
661
+ output_name = slugify(lora_name)
662
+ dataset_folder = str(f"datasets/{output_name}")
663
+ sh = gen_sh(
664
+ base_model,
665
+ output_name,
666
+ resolution,
667
+ seed,
668
+ workers,
669
+ learning_rate,
670
+ network_dim,
671
+ max_train_epochs,
672
+ save_every_n_epochs,
673
+ timestep_sampling,
674
+ guidance_scale,
675
+ vram,
676
+ sample_prompts,
677
+ sample_every_n_steps,
678
+ *advanced_components,
679
+ )
680
+ toml = gen_toml(
681
+ dataset_folder,
682
+ resolution,
683
+ class_tokens,
684
+ num_repeats
685
+ )
686
+ return gr.update(value=sh), gr.update(value=toml), dataset_folder
687
+
688
+ """
689
+ demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, hf_account])
690
+ """
691
+ def loaded():
692
+ global current_account
693
+ current_account = account_hf()
694
+ print(f"current_account={current_account}")
695
+ if current_account != None:
696
+ return gr.update(value=current_account["token"]), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
697
+ else:
698
+ return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
699
+
700
+ def update_sample(concept_sentence):
701
+ return gr.update(value=concept_sentence)
702
+
703
+ def refresh_publish_tab():
704
+ loras = get_loras()
705
+ return gr.Dropdown(label="Trained LoRAs", choices=loras)
706
+
707
+ def init_advanced():
708
+ # if basic_args
709
+ basic_args = {
710
+ 'pretrained_model_name_or_path',
711
+ 'clip_l',
712
+ 't5xxl',
713
+ 'ae',
714
+ 'cache_latents_to_disk',
715
+ 'save_model_as',
716
+ 'sdpa',
717
+ 'persistent_data_loader_workers',
718
+ 'max_data_loader_n_workers',
719
+ 'seed',
720
+ 'gradient_checkpointing',
721
+ 'mixed_precision',
722
+ 'save_precision',
723
+ 'network_module',
724
+ 'network_dim',
725
+ 'learning_rate',
726
+ 'cache_text_encoder_outputs',
727
+ 'cache_text_encoder_outputs_to_disk',
728
+ 'fp8_base',
729
+ 'highvram',
730
+ 'max_train_epochs',
731
+ 'save_every_n_epochs',
732
+ 'dataset_config',
733
+ 'output_dir',
734
+ 'output_name',
735
+ 'timestep_sampling',
736
+ 'discrete_flow_shift',
737
+ 'model_prediction_type',
738
+ 'guidance_scale',
739
+ 'loss_type',
740
+ 'optimizer_type',
741
+ 'optimizer_args',
742
+ 'lr_scheduler',
743
+ 'sample_prompts',
744
+ 'sample_every_n_steps',
745
+ 'max_grad_norm',
746
+ 'split_mode',
747
+ 'network_args'
748
+ }
749
+
750
+ # generate a UI config
751
+ # if not in basic_args, create a simple form
752
+ parser = train_network.setup_parser()
753
+ flux_train_utils.add_flux_train_arguments(parser)
754
+ args_info = {}
755
+ for action in parser._actions:
756
+ if action.dest != 'help': # Skip the default help argument
757
+ # if the dest is included in basic_args
758
+ args_info[action.dest] = {
759
+ "action": action.option_strings, # Option strings like '--use_8bit_adam'
760
+ "type": action.type, # Type of the argument
761
+ "help": action.help, # Help message
762
+ "default": action.default, # Default value, if any
763
+ "required": action.required # Whether the argument is required
764
+ }
765
+ temp = []
766
+ for key in args_info:
767
+ temp.append({ 'key': key, 'action': args_info[key] })
768
+ temp.sort(key=lambda x: x['key'])
769
+ advanced_component_ids = []
770
+ advanced_components = []
771
+ for item in temp:
772
+ key = item['key']
773
+ action = item['action']
774
+ if key in basic_args:
775
+ print("")
776
+ else:
777
+ action_type = str(action['type'])
778
+ component = None
779
+ with gr.Column(min_width=300):
780
+ if action_type == "None":
781
+ # radio
782
+ component = gr.Checkbox()
783
+ # elif action_type == "<class 'str'>":
784
+ # component = gr.Textbox()
785
+ # elif action_type == "<class 'int'>":
786
+ # component = gr.Number(precision=0)
787
+ # elif action_type == "<class 'float'>":
788
+ # component = gr.Number()
789
+ # elif "int_or_float" in action_type:
790
+ # component = gr.Number()
791
+ else:
792
+ component = gr.Textbox(value="")
793
+ if component != None:
794
+ component.interactive = True
795
+ component.elem_id = action['action'][0]
796
+ component.label = component.elem_id
797
+ component.elem_classes = ["advanced"]
798
+ if action['help'] != None:
799
+ component.info = action['help']
800
+ advanced_components.append(component)
801
+ advanced_component_ids.append(component.elem_id)
802
+ return advanced_components, advanced_component_ids
803
+
804
+
805
+ theme = gr.themes.Monochrome(
806
+ text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
807
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
808
+ )
809
+ css = """
810
+ @keyframes rotate {
811
+ 0% {
812
+ transform: rotate(0deg);
813
+ }
814
+ 100% {
815
+ transform: rotate(360deg);
816
+ }
817
+ }
818
+ #advanced_options .advanced:nth-child(even) { background: rgba(0,0,100,0.04) !important; }
819
+ h1{font-family: georgia; font-style: italic; font-weight: bold; font-size: 30px; letter-spacing: -1px;}
820
+ h3{margin-top: 0}
821
+ .tabitem{border: 0px}
822
+ .group_padding{}
823
+ nav{position: fixed; top: 0; left: 0; right: 0; z-index: 1000; text-align: center; padding: 10px; box-sizing: border-box; display: flex; align-items: center; backdrop-filter: blur(10px); }
824
+ nav button { background: none; color: firebrick; font-weight: bold; border: 2px solid firebrick; padding: 5px 10px; border-radius: 5px; font-size: 14px; }
825
+ nav img { height: 40px; width: 40px; border-radius: 40px; }
826
+ nav img.rotate { animation: rotate 2s linear infinite; }
827
+ .flexible { flex-grow: 1; }
828
+ .tast-details { margin: 10px 0 !important; }
829
+ .toast-wrap { bottom: var(--size-4) !important; top: auto !important; border: none !important; backdrop-filter: blur(10px); }
830
+ .toast-title, .toast-text, .toast-icon, .toast-close { color: black !important; font-size: 14px; }
831
+ .toast-body { border: none !important; }
832
+ #terminal { box-shadow: none !important; margin-bottom: 25px; background: rgba(0,0,0,0.03); }
833
+ #terminal .generating { border: none !important; }
834
+ #terminal label { position: absolute !important; }
835
+ .tabs { margin-top: 50px; }
836
+ .hidden { display: none !important; }
837
+ .codemirror-wrapper .cm-line { font-size: 12px !important; }
838
+ label { font-weight: bold !important; }
839
+ #start_training.clicked { background: silver; color: black; }
840
+ """
841
+
842
+ js = """
843
+ function() {
844
+ let autoscroll = document.querySelector("#autoscroll")
845
+ if (window.iidxx) {
846
+ window.clearInterval(window.iidxx);
847
+ }
848
+ window.iidxx = window.setInterval(function() {
849
+ let text=document.querySelector(".codemirror-wrapper .cm-line").innerText.trim()
850
+ let img = document.querySelector("#logo")
851
+ if (text.length > 0) {
852
+ autoscroll.classList.remove("hidden")
853
+ if (autoscroll.classList.contains("on")) {
854
+ autoscroll.textContent = "Autoscroll ON"
855
+ window.scrollTo(0, document.body.scrollHeight, { behavior: "smooth" });
856
+ img.classList.add("rotate")
857
+ } else {
858
+ autoscroll.textContent = "Autoscroll OFF"
859
+ img.classList.remove("rotate")
860
+ }
861
+ }
862
+ }, 500);
863
+ console.log("autoscroll", autoscroll)
864
+ autoscroll.addEventListener("click", (e) => {
865
+ autoscroll.classList.toggle("on")
866
+ })
867
+ function debounce(fn, delay) {
868
+ let timeoutId;
869
+ return function(...args) {
870
+ clearTimeout(timeoutId);
871
+ timeoutId = setTimeout(() => fn(...args), delay);
872
+ };
873
+ }
874
+
875
+ function handleClick() {
876
+ console.log("refresh")
877
+ document.querySelector("#refresh").click();
878
+ }
879
+ const debouncedClick = debounce(handleClick, 1000);
880
+ document.addEventListener("input", debouncedClick);
881
+
882
+ document.querySelector("#start_training").addEventListener("click", (e) => {
883
+ e.target.classList.add("clicked")
884
+ e.target.innerHTML = "Training..."
885
+ })
886
+
887
+ }
888
+ """
889
+
890
+ current_account = account_hf()
891
+ print(f"current_account={current_account}")
892
+
893
+ with gr.Blocks(elem_id="app", theme=theme, css=css, fill_width=True) as demo:
894
+ with gr.Tabs() as tabs:
895
+ with gr.TabItem("Gym"):
896
+ output_components = []
897
+ with gr.Row():
898
+ gr.HTML("""<nav>
899
+ <img id='logo' src='/file=icon.png' width='80' height='80'>
900
+ <div class='flexible'></div>
901
+ <button id='autoscroll' class='on hidden'></button>
902
+ </nav>
903
+ """)
904
+ with gr.Row(elem_id='container'):
905
+ with gr.Column():
906
+ gr.Markdown(
907
+ """# Step 1. LoRA Info
908
+ <p style="margin-top:0">Configure your LoRA train settings.</p>
909
+ """, elem_classes="group_padding")
910
+ lora_name = gr.Textbox(
911
+ label="The name of your LoRA",
912
+ info="This has to be a unique name",
913
+ placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
914
+ )
915
+ concept_sentence = gr.Textbox(
916
+ elem_id="--concept_sentence",
917
+ label="Trigger word/sentence",
918
+ info="Trigger word or sentence to be used",
919
+ placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
920
+ interactive=True,
921
+ )
922
+ model_names = list(models.keys())
923
+ print(f"model_names={model_names}")
924
+ base_model = gr.Dropdown(label="Base model (edit the models.yaml file to add more to this list)", choices=model_names, value=model_names[0])
925
+ vram = gr.Radio(["20G", "16G", "12G" ], value="20G", label="VRAM", interactive=True)
926
+ num_repeats = gr.Number(value=10, precision=0, label="Repeat trains per image", interactive=True)
927
+ max_train_epochs = gr.Number(label="Max Train Epochs", value=16, interactive=True)
928
+ total_steps = gr.Number(0, interactive=False, label="Expected training steps")
929
+ sample_prompts = gr.Textbox("", lines=5, label="Sample Image Prompts (Separate with new lines)", interactive=True)
930
+ sample_every_n_steps = gr.Number(0, precision=0, label="Sample Image Every N Steps", interactive=True)
931
+ resolution = gr.Number(value=512, precision=0, label="Resize dataset images")
932
+ with gr.Column():
933
+ gr.Markdown(
934
+ """# Step 2. Dataset
935
+ <p style="margin-top:0">Make sure the captions include the trigger word.</p>
936
+ """, elem_classes="group_padding")
937
+ with gr.Group():
938
+ images = gr.File(
939
+ file_types=["image", ".txt"],
940
+ label="Upload your images",
941
+ #info="If you want, you can also manually upload caption files that match the image names (example: img0.png => img0.txt)",
942
+ file_count="multiple",
943
+ interactive=True,
944
+ visible=True,
945
+ scale=1,
946
+ )
947
+ with gr.Group(visible=False) as captioning_area:
948
+ do_captioning = gr.Button("Add AI captions with Florence-2")
949
+ output_components.append(captioning_area)
950
+ #output_components = [captioning_area]
951
+ caption_list = []
952
+ for i in range(1, MAX_IMAGES + 1):
953
+ locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
954
+ with locals()[f"captioning_row_{i}"]:
955
+ locals()[f"image_{i}"] = gr.Image(
956
+ type="filepath",
957
+ width=111,
958
+ height=111,
959
+ min_width=111,
960
+ interactive=False,
961
+ scale=2,
962
+ show_label=False,
963
+ show_share_button=False,
964
+ show_download_button=False,
965
+ )
966
+ locals()[f"caption_{i}"] = gr.Textbox(
967
+ label=f"Caption {i}", scale=15, interactive=True
968
+ )
969
+
970
+ output_components.append(locals()[f"captioning_row_{i}"])
971
+ output_components.append(locals()[f"image_{i}"])
972
+ output_components.append(locals()[f"caption_{i}"])
973
+ caption_list.append(locals()[f"caption_{i}"])
974
+ with gr.Column():
975
+ gr.Markdown(
976
+ """# Step 3. Train
977
+ <p style="margin-top:0">Press start to start training.</p>
978
+ """, elem_classes="group_padding")
979
+ refresh = gr.Button("Refresh", elem_id="refresh", visible=False)
980
+ start = gr.Button("Start training", visible=False, elem_id="start_training")
981
+ output_components.append(start)
982
+ train_script = gr.Textbox(label="Train script", max_lines=100, interactive=True)
983
+ train_config = gr.Textbox(label="Train config", max_lines=100, interactive=True)
984
+ with gr.Accordion("Advanced options", elem_id='advanced_options', open=False):
985
+ with gr.Row():
986
+ with gr.Column(min_width=300):
987
+ seed = gr.Number(label="--seed", info="Seed", value=42, interactive=True)
988
+ with gr.Column(min_width=300):
989
+ workers = gr.Number(label="--max_data_loader_n_workers", info="Number of Workers", value=2, interactive=True)
990
+ with gr.Column(min_width=300):
991
+ learning_rate = gr.Textbox(label="--learning_rate", info="Learning Rate", value="8e-4", interactive=True)
992
+ with gr.Column(min_width=300):
993
+ save_every_n_epochs = gr.Number(label="--save_every_n_epochs", info="Save every N epochs", value=4, interactive=True)
994
+ with gr.Column(min_width=300):
995
+ guidance_scale = gr.Number(label="--guidance_scale", info="Guidance Scale", value=1.0, interactive=True)
996
+ with gr.Column(min_width=300):
997
+ timestep_sampling = gr.Textbox(label="--timestep_sampling", info="Timestep Sampling", value="shift", interactive=True)
998
+ with gr.Column(min_width=300):
999
+ network_dim = gr.Number(label="--network_dim", info="LoRA Rank", value=4, minimum=4, maximum=128, step=4, interactive=True)
1000
+ advanced_components, advanced_component_ids = init_advanced()
1001
+ with gr.Row():
1002
+ terminal = LogsView(label="Train log", elem_id="terminal")
1003
+ with gr.Row():
1004
+ gallery = gr.Gallery(get_samples, inputs=[lora_name], label="Samples", every=10, columns=6)
1005
+
1006
+ with gr.TabItem("Publish") as publish_tab:
1007
+ hf_token = gr.Textbox(label="Huggingface Token")
1008
+ hf_login = gr.Button("Login")
1009
+ hf_logout = gr.Button("Logout")
1010
+ with gr.Row() as row:
1011
+ gr.Markdown("**LoRA**")
1012
+ gr.Markdown("**Upload**")
1013
+ loras = get_loras()
1014
+ with gr.Row():
1015
+ lora_rows = refresh_publish_tab()
1016
+ with gr.Column():
1017
+ with gr.Row():
1018
+ repo_owner = gr.Textbox(label="Account", interactive=False)
1019
+ repo_name = gr.Textbox(label="Repository Name")
1020
+ repo_visibility = gr.Textbox(label="Repository Visibility ('public' or 'private')", value="public")
1021
+ upload_button = gr.Button("Upload to HuggingFace")
1022
+ upload_button.click(
1023
+ fn=upload_hf,
1024
+ inputs=[
1025
+ base_model,
1026
+ lora_rows,
1027
+ repo_owner,
1028
+ repo_name,
1029
+ repo_visibility,
1030
+ hf_token,
1031
+ ]
1032
+ )
1033
+ hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
1034
+ hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
1035
+
1036
+
1037
+ publish_tab.select(refresh_publish_tab, outputs=lora_rows)
1038
+ lora_rows.select(fn=set_repo, inputs=[lora_rows], outputs=[repo_name])
1039
+
1040
+ dataset_folder = gr.State()
1041
+
1042
+ listeners = [
1043
+ base_model,
1044
+ lora_name,
1045
+ resolution,
1046
+ seed,
1047
+ workers,
1048
+ concept_sentence,
1049
+ learning_rate,
1050
+ network_dim,
1051
+ max_train_epochs,
1052
+ save_every_n_epochs,
1053
+ timestep_sampling,
1054
+ guidance_scale,
1055
+ vram,
1056
+ num_repeats,
1057
+ sample_prompts,
1058
+ sample_every_n_steps,
1059
+ *advanced_components
1060
+ ]
1061
+ advanced_component_ids = [x.elem_id for x in advanced_components]
1062
+ original_advanced_component_values = [comp.value for comp in advanced_components]
1063
+ images.upload(
1064
+ load_captioning,
1065
+ inputs=[images, concept_sentence],
1066
+ outputs=output_components
1067
+ )
1068
+ images.delete(
1069
+ load_captioning,
1070
+ inputs=[images, concept_sentence],
1071
+ outputs=output_components
1072
+ )
1073
+ images.clear(
1074
+ hide_captioning,
1075
+ outputs=[captioning_area, start]
1076
+ )
1077
+ max_train_epochs.change(
1078
+ fn=update_total_steps,
1079
+ inputs=[max_train_epochs, num_repeats, images],
1080
+ outputs=[total_steps]
1081
+ )
1082
+ num_repeats.change(
1083
+ fn=update_total_steps,
1084
+ inputs=[max_train_epochs, num_repeats, images],
1085
+ outputs=[total_steps]
1086
+ )
1087
+ images.upload(
1088
+ fn=update_total_steps,
1089
+ inputs=[max_train_epochs, num_repeats, images],
1090
+ outputs=[total_steps]
1091
+ )
1092
+ images.delete(
1093
+ fn=update_total_steps,
1094
+ inputs=[max_train_epochs, num_repeats, images],
1095
+ outputs=[total_steps]
1096
+ )
1097
+ images.clear(
1098
+ fn=update_total_steps,
1099
+ inputs=[max_train_epochs, num_repeats, images],
1100
+ outputs=[total_steps]
1101
+ )
1102
+ concept_sentence.change(fn=update_sample, inputs=[concept_sentence], outputs=sample_prompts)
1103
+ start.click(fn=create_dataset, inputs=[dataset_folder, resolution, images] + caption_list, outputs=dataset_folder).then(
1104
+ fn=start_training,
1105
+ inputs=[
1106
+ base_model,
1107
+ lora_name,
1108
+ train_script,
1109
+ train_config,
1110
+ sample_prompts,
1111
+ ],
1112
+ outputs=terminal,
1113
+ )
1114
+ do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
1115
+ demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, repo_owner])
1116
+ refresh.click(update, inputs=listeners, outputs=[train_script, train_config, dataset_folder])
1117
+ if __name__ == "__main__":
1118
+ cwd = os.path.dirname(os.path.abspath(__file__))
1119
+ demo.launch(debug=True, show_error=True, allowed_paths=[cwd])