|
import os
|
|
import sys
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = '0'
|
|
sys.path.insert(0, os.getcwd())
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'sd-scripts'))
|
|
import subprocess
|
|
import gradio as gr
|
|
from PIL import Image
|
|
import torch
|
|
import uuid
|
|
import shutil
|
|
import json
|
|
import yaml
|
|
from slugify import slugify
|
|
from transformers import AutoProcessor, AutoModelForCausalLM
|
|
from gradio_logsview import LogsView, LogsViewRunner
|
|
from huggingface_hub import hf_hub_download, HfApi
|
|
from library import flux_train_utils, huggingface_util
|
|
from argparse import Namespace
|
|
import train_network
|
|
import toml
|
|
import re
|
|
MAX_IMAGES = 150
|
|
|
|
with open('models.yaml', 'r') as file:
|
|
models = yaml.safe_load(file)
|
|
|
|
def readme(base_model, lora_name, instance_prompt, sample_prompts):
|
|
|
|
|
|
model_config = models[base_model]
|
|
model_file = model_config["file"]
|
|
base_model_name = model_config["base"]
|
|
license = None
|
|
license_name = None
|
|
license_link = None
|
|
license_items = []
|
|
if "license" in model_config:
|
|
license = model_config["license"]
|
|
license_items.append(f"license: {license}")
|
|
if "license_name" in model_config:
|
|
license_name = model_config["license_name"]
|
|
license_items.append(f"license_name: {license_name}")
|
|
if "license_link" in model_config:
|
|
license_link = model_config["license_link"]
|
|
license_items.append(f"license_link: {license_link}")
|
|
license_str = "\n".join(license_items)
|
|
print(f"license_items={license_items}")
|
|
print(f"license_str = {license_str}")
|
|
|
|
|
|
tags = [ "text-to-image", "flux", "lora", "diffusers", "template:sd-lora", "fluxgym" ]
|
|
|
|
|
|
widgets = []
|
|
sample_image_paths = []
|
|
output_name = slugify(lora_name)
|
|
samples_dir = resolve_path_without_quotes(f"outputs/{output_name}/sample")
|
|
try:
|
|
for filename in os.listdir(samples_dir):
|
|
|
|
match = re.search(r"_(\d+)_(\d+)_(\d+)\.png$", filename)
|
|
if match:
|
|
steps, index, timestamp = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
|
sample_image_paths.append((steps, index, f"sample/{filename}"))
|
|
|
|
|
|
sample_image_paths.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
final_sample_image_paths = sample_image_paths[:len(sample_prompts)]
|
|
final_sample_image_paths.sort(key=lambda x: x[1])
|
|
for i, prompt in enumerate(sample_prompts):
|
|
_, _, image_path = final_sample_image_paths[i]
|
|
widgets.append(
|
|
{
|
|
"text": prompt,
|
|
"output": {
|
|
"url": image_path
|
|
},
|
|
}
|
|
)
|
|
except:
|
|
print(f"no samples")
|
|
dtype = "torch.bfloat16"
|
|
|
|
readme_content = f"""---
|
|
tags:
|
|
{yaml.dump(tags, indent=4).strip()}
|
|
{"widget:" if os.path.isdir(samples_dir) else ""}
|
|
{yaml.dump(widgets, indent=4).strip() if widgets else ""}
|
|
base_model: {base_model_name}
|
|
{"instance_prompt: " + instance_prompt if instance_prompt else ""}
|
|
{license_str}
|
|
---
|
|
|
|
# {lora_name}
|
|
|
|
A Flux LoRA trained on a local computer with [Fluxgym](https://github.com/cocktailpeanut/fluxgym)
|
|
|
|
<Gallery />
|
|
|
|
## Trigger words
|
|
|
|
{"You should use `" + instance_prompt + "` to trigger the image generation." if instance_prompt else "No trigger words defined."}
|
|
|
|
## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, Forge, etc.
|
|
|
|
Weights for this model are available in Safetensors format.
|
|
|
|
"""
|
|
return readme_content
|
|
|
|
def account_hf():
|
|
try:
|
|
with open("HF_TOKEN", "r") as file:
|
|
token = file.read()
|
|
api = HfApi(token=token)
|
|
try:
|
|
account = api.whoami()
|
|
return { "token": token, "account": account['name'] }
|
|
except:
|
|
return None
|
|
except:
|
|
return None
|
|
|
|
"""
|
|
hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
|
|
"""
|
|
def logout_hf():
|
|
os.remove("HF_TOKEN")
|
|
global current_account
|
|
current_account = account_hf()
|
|
print(f"current_account={current_account}")
|
|
return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
|
|
|
|
|
|
"""
|
|
hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
|
|
"""
|
|
def login_hf(hf_token):
|
|
api = HfApi(token=hf_token)
|
|
try:
|
|
account = api.whoami()
|
|
if account != None:
|
|
if "name" in account:
|
|
with open("HF_TOKEN", "w") as file:
|
|
file.write(hf_token)
|
|
global current_account
|
|
current_account = account_hf()
|
|
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
|
|
return gr.update(), gr.update(), gr.update(), gr.update()
|
|
except:
|
|
print(f"incorrect hf_token")
|
|
return gr.update(), gr.update(), gr.update(), gr.update()
|
|
|
|
def upload_hf(base_model, lora_rows, repo_owner, repo_name, repo_visibility, hf_token):
|
|
src = lora_rows
|
|
repo_id = f"{repo_owner}/{repo_name}"
|
|
gr.Info(f"Uploading to Huggingface. Please Stand by...", duration=None)
|
|
args = Namespace(
|
|
huggingface_repo_id=repo_id,
|
|
huggingface_repo_type="model",
|
|
huggingface_repo_visibility=repo_visibility,
|
|
huggingface_path_in_repo="",
|
|
huggingface_token=hf_token,
|
|
async_upload=False
|
|
)
|
|
print(f"upload_hf args={args}")
|
|
huggingface_util.upload(args=args, src=src)
|
|
gr.Info(f"[Upload Complete] https://huggingface.co/{repo_id}", duration=None)
|
|
|
|
def load_captioning(uploaded_files, concept_sentence):
|
|
uploaded_images = [file for file in uploaded_files if not file.endswith('.txt')]
|
|
txt_files = [file for file in uploaded_files if file.endswith('.txt')]
|
|
txt_files_dict = {os.path.splitext(os.path.basename(txt_file))[0]: txt_file for txt_file in txt_files}
|
|
updates = []
|
|
if len(uploaded_images) <= 1:
|
|
raise gr.Error(
|
|
"Please upload at least 2 images to train your model (the ideal number with default settings is between 4-30)"
|
|
)
|
|
elif len(uploaded_images) > MAX_IMAGES:
|
|
raise gr.Error(f"For now, only {MAX_IMAGES} or less images are allowed for training")
|
|
|
|
|
|
updates.append(gr.update(visible=True))
|
|
|
|
for i in range(1, MAX_IMAGES + 1):
|
|
|
|
visible = i <= len(uploaded_images)
|
|
|
|
|
|
updates.append(gr.update(visible=visible))
|
|
|
|
|
|
image_value = uploaded_images[i - 1] if visible else None
|
|
updates.append(gr.update(value=image_value, visible=visible))
|
|
|
|
corresponding_caption = False
|
|
if(image_value):
|
|
base_name = os.path.splitext(os.path.basename(image_value))[0]
|
|
if base_name in txt_files_dict:
|
|
with open(txt_files_dict[base_name], 'r') as file:
|
|
corresponding_caption = file.read()
|
|
|
|
|
|
text_value = corresponding_caption if visible and corresponding_caption else concept_sentence if visible and concept_sentence else None
|
|
updates.append(gr.update(value=text_value, visible=visible))
|
|
|
|
|
|
updates.append(gr.update(visible=True))
|
|
updates.append(gr.update(visible=True))
|
|
|
|
return updates
|
|
|
|
def hide_captioning():
|
|
return gr.update(visible=False), gr.update(visible=False)
|
|
|
|
def resize_image(image_path, output_path, size):
|
|
with Image.open(image_path) as img:
|
|
width, height = img.size
|
|
if width < height:
|
|
new_width = size
|
|
new_height = int((size/width) * height)
|
|
else:
|
|
new_height = size
|
|
new_width = int((size/height) * width)
|
|
print(f"resize {image_path} : {new_width}x{new_height}")
|
|
img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
img_resized.save(output_path)
|
|
|
|
def create_dataset(destination_folder, size, *inputs):
|
|
print("Creating dataset")
|
|
images = inputs[0]
|
|
if not os.path.exists(destination_folder):
|
|
os.makedirs(destination_folder)
|
|
|
|
for index, image in enumerate(images):
|
|
|
|
new_image_path = shutil.copy(image, destination_folder)
|
|
|
|
|
|
ext = os.path.splitext(new_image_path)[-1].lower()
|
|
if ext == '.txt':
|
|
continue
|
|
|
|
|
|
resize_image(new_image_path, new_image_path, size)
|
|
|
|
|
|
|
|
original_caption = inputs[index + 1]
|
|
|
|
image_file_name = os.path.basename(new_image_path)
|
|
caption_file_name = os.path.splitext(image_file_name)[0] + ".txt"
|
|
caption_path = resolve_path_without_quotes(os.path.join(destination_folder, caption_file_name))
|
|
print(f"image_path={new_image_path}, caption_path = {caption_path}, original_caption={original_caption}")
|
|
|
|
if os.path.exists(caption_path):
|
|
print(f"{caption_path} already exists. use the existing .txt file")
|
|
else:
|
|
print(f"{caption_path} create a .txt caption file")
|
|
with open(caption_path, 'w') as file:
|
|
file.write(original_caption)
|
|
|
|
print(f"destination_folder {destination_folder}")
|
|
return destination_folder
|
|
|
|
|
|
def run_captioning(images, concept_sentence, *captions):
|
|
print(f"run_captioning")
|
|
print(f"concept sentence {concept_sentence}")
|
|
print(f"captions {captions}")
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"device={device}")
|
|
torch_dtype = torch.float16
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
"multimodalart/Florence-2-large-no-flash-attn", torch_dtype=torch_dtype, trust_remote_code=True
|
|
).to(device)
|
|
processor = AutoProcessor.from_pretrained("multimodalart/Florence-2-large-no-flash-attn", trust_remote_code=True)
|
|
|
|
captions = list(captions)
|
|
for i, image_path in enumerate(images):
|
|
print(captions[i])
|
|
if isinstance(image_path, str):
|
|
image = Image.open(image_path).convert("RGB")
|
|
|
|
prompt = "<DETAILED_CAPTION>"
|
|
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
|
|
print(f"inputs {inputs}")
|
|
|
|
generated_ids = model.generate(
|
|
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
|
|
)
|
|
print(f"generated_ids {generated_ids}")
|
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
|
print(f"generated_text: {generated_text}")
|
|
parsed_answer = processor.post_process_generation(
|
|
generated_text, task=prompt, image_size=(image.width, image.height)
|
|
)
|
|
print(f"parsed_answer = {parsed_answer}")
|
|
caption_text = parsed_answer["<DETAILED_CAPTION>"].replace("The image shows ", "")
|
|
print(f"caption_text = {caption_text}, concept_sentence={concept_sentence}")
|
|
if concept_sentence:
|
|
caption_text = f"{concept_sentence} {caption_text}"
|
|
captions[i] = caption_text
|
|
|
|
yield captions
|
|
model.to("cpu")
|
|
del model
|
|
del processor
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
def recursive_update(d, u):
|
|
for k, v in u.items():
|
|
if isinstance(v, dict) and v:
|
|
d[k] = recursive_update(d.get(k, {}), v)
|
|
else:
|
|
d[k] = v
|
|
return d
|
|
|
|
def download(base_model):
|
|
model = models[base_model]
|
|
model_file = model["file"]
|
|
repo = model["repo"]
|
|
|
|
|
|
if base_model == "flux-dev" or base_model == "flux-schnell":
|
|
unet_folder = "models/unet"
|
|
else:
|
|
unet_folder = f"models/unet/{repo}"
|
|
unet_path = os.path.join(unet_folder, model_file)
|
|
if not os.path.exists(unet_path):
|
|
os.makedirs(unet_folder, exist_ok=True)
|
|
gr.Info(f"Downloading base model: {base_model}. Please wait. (You can check the terminal for the download progress)", duration=None)
|
|
print(f"download {base_model}")
|
|
hf_hub_download(repo_id=repo, local_dir=unet_folder, filename=model_file)
|
|
|
|
|
|
vae_folder = "models/vae"
|
|
vae_path = os.path.join(vae_folder, "ae.sft")
|
|
if not os.path.exists(vae_path):
|
|
os.makedirs(vae_folder, exist_ok=True)
|
|
gr.Info(f"Downloading vae")
|
|
print(f"downloading ae.sft...")
|
|
hf_hub_download(repo_id="cocktailpeanut/xulf-dev", local_dir=vae_folder, filename="ae.sft")
|
|
|
|
|
|
clip_folder = "models/clip"
|
|
clip_l_path = os.path.join(clip_folder, "clip_l.safetensors")
|
|
if not os.path.exists(clip_l_path):
|
|
os.makedirs(clip_folder, exist_ok=True)
|
|
gr.Info(f"Downloading clip...")
|
|
print(f"download clip_l.safetensors")
|
|
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="clip_l.safetensors")
|
|
|
|
|
|
t5xxl_path = os.path.join(clip_folder, "t5xxl_fp16.safetensors")
|
|
if not os.path.exists(t5xxl_path):
|
|
print(f"download t5xxl_fp16.safetensors")
|
|
gr.Info(f"Downloading t5xxl...")
|
|
hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", local_dir=clip_folder, filename="t5xxl_fp16.safetensors")
|
|
|
|
|
|
def resolve_path(p):
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
norm_path = os.path.normpath(os.path.join(current_dir, p))
|
|
return f"\"{norm_path}\""
|
|
def resolve_path_without_quotes(p):
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
norm_path = os.path.normpath(os.path.join(current_dir, p))
|
|
return norm_path
|
|
|
|
def gen_sh(
|
|
base_model,
|
|
output_name,
|
|
resolution,
|
|
seed,
|
|
workers,
|
|
learning_rate,
|
|
network_dim,
|
|
max_train_epochs,
|
|
save_every_n_epochs,
|
|
timestep_sampling,
|
|
guidance_scale,
|
|
vram,
|
|
sample_prompts,
|
|
sample_every_n_steps,
|
|
*advanced_components
|
|
):
|
|
|
|
print(f"gen_sh: network_dim:{network_dim}, max_train_epochs={max_train_epochs}, save_every_n_epochs={save_every_n_epochs}, timestep_sampling={timestep_sampling}, guidance_scale={guidance_scale}, vram={vram}, sample_prompts={sample_prompts}, sample_every_n_steps={sample_every_n_steps}")
|
|
|
|
output_dir = resolve_path(f"outputs/{output_name}")
|
|
sample_prompts_path = resolve_path(f"outputs/{output_name}/sample_prompts.txt")
|
|
|
|
line_break = "\\"
|
|
file_type = "sh"
|
|
if sys.platform == "win32":
|
|
line_break = "^"
|
|
file_type = "bat"
|
|
|
|
|
|
sample = ""
|
|
if len(sample_prompts) > 0 and sample_every_n_steps > 0:
|
|
sample = f"""--sample_prompts={sample_prompts_path} --sample_every_n_steps="{sample_every_n_steps}" {line_break}"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if vram == "16G":
|
|
|
|
optimizer = f"""--optimizer_type adafactor {line_break}
|
|
--optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
|
|
--lr_scheduler constant_with_warmup {line_break}
|
|
--max_grad_norm 0.0 {line_break}"""
|
|
elif vram == "12G":
|
|
|
|
optimizer = f"""--optimizer_type adafactor {line_break}
|
|
--optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" {line_break}
|
|
--split_mode {line_break}
|
|
--network_args "train_blocks=single" {line_break}
|
|
--lr_scheduler constant_with_warmup {line_break}
|
|
--max_grad_norm 0.0 {line_break}"""
|
|
else:
|
|
|
|
optimizer = f"--optimizer_type adamw8bit {line_break}"
|
|
|
|
|
|
|
|
model_config = models[base_model]
|
|
model_file = model_config["file"]
|
|
repo = model_config["repo"]
|
|
if base_model == "flux-dev" or base_model == "flux-schnell":
|
|
model_folder = "models/unet"
|
|
else:
|
|
model_folder = f"models/unet/{repo}"
|
|
model_path = os.path.join(model_folder, model_file)
|
|
pretrained_model_path = resolve_path(model_path)
|
|
|
|
clip_path = resolve_path("models/clip/clip_l.safetensors")
|
|
t5_path = resolve_path("models/clip/t5xxl_fp16.safetensors")
|
|
ae_path = resolve_path("models/vae/ae.sft")
|
|
sh = f"""accelerate launch {line_break}
|
|
--mixed_precision bf16 {line_break}
|
|
--num_cpu_threads_per_process 1 {line_break}
|
|
sd-scripts/flux_train_network.py {line_break}
|
|
--pretrained_model_name_or_path {pretrained_model_path} {line_break}
|
|
--clip_l {clip_path} {line_break}
|
|
--t5xxl {t5_path} {line_break}
|
|
--ae {ae_path} {line_break}
|
|
--cache_latents_to_disk {line_break}
|
|
--save_model_as safetensors {line_break}
|
|
--sdpa --persistent_data_loader_workers {line_break}
|
|
--max_data_loader_n_workers {workers} {line_break}
|
|
--seed {seed} {line_break}
|
|
--gradient_checkpointing {line_break}
|
|
--mixed_precision bf16 {line_break}
|
|
--save_precision bf16 {line_break}
|
|
--network_module networks.lora_flux {line_break}
|
|
--network_dim {network_dim} {line_break}
|
|
{optimizer}{sample}
|
|
--learning_rate {learning_rate} {line_break}
|
|
--cache_text_encoder_outputs {line_break}
|
|
--cache_text_encoder_outputs_to_disk {line_break}
|
|
--fp8_base {line_break}
|
|
--highvram {line_break}
|
|
--max_train_epochs {max_train_epochs} {line_break}
|
|
--save_every_n_epochs {save_every_n_epochs} {line_break}
|
|
--dataset_config {resolve_path(f"outputs/{output_name}/dataset.toml")} {line_break}
|
|
--output_dir {output_dir} {line_break}
|
|
--output_name {output_name} {line_break}
|
|
--timestep_sampling {timestep_sampling} {line_break}
|
|
--discrete_flow_shift 3.1582 {line_break}
|
|
--model_prediction_type raw {line_break}
|
|
--guidance_scale {guidance_scale} {line_break}
|
|
--loss_type l2 {line_break}"""
|
|
|
|
|
|
|
|
|
|
global advanced_component_ids
|
|
global original_advanced_component_values
|
|
|
|
|
|
print(f"original_advanced_component_values = {original_advanced_component_values}")
|
|
advanced_flags = []
|
|
for i, current_value in enumerate(advanced_components):
|
|
|
|
if original_advanced_component_values[i] != current_value:
|
|
|
|
if current_value == True:
|
|
|
|
advanced_flags.append(advanced_component_ids[i])
|
|
else:
|
|
|
|
advanced_flags.append(f"{advanced_component_ids[i]} {current_value}")
|
|
|
|
if len(advanced_flags) > 0:
|
|
advanced_flags_str = f" {line_break}\n ".join(advanced_flags)
|
|
sh = sh + "\n " + advanced_flags_str
|
|
|
|
return sh
|
|
|
|
def gen_toml(
|
|
dataset_folder,
|
|
resolution,
|
|
class_tokens,
|
|
num_repeats
|
|
):
|
|
toml = f"""[general]
|
|
shuffle_caption = false
|
|
caption_extension = '.txt'
|
|
keep_tokens = 1
|
|
|
|
[[datasets]]
|
|
resolution = {resolution}
|
|
batch_size = 1
|
|
keep_tokens = 1
|
|
|
|
[[datasets.subsets]]
|
|
image_dir = '{resolve_path_without_quotes(dataset_folder)}'
|
|
class_tokens = '{class_tokens}'
|
|
num_repeats = {num_repeats}"""
|
|
return toml
|
|
|
|
def update_total_steps(max_train_epochs, num_repeats, images):
|
|
try:
|
|
num_images = len(images)
|
|
total_steps = max_train_epochs * num_images * num_repeats
|
|
print(f"max_train_epochs={max_train_epochs} num_images={num_images}, num_repeats={num_repeats}, total_steps={total_steps}")
|
|
return gr.update(value = total_steps)
|
|
except:
|
|
print("")
|
|
|
|
def set_repo(lora_rows):
|
|
selected_name = os.path.basename(lora_rows)
|
|
return gr.update(value=selected_name)
|
|
|
|
def get_loras():
|
|
try:
|
|
outputs_path = resolve_path_without_quotes(f"outputs")
|
|
files = os.listdir(outputs_path)
|
|
folders = [os.path.join(outputs_path, item) for item in files if os.path.isdir(os.path.join(outputs_path, item)) and item != "sample"]
|
|
folders.sort(key=lambda file: os.path.getctime(file), reverse=True)
|
|
return folders
|
|
except Exception as e:
|
|
return []
|
|
|
|
def get_samples(lora_name):
|
|
output_name = slugify(lora_name)
|
|
try:
|
|
samples_path = resolve_path_without_quotes(f"outputs/{output_name}/sample")
|
|
files = [os.path.join(samples_path, file) for file in os.listdir(samples_path)]
|
|
files.sort(key=lambda file: os.path.getctime(file), reverse=True)
|
|
return files
|
|
except:
|
|
return []
|
|
|
|
def start_training(
|
|
base_model,
|
|
lora_name,
|
|
train_script,
|
|
train_config,
|
|
sample_prompts,
|
|
):
|
|
|
|
if not os.path.exists("models"):
|
|
os.makedirs("models", exist_ok=True)
|
|
if not os.path.exists("outputs"):
|
|
os.makedirs("outputs", exist_ok=True)
|
|
output_name = slugify(lora_name)
|
|
output_dir = resolve_path_without_quotes(f"outputs/{output_name}")
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
download(base_model)
|
|
|
|
file_type = "sh"
|
|
if sys.platform == "win32":
|
|
file_type = "bat"
|
|
|
|
sh_filename = f"train.{file_type}"
|
|
sh_filepath = resolve_path_without_quotes(f"outputs/{output_name}/{sh_filename}")
|
|
with open(sh_filepath, 'w', encoding="utf-8") as file:
|
|
file.write(train_script)
|
|
gr.Info(f"Generated train script at {sh_filename}")
|
|
|
|
|
|
dataset_path = resolve_path_without_quotes(f"outputs/{output_name}/dataset.toml")
|
|
with open(dataset_path, 'w', encoding="utf-8") as file:
|
|
file.write(train_config)
|
|
gr.Info(f"Generated dataset.toml")
|
|
|
|
sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
|
|
with open(sample_prompts_path, 'w', encoding='utf-8') as file:
|
|
file.write(sample_prompts)
|
|
gr.Info(f"Generated sample_prompts.txt")
|
|
|
|
|
|
if sys.platform == "win32":
|
|
command = sh_filepath
|
|
else:
|
|
command = f"bash \"{sh_filepath}\""
|
|
|
|
|
|
env = os.environ.copy()
|
|
env['PYTHONIOENCODING'] = 'utf-8'
|
|
env['LOG_LEVEL'] = 'DEBUG'
|
|
runner = LogsViewRunner()
|
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
|
gr.Info(f"Started training")
|
|
yield from runner.run_command([command], cwd=cwd)
|
|
yield runner.log(f"Runner: {runner}")
|
|
|
|
|
|
config = toml.loads(train_config)
|
|
concept_sentence = config['datasets'][0]['subsets'][0]['class_tokens']
|
|
print(f"concept_sentence={concept_sentence}")
|
|
print(f"lora_name {lora_name}, concept_sentence={concept_sentence}, output_name={output_name}")
|
|
sample_prompts_path = resolve_path_without_quotes(f"outputs/{output_name}/sample_prompts.txt")
|
|
with open(sample_prompts_path, "r", encoding="utf-8") as f:
|
|
lines = f.readlines()
|
|
sample_prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
|
md = readme(base_model, lora_name, concept_sentence, sample_prompts)
|
|
readme_path = resolve_path_without_quotes(f"outputs/{output_name}/README.md")
|
|
with open(readme_path, "w", encoding="utf-8") as f:
|
|
f.write(md)
|
|
|
|
gr.Info(f"Training Complete. Check the outputs folder for the LoRA files.", duration=None)
|
|
|
|
|
|
def update(
|
|
base_model,
|
|
lora_name,
|
|
resolution,
|
|
seed,
|
|
workers,
|
|
class_tokens,
|
|
learning_rate,
|
|
network_dim,
|
|
max_train_epochs,
|
|
save_every_n_epochs,
|
|
timestep_sampling,
|
|
guidance_scale,
|
|
vram,
|
|
num_repeats,
|
|
sample_prompts,
|
|
sample_every_n_steps,
|
|
*advanced_components,
|
|
):
|
|
output_name = slugify(lora_name)
|
|
dataset_folder = str(f"datasets/{output_name}")
|
|
sh = gen_sh(
|
|
base_model,
|
|
output_name,
|
|
resolution,
|
|
seed,
|
|
workers,
|
|
learning_rate,
|
|
network_dim,
|
|
max_train_epochs,
|
|
save_every_n_epochs,
|
|
timestep_sampling,
|
|
guidance_scale,
|
|
vram,
|
|
sample_prompts,
|
|
sample_every_n_steps,
|
|
*advanced_components,
|
|
)
|
|
toml = gen_toml(
|
|
dataset_folder,
|
|
resolution,
|
|
class_tokens,
|
|
num_repeats
|
|
)
|
|
return gr.update(value=sh), gr.update(value=toml), dataset_folder
|
|
|
|
"""
|
|
demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, hf_account])
|
|
"""
|
|
def loaded():
|
|
global current_account
|
|
current_account = account_hf()
|
|
print(f"current_account={current_account}")
|
|
if current_account != None:
|
|
return gr.update(value=current_account["token"]), gr.update(visible=False), gr.update(visible=True), gr.update(value=current_account["account"], visible=True)
|
|
else:
|
|
return gr.update(value=""), gr.update(visible=True), gr.update(visible=False), gr.update(value="", visible=False)
|
|
|
|
def update_sample(concept_sentence):
|
|
return gr.update(value=concept_sentence)
|
|
|
|
def refresh_publish_tab():
|
|
loras = get_loras()
|
|
return gr.Dropdown(label="Trained LoRAs", choices=loras)
|
|
|
|
def init_advanced():
|
|
|
|
basic_args = {
|
|
'pretrained_model_name_or_path',
|
|
'clip_l',
|
|
't5xxl',
|
|
'ae',
|
|
'cache_latents_to_disk',
|
|
'save_model_as',
|
|
'sdpa',
|
|
'persistent_data_loader_workers',
|
|
'max_data_loader_n_workers',
|
|
'seed',
|
|
'gradient_checkpointing',
|
|
'mixed_precision',
|
|
'save_precision',
|
|
'network_module',
|
|
'network_dim',
|
|
'learning_rate',
|
|
'cache_text_encoder_outputs',
|
|
'cache_text_encoder_outputs_to_disk',
|
|
'fp8_base',
|
|
'highvram',
|
|
'max_train_epochs',
|
|
'save_every_n_epochs',
|
|
'dataset_config',
|
|
'output_dir',
|
|
'output_name',
|
|
'timestep_sampling',
|
|
'discrete_flow_shift',
|
|
'model_prediction_type',
|
|
'guidance_scale',
|
|
'loss_type',
|
|
'optimizer_type',
|
|
'optimizer_args',
|
|
'lr_scheduler',
|
|
'sample_prompts',
|
|
'sample_every_n_steps',
|
|
'max_grad_norm',
|
|
'split_mode',
|
|
'network_args'
|
|
}
|
|
|
|
|
|
|
|
parser = train_network.setup_parser()
|
|
flux_train_utils.add_flux_train_arguments(parser)
|
|
args_info = {}
|
|
for action in parser._actions:
|
|
if action.dest != 'help':
|
|
|
|
args_info[action.dest] = {
|
|
"action": action.option_strings,
|
|
"type": action.type,
|
|
"help": action.help,
|
|
"default": action.default,
|
|
"required": action.required
|
|
}
|
|
temp = []
|
|
for key in args_info:
|
|
temp.append({ 'key': key, 'action': args_info[key] })
|
|
temp.sort(key=lambda x: x['key'])
|
|
advanced_component_ids = []
|
|
advanced_components = []
|
|
for item in temp:
|
|
key = item['key']
|
|
action = item['action']
|
|
if key in basic_args:
|
|
print("")
|
|
else:
|
|
action_type = str(action['type'])
|
|
component = None
|
|
with gr.Column(min_width=300):
|
|
if action_type == "None":
|
|
|
|
component = gr.Checkbox()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
component = gr.Textbox(value="")
|
|
if component != None:
|
|
component.interactive = True
|
|
component.elem_id = action['action'][0]
|
|
component.label = component.elem_id
|
|
component.elem_classes = ["advanced"]
|
|
if action['help'] != None:
|
|
component.info = action['help']
|
|
advanced_components.append(component)
|
|
advanced_component_ids.append(component.elem_id)
|
|
return advanced_components, advanced_component_ids
|
|
|
|
|
|
theme = gr.themes.Monochrome(
|
|
text_size=gr.themes.Size(lg="18px", md="15px", sm="13px", xl="22px", xs="12px", xxl="24px", xxs="9px"),
|
|
font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif"],
|
|
)
|
|
css = """
|
|
@keyframes rotate {
|
|
0% {
|
|
transform: rotate(0deg);
|
|
}
|
|
100% {
|
|
transform: rotate(360deg);
|
|
}
|
|
}
|
|
#advanced_options .advanced:nth-child(even) { background: rgba(0,0,100,0.04) !important; }
|
|
h1{font-family: georgia; font-style: italic; font-weight: bold; font-size: 30px; letter-spacing: -1px;}
|
|
h3{margin-top: 0}
|
|
.tabitem{border: 0px}
|
|
.group_padding{}
|
|
nav{position: fixed; top: 0; left: 0; right: 0; z-index: 1000; text-align: center; padding: 10px; box-sizing: border-box; display: flex; align-items: center; backdrop-filter: blur(10px); }
|
|
nav button { background: none; color: firebrick; font-weight: bold; border: 2px solid firebrick; padding: 5px 10px; border-radius: 5px; font-size: 14px; }
|
|
nav img { height: 40px; width: 40px; border-radius: 40px; }
|
|
nav img.rotate { animation: rotate 2s linear infinite; }
|
|
.flexible { flex-grow: 1; }
|
|
.tast-details { margin: 10px 0 !important; }
|
|
.toast-wrap { bottom: var(--size-4) !important; top: auto !important; border: none !important; backdrop-filter: blur(10px); }
|
|
.toast-title, .toast-text, .toast-icon, .toast-close { color: black !important; font-size: 14px; }
|
|
.toast-body { border: none !important; }
|
|
#terminal { box-shadow: none !important; margin-bottom: 25px; background: rgba(0,0,0,0.03); }
|
|
#terminal .generating { border: none !important; }
|
|
#terminal label { position: absolute !important; }
|
|
.tabs { margin-top: 50px; }
|
|
.hidden { display: none !important; }
|
|
.codemirror-wrapper .cm-line { font-size: 12px !important; }
|
|
label { font-weight: bold !important; }
|
|
#start_training.clicked { background: silver; color: black; }
|
|
"""
|
|
|
|
js = """
|
|
function() {
|
|
let autoscroll = document.querySelector("#autoscroll")
|
|
if (window.iidxx) {
|
|
window.clearInterval(window.iidxx);
|
|
}
|
|
window.iidxx = window.setInterval(function() {
|
|
let text=document.querySelector(".codemirror-wrapper .cm-line").innerText.trim()
|
|
let img = document.querySelector("#logo")
|
|
if (text.length > 0) {
|
|
autoscroll.classList.remove("hidden")
|
|
if (autoscroll.classList.contains("on")) {
|
|
autoscroll.textContent = "Autoscroll ON"
|
|
window.scrollTo(0, document.body.scrollHeight, { behavior: "smooth" });
|
|
img.classList.add("rotate")
|
|
} else {
|
|
autoscroll.textContent = "Autoscroll OFF"
|
|
img.classList.remove("rotate")
|
|
}
|
|
}
|
|
}, 500);
|
|
console.log("autoscroll", autoscroll)
|
|
autoscroll.addEventListener("click", (e) => {
|
|
autoscroll.classList.toggle("on")
|
|
})
|
|
function debounce(fn, delay) {
|
|
let timeoutId;
|
|
return function(...args) {
|
|
clearTimeout(timeoutId);
|
|
timeoutId = setTimeout(() => fn(...args), delay);
|
|
};
|
|
}
|
|
|
|
function handleClick() {
|
|
console.log("refresh")
|
|
document.querySelector("#refresh").click();
|
|
}
|
|
const debouncedClick = debounce(handleClick, 1000);
|
|
document.addEventListener("input", debouncedClick);
|
|
|
|
document.querySelector("#start_training").addEventListener("click", (e) => {
|
|
e.target.classList.add("clicked")
|
|
e.target.innerHTML = "Training..."
|
|
})
|
|
|
|
}
|
|
"""
|
|
|
|
current_account = account_hf()
|
|
print(f"current_account={current_account}")
|
|
|
|
with gr.Blocks(elem_id="app", theme=theme, css=css, fill_width=True) as demo:
|
|
with gr.Tabs() as tabs:
|
|
with gr.TabItem("Gym"):
|
|
output_components = []
|
|
with gr.Row():
|
|
gr.HTML("""<nav>
|
|
<img id='logo' src='/file=icon.png' width='80' height='80'>
|
|
<div class='flexible'></div>
|
|
<button id='autoscroll' class='on hidden'></button>
|
|
</nav>
|
|
""")
|
|
with gr.Row(elem_id='container'):
|
|
with gr.Column():
|
|
gr.Markdown(
|
|
"""# Step 1. LoRA Info
|
|
<p style="margin-top:0">Configure your LoRA train settings.</p>
|
|
""", elem_classes="group_padding")
|
|
lora_name = gr.Textbox(
|
|
label="The name of your LoRA",
|
|
info="This has to be a unique name",
|
|
placeholder="e.g.: Persian Miniature Painting style, Cat Toy",
|
|
)
|
|
concept_sentence = gr.Textbox(
|
|
elem_id="--concept_sentence",
|
|
label="Trigger word/sentence",
|
|
info="Trigger word or sentence to be used",
|
|
placeholder="uncommon word like p3rs0n or trtcrd, or sentence like 'in the style of CNSTLL'",
|
|
interactive=True,
|
|
)
|
|
model_names = list(models.keys())
|
|
print(f"model_names={model_names}")
|
|
base_model = gr.Dropdown(label="Base model (edit the models.yaml file to add more to this list)", choices=model_names, value=model_names[0])
|
|
vram = gr.Radio(["20G", "16G", "12G" ], value="20G", label="VRAM", interactive=True)
|
|
num_repeats = gr.Number(value=10, precision=0, label="Repeat trains per image", interactive=True)
|
|
max_train_epochs = gr.Number(label="Max Train Epochs", value=16, interactive=True)
|
|
total_steps = gr.Number(0, interactive=False, label="Expected training steps")
|
|
sample_prompts = gr.Textbox("", lines=5, label="Sample Image Prompts (Separate with new lines)", interactive=True)
|
|
sample_every_n_steps = gr.Number(0, precision=0, label="Sample Image Every N Steps", interactive=True)
|
|
resolution = gr.Number(value=512, precision=0, label="Resize dataset images")
|
|
with gr.Column():
|
|
gr.Markdown(
|
|
"""# Step 2. Dataset
|
|
<p style="margin-top:0">Make sure the captions include the trigger word.</p>
|
|
""", elem_classes="group_padding")
|
|
with gr.Group():
|
|
images = gr.File(
|
|
file_types=["image", ".txt"],
|
|
label="Upload your images",
|
|
|
|
file_count="multiple",
|
|
interactive=True,
|
|
visible=True,
|
|
scale=1,
|
|
)
|
|
with gr.Group(visible=False) as captioning_area:
|
|
do_captioning = gr.Button("Add AI captions with Florence-2")
|
|
output_components.append(captioning_area)
|
|
|
|
caption_list = []
|
|
for i in range(1, MAX_IMAGES + 1):
|
|
locals()[f"captioning_row_{i}"] = gr.Row(visible=False)
|
|
with locals()[f"captioning_row_{i}"]:
|
|
locals()[f"image_{i}"] = gr.Image(
|
|
type="filepath",
|
|
width=111,
|
|
height=111,
|
|
min_width=111,
|
|
interactive=False,
|
|
scale=2,
|
|
show_label=False,
|
|
show_share_button=False,
|
|
show_download_button=False,
|
|
)
|
|
locals()[f"caption_{i}"] = gr.Textbox(
|
|
label=f"Caption {i}", scale=15, interactive=True
|
|
)
|
|
|
|
output_components.append(locals()[f"captioning_row_{i}"])
|
|
output_components.append(locals()[f"image_{i}"])
|
|
output_components.append(locals()[f"caption_{i}"])
|
|
caption_list.append(locals()[f"caption_{i}"])
|
|
with gr.Column():
|
|
gr.Markdown(
|
|
"""# Step 3. Train
|
|
<p style="margin-top:0">Press start to start training.</p>
|
|
""", elem_classes="group_padding")
|
|
refresh = gr.Button("Refresh", elem_id="refresh", visible=False)
|
|
start = gr.Button("Start training", visible=False, elem_id="start_training")
|
|
output_components.append(start)
|
|
train_script = gr.Textbox(label="Train script", max_lines=100, interactive=True)
|
|
train_config = gr.Textbox(label="Train config", max_lines=100, interactive=True)
|
|
with gr.Accordion("Advanced options", elem_id='advanced_options', open=False):
|
|
with gr.Row():
|
|
with gr.Column(min_width=300):
|
|
seed = gr.Number(label="--seed", info="Seed", value=42, interactive=True)
|
|
with gr.Column(min_width=300):
|
|
workers = gr.Number(label="--max_data_loader_n_workers", info="Number of Workers", value=2, interactive=True)
|
|
with gr.Column(min_width=300):
|
|
learning_rate = gr.Textbox(label="--learning_rate", info="Learning Rate", value="8e-4", interactive=True)
|
|
with gr.Column(min_width=300):
|
|
save_every_n_epochs = gr.Number(label="--save_every_n_epochs", info="Save every N epochs", value=4, interactive=True)
|
|
with gr.Column(min_width=300):
|
|
guidance_scale = gr.Number(label="--guidance_scale", info="Guidance Scale", value=1.0, interactive=True)
|
|
with gr.Column(min_width=300):
|
|
timestep_sampling = gr.Textbox(label="--timestep_sampling", info="Timestep Sampling", value="shift", interactive=True)
|
|
with gr.Column(min_width=300):
|
|
network_dim = gr.Number(label="--network_dim", info="LoRA Rank", value=4, minimum=4, maximum=128, step=4, interactive=True)
|
|
advanced_components, advanced_component_ids = init_advanced()
|
|
with gr.Row():
|
|
terminal = LogsView(label="Train log", elem_id="terminal")
|
|
with gr.Row():
|
|
gallery = gr.Gallery(get_samples, inputs=[lora_name], label="Samples", every=10, columns=6)
|
|
|
|
with gr.TabItem("Publish") as publish_tab:
|
|
hf_token = gr.Textbox(label="Huggingface Token")
|
|
hf_login = gr.Button("Login")
|
|
hf_logout = gr.Button("Logout")
|
|
with gr.Row() as row:
|
|
gr.Markdown("**LoRA**")
|
|
gr.Markdown("**Upload**")
|
|
loras = get_loras()
|
|
with gr.Row():
|
|
lora_rows = refresh_publish_tab()
|
|
with gr.Column():
|
|
with gr.Row():
|
|
repo_owner = gr.Textbox(label="Account", interactive=False)
|
|
repo_name = gr.Textbox(label="Repository Name")
|
|
repo_visibility = gr.Textbox(label="Repository Visibility ('public' or 'private')", value="public")
|
|
upload_button = gr.Button("Upload to HuggingFace")
|
|
upload_button.click(
|
|
fn=upload_hf,
|
|
inputs=[
|
|
base_model,
|
|
lora_rows,
|
|
repo_owner,
|
|
repo_name,
|
|
repo_visibility,
|
|
hf_token,
|
|
]
|
|
)
|
|
hf_login.click(fn=login_hf, inputs=[hf_token], outputs=[hf_token, hf_login, hf_logout, repo_owner])
|
|
hf_logout.click(fn=logout_hf, outputs=[hf_token, hf_login, hf_logout, repo_owner])
|
|
|
|
|
|
publish_tab.select(refresh_publish_tab, outputs=lora_rows)
|
|
lora_rows.select(fn=set_repo, inputs=[lora_rows], outputs=[repo_name])
|
|
|
|
dataset_folder = gr.State()
|
|
|
|
listeners = [
|
|
base_model,
|
|
lora_name,
|
|
resolution,
|
|
seed,
|
|
workers,
|
|
concept_sentence,
|
|
learning_rate,
|
|
network_dim,
|
|
max_train_epochs,
|
|
save_every_n_epochs,
|
|
timestep_sampling,
|
|
guidance_scale,
|
|
vram,
|
|
num_repeats,
|
|
sample_prompts,
|
|
sample_every_n_steps,
|
|
*advanced_components
|
|
]
|
|
advanced_component_ids = [x.elem_id for x in advanced_components]
|
|
original_advanced_component_values = [comp.value for comp in advanced_components]
|
|
images.upload(
|
|
load_captioning,
|
|
inputs=[images, concept_sentence],
|
|
outputs=output_components
|
|
)
|
|
images.delete(
|
|
load_captioning,
|
|
inputs=[images, concept_sentence],
|
|
outputs=output_components
|
|
)
|
|
images.clear(
|
|
hide_captioning,
|
|
outputs=[captioning_area, start]
|
|
)
|
|
max_train_epochs.change(
|
|
fn=update_total_steps,
|
|
inputs=[max_train_epochs, num_repeats, images],
|
|
outputs=[total_steps]
|
|
)
|
|
num_repeats.change(
|
|
fn=update_total_steps,
|
|
inputs=[max_train_epochs, num_repeats, images],
|
|
outputs=[total_steps]
|
|
)
|
|
images.upload(
|
|
fn=update_total_steps,
|
|
inputs=[max_train_epochs, num_repeats, images],
|
|
outputs=[total_steps]
|
|
)
|
|
images.delete(
|
|
fn=update_total_steps,
|
|
inputs=[max_train_epochs, num_repeats, images],
|
|
outputs=[total_steps]
|
|
)
|
|
images.clear(
|
|
fn=update_total_steps,
|
|
inputs=[max_train_epochs, num_repeats, images],
|
|
outputs=[total_steps]
|
|
)
|
|
concept_sentence.change(fn=update_sample, inputs=[concept_sentence], outputs=sample_prompts)
|
|
start.click(fn=create_dataset, inputs=[dataset_folder, resolution, images] + caption_list, outputs=dataset_folder).then(
|
|
fn=start_training,
|
|
inputs=[
|
|
base_model,
|
|
lora_name,
|
|
train_script,
|
|
train_config,
|
|
sample_prompts,
|
|
],
|
|
outputs=terminal,
|
|
)
|
|
do_captioning.click(fn=run_captioning, inputs=[images, concept_sentence] + caption_list, outputs=caption_list)
|
|
demo.load(fn=loaded, js=js, outputs=[hf_token, hf_login, hf_logout, repo_owner])
|
|
refresh.click(update, inputs=listeners, outputs=[train_script, train_config, dataset_folder])
|
|
if __name__ == "__main__":
|
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
|
demo.launch(debug=True, show_error=True, allowed_paths=[cwd])
|
|
|