Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -30,7 +30,7 @@ def change_training_setup(training_type):
|
|
| 30 |
elif training_type == "concept" :
|
| 31 |
return 2000, 1000
|
| 32 |
|
| 33 |
-
def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps):
|
| 34 |
|
| 35 |
script_filename = "train_dreambooth_b-lora_sdxl.py" # Assuming it's in the same folder
|
| 36 |
|
|
@@ -42,6 +42,7 @@ def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instan
|
|
| 42 |
f"--instance_data_dir={instance_data_dir}",
|
| 43 |
f"--output_dir={b_lora_trained_folder}",
|
| 44 |
f"--instance_prompt='{instance_prompt}'",
|
|
|
|
| 45 |
#f"--validation_prompt=a teddy bear in {instance_prompt} style",
|
| 46 |
"--num_validation_images=1",
|
| 47 |
"--validation_epochs=500",
|
|
@@ -68,7 +69,7 @@ def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instan
|
|
| 68 |
except subprocess.CalledProcessError as e:
|
| 69 |
print(f"An error occurred: {e}")
|
| 70 |
|
| 71 |
-
def main(image_path, b_lora_trained_folder, instance_prompt, training_type, training_steps):
|
| 72 |
|
| 73 |
if is_shared_ui:
|
| 74 |
raise gr.Error("This Space only works in duplicated instances")
|
|
@@ -100,7 +101,7 @@ def main(image_path, b_lora_trained_folder, instance_prompt, training_type, trai
|
|
| 100 |
|
| 101 |
max_train_steps = training_steps
|
| 102 |
|
| 103 |
-
train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, max_train_steps, checkpoint_steps)
|
| 104 |
|
| 105 |
your_username = api.whoami(token=hf_token)["name"]
|
| 106 |
|
|
@@ -208,7 +209,9 @@ with gr.Blocks(css=css) as demo:
|
|
| 208 |
with gr.Column():
|
| 209 |
training_type = gr.Radio(label="Training type", choices=["style", "concept"], value="style")
|
| 210 |
b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
|
| 211 |
-
|
|
|
|
|
|
|
| 212 |
training_steps = gr.Number(label="Training steps", value=1000, interactive=False)
|
| 213 |
checkpoint_step = gr.Number(label="checkpoint step", visible=False, value=500)
|
| 214 |
train_btn = gr.Button("Train B-LoRa")
|
|
@@ -222,7 +225,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 222 |
|
| 223 |
train_btn.click(
|
| 224 |
fn = main,
|
| 225 |
-
inputs = [image, b_lora_name, instance_prompt, training_type, training_steps],
|
| 226 |
outputs = [status]
|
| 227 |
)
|
| 228 |
|
|
|
|
| 30 |
elif training_type == "concept" :
|
| 31 |
return 2000, 1000
|
| 32 |
|
| 33 |
+
def train_dreambooth_blora_sdxl(instance_data_dir, b_lora_trained_folder, instance_prompt, class_prompt, max_train_steps, checkpoint_steps):
|
| 34 |
|
| 35 |
script_filename = "train_dreambooth_b-lora_sdxl.py" # Assuming it's in the same folder
|
| 36 |
|
|
|
|
| 42 |
f"--instance_data_dir={instance_data_dir}",
|
| 43 |
f"--output_dir={b_lora_trained_folder}",
|
| 44 |
f"--instance_prompt='{instance_prompt}'",
|
| 45 |
+
f"--class_prompt={class_prompt}",
|
| 46 |
#f"--validation_prompt=a teddy bear in {instance_prompt} style",
|
| 47 |
"--num_validation_images=1",
|
| 48 |
"--validation_epochs=500",
|
|
|
|
| 69 |
except subprocess.CalledProcessError as e:
|
| 70 |
print(f"An error occurred: {e}")
|
| 71 |
|
| 72 |
+
def main(image_path, b_lora_trained_folder, instance_prompt, class_prompt, training_type, training_steps):
|
| 73 |
|
| 74 |
if is_shared_ui:
|
| 75 |
raise gr.Error("This Space only works in duplicated instances")
|
|
|
|
| 101 |
|
| 102 |
max_train_steps = training_steps
|
| 103 |
|
| 104 |
+
train_dreambooth_blora_sdxl(local_dir, b_lora_trained_folder, instance_prompt, class_prompt, max_train_steps, checkpoint_steps)
|
| 105 |
|
| 106 |
your_username = api.whoami(token=hf_token)["name"]
|
| 107 |
|
|
|
|
| 209 |
with gr.Column():
|
| 210 |
training_type = gr.Radio(label="Training type", choices=["style", "concept"], value="style")
|
| 211 |
b_lora_name = gr.Textbox(label="Name your B-LoRa model", placeholder="b_lora_trained_folder")
|
| 212 |
+
with gr.Row():
|
| 213 |
+
instance_prompt = gr.Textbox(label="Create instance prompt", placeholder="A [v42] <class_prompt>")
|
| 214 |
+
class_prompt = gr.Textbox(label="Specify class prompt", placeholder="style | person | dog ")
|
| 215 |
training_steps = gr.Number(label="Training steps", value=1000, interactive=False)
|
| 216 |
checkpoint_step = gr.Number(label="checkpoint step", visible=False, value=500)
|
| 217 |
train_btn = gr.Button("Train B-LoRa")
|
|
|
|
| 225 |
|
| 226 |
train_btn.click(
|
| 227 |
fn = main,
|
| 228 |
+
inputs = [image, b_lora_name, instance_prompt, class_prompt, training_type, training_steps],
|
| 229 |
outputs = [status]
|
| 230 |
)
|
| 231 |
|