thuanz123 commited on
Commit
275aca4
·
1 Parent(s): ef0d21e

Upload 9 files

Browse files
Files changed (9) hide show
  1. README.md +6 -5
  2. app.py +275 -0
  3. colab.py +371 -0
  4. inference.py +63 -0
  5. requirements.txt +12 -0
  6. style.css +3 -0
  7. train_realfill.py +952 -0
  8. trainer.py +145 -0
  9. uploader.py +17 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Peft Sd Realfill
3
- emoji: 🏆
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.48.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Peft Lora Sd Dreambooth
3
+ emoji: 🎨
4
+ colorFrom: purple
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.16.2
8
  app_file: app.py
9
  pinned: false
10
+ license: openrail
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft)
4
+
5
+ The code in this repo is partly adapted from the following repositories:
6
+ https://huggingface.co/spaces/hysts/LoRA-SD-training
7
+ https://huggingface.co/spaces/multimodalart/dreambooth-training
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import pathlib
13
+
14
+ import gradio as gr
15
+ import torch
16
+ from typing import List
17
+
18
+ from inference import InferencePipeline
19
+ from trainer import Trainer
20
+ from uploader import upload
21
+
22
+
23
+ TITLE = "# RealFill Training and Inference Demo 🎨"
24
+ DESCRIPTION = "Demo showcasing parameter-efficient fine-tuning of Stable Diffusion Inpainting via RealFill leveraging 🤗 PEFT (https://github.com/huggingface/peft)."
25
+
26
+
27
+ ORIGINAL_SPACE_ID = "thuanz123/peft-sd-realfill"
28
+
29
+ SPACE_ID = os.getenv("SPACE_ID", ORIGINAL_SPACE_ID)
30
+ SHARED_UI_WARNING = f"""# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
31
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
32
+ """
33
+ if os.getenv("SYSTEM") == "spaces" and SPACE_ID != ORIGINAL_SPACE_ID:
34
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
35
+
36
+ else:
37
+ SETTINGS = "Settings"
38
+ CUDA_NOT_AVAILABLE_WARNING = f"""# Attention - Running on CPU.
39
+ <center>
40
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
41
+ "T4 small" is sufficient to run this demo.
42
+ </center>
43
+ """
44
+
45
+
46
+ def show_warning(warning_text: str) -> gr.Blocks:
47
+ with gr.Blocks() as demo:
48
+ with gr.Box():
49
+ gr.Markdown(warning_text)
50
+ return demo
51
+
52
+
53
+ def update_output_files() -> dict:
54
+ paths = sorted(pathlib.Path("results").glob("*.pt"))
55
+ config_paths = sorted(pathlib.Path("results").glob("*.json"))
56
+ paths = paths + config_paths
57
+ paths = [path.as_posix() for path in paths] # type: ignore
58
+ return gr.update(value=paths or None)
59
+
60
+
61
+ def create_training_demo(trainer: Trainer, pipe: InferencePipeline) -> gr.Blocks:
62
+ with gr.Blocks() as demo:
63
+ base_model = gr.Dropdown(
64
+ choices=[
65
+ "runwayml/stable-diffusion-inpainting",
66
+ "stabilityai/stable-diffusion-2-inpainting",
67
+ ],
68
+ value="stabilityai/stable-diffusion-2-inpainting",
69
+ label="Base Model",
70
+ visible=True,
71
+ )
72
+ resolution = gr.Dropdown(choices=["512"], value="512", label="Resolution", visible=False)
73
+
74
+ with gr.Row():
75
+ with gr.Box():
76
+ gr.Markdown("Training Data")
77
+ ref_images = gr.Files(label="Reference images")
78
+ target_image = gr.Files(label="Target image")
79
+ target_mask = gr.Files(label="Target mask")
80
+ gr.Markdown(
81
+ """
82
+ - Upload reference images of the scene you are planning on training on.
83
+ - For a concept prompt, use a unique, made up word to avoid collisions.
84
+ - Guidelines for getting good results:
85
+ - 1-5 images of the object from different angles
86
+ - 2000 iterations should be good enough.
87
+ - LoRA Rank for unet - 8
88
+ - LoRA Alpha for unet - 16
89
+ - lora dropout - 0.1
90
+ - LoRA Bias for unet - `none`
91
+ - Uncheck `FP16` and `8bit-Adam` only if you have VRAM at least 32GB
92
+ - Experiment with various values for lora dropouts, enabling/disabling fp16 and 8bit-Adam
93
+ """
94
+ )
95
+ with gr.Box():
96
+ gr.Markdown("Training Parameters")
97
+ num_training_steps = gr.Number(label="Number of Training Steps", value=2000, precision=0)
98
+ unet_learning_rate = gr.Number(label="Unet Learning Rate", value=2e-4)
99
+ text_encoder_learning_rate = gr.Number(label="Text Encoder Learning Rate", value=4e-5)
100
+ gradient_checkpointing = gr.Checkbox(label="Whether to use gradient checkpointing", value=True)
101
+ lora_rank = gr.Number(label="LoRA Rank for unet", value=8, precision=0)
102
+ lora_alpha = gr.Number(
103
+ label="LoRA Alpha for unet. scaling factor = lora_alpha/lora_r", value=16, precision=0
104
+ )
105
+ lora_dropout = gr.Number(label="lora dropout", value=0.1)
106
+ lora_bias = gr.Dropdown(
107
+ choices=["none", "all", "lora_only"],
108
+ value="none",
109
+ label="LoRA Bias for unet. This enables bias params to be trainable based on the bias type",
110
+ visible=True,
111
+ )
112
+ gradient_accumulation = gr.Number(label="Number of Gradient Accumulation", value=1, precision=0)
113
+ fp16 = gr.Checkbox(label="FP16", value=True)
114
+ use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=True)
115
+ gr.Markdown(
116
+ """
117
+ - It will take about 40-60 minutes to train for 2000 steps with a T4 GPU.
118
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
119
+ - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
120
+ """
121
+ )
122
+
123
+ run_button = gr.Button("Start Training")
124
+ with gr.Box():
125
+ with gr.Row():
126
+ check_status_button = gr.Button("Check Training Status")
127
+ with gr.Column():
128
+ with gr.Box():
129
+ gr.Markdown("Message")
130
+ training_status = gr.Markdown()
131
+ output_files = gr.Files(label="Trained Model Files")
132
+
133
+ run_button.click(fn=pipe.clear)
134
+
135
+ run_button.click(
136
+ fn=trainer.run,
137
+ inputs=[
138
+ base_model,
139
+ resolution,
140
+ num_training_steps,
141
+ ref_images,
142
+ target_image,
143
+ target_mask,
144
+ unet_learning_rate,
145
+ text_encoder_learning_rate,
146
+ gradient_accumulation,
147
+ fp16,
148
+ use_8bit_adam,
149
+ gradient_checkpointing,
150
+ lora_rank,
151
+ lora_alpha,
152
+ lora_bias,
153
+ lora_dropout,
154
+ ],
155
+ outputs=[
156
+ training_status,
157
+ output_files,
158
+ ],
159
+ queue=False,
160
+ )
161
+ check_status_button.click(fn=trainer.check_if_running, inputs=None, outputs=training_status, queue=False)
162
+ check_status_button.click(fn=update_output_files, inputs=None, outputs=output_files, queue=False)
163
+ return demo
164
+
165
+
166
+ def find_model_files() -> list[str]:
167
+ curr_dir = pathlib.Path(__file__).parent
168
+ paths = sorted(curr_dir.glob('*'))
169
+ paths = [
170
+ path for path in paths
171
+ if (path / 'model_index.json').exists()
172
+ ]
173
+ return [path.relative_to(curr_dir).as_posix() for path in paths]
174
+
175
+
176
+ def reload_realfill_model_list() -> dict:
177
+ return gr.update(choices=find_model_files())
178
+
179
+
180
+ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
181
+ with gr.Blocks() as demo:
182
+ with gr.Row():
183
+ with gr.Column():
184
+ reload_button = gr.Button("Reload Model List")
185
+ realfill_model = gr.Dropdown(
186
+ choices=find_model_files(), label="RealFill Model File"
187
+ )
188
+ target_image = gr.Files(label="Target image")
189
+ target_mask = gr.Files(label="Target mask")
190
+ seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=1)
191
+ with gr.Accordion("Other Parameters", open=False):
192
+ num_steps = gr.Slider(label="Number of Steps", minimum=0, maximum=1000, step=1, value=50)
193
+ guidance_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=50, step=0.1, value=7)
194
+
195
+ run_button = gr.Button("Generate")
196
+
197
+ gr.Markdown(
198
+ """
199
+ - After training, you can press "Reload Model List" button to load your trained model names.
200
+ """
201
+ )
202
+ with gr.Column():
203
+ result = gr.Image(label="Result")
204
+
205
+ reload_button.click(fn=reload_realfill_model_list, inputs=None, outputs=realfill_model)
206
+ run_button.click(
207
+ fn=pipe.run,
208
+ inputs=[
209
+ realfill_model,
210
+ target_image,
211
+ target_mask,
212
+ seed,
213
+ num_steps,
214
+ guidance_scale,
215
+ ],
216
+ outputs=result,
217
+ queue=False,
218
+ )
219
+ seed.change(
220
+ fn=pipe.run,
221
+ inputs=[
222
+ realfill_model,
223
+ target_image,
224
+ target_mask,
225
+ seed,
226
+ num_steps,
227
+ guidance_scale,
228
+ ],
229
+ outputs=result,
230
+ queue=False,
231
+ )
232
+ return demo
233
+
234
+
235
+ def create_upload_demo() -> gr.Blocks:
236
+ with gr.Blocks() as demo:
237
+ model_name = gr.Textbox(label="Model Name")
238
+ hf_token = gr.Textbox(label="Hugging Face Token (with write permission)")
239
+ upload_button = gr.Button("Upload")
240
+ with gr.Box():
241
+ gr.Markdown("Message")
242
+ result = gr.Markdown()
243
+ gr.Markdown(
244
+ """
245
+ - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
246
+ - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
247
+ """
248
+ )
249
+
250
+ upload_button.click(fn=upload, inputs=[model_name, hf_token], outputs=result)
251
+
252
+ return demo
253
+
254
+
255
+ pipe = InferencePipeline()
256
+ trainer = Trainer()
257
+
258
+ with gr.Blocks(css="style.css") as demo:
259
+ if os.getenv("IS_SHARED_UI"):
260
+ show_warning(SHARED_UI_WARNING)
261
+ if not torch.cuda.is_available():
262
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
263
+
264
+ gr.Markdown(TITLE)
265
+ gr.Markdown(DESCRIPTION)
266
+
267
+ with gr.Tabs():
268
+ with gr.TabItem("Train"):
269
+ create_training_demo(trainer, pipe)
270
+ with gr.TabItem("Test"):
271
+ create_inference_demo(pipe)
272
+ with gr.TabItem("Upload"):
273
+ create_upload_demo()
274
+
275
+ demo.queue(default_enabled=False).launch(share=False)
colab.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft)
4
+
5
+ The code in this repo is partly adapted from the following repositories:
6
+ https://huggingface.co/spaces/hysts/LoRA-SD-training
7
+ https://huggingface.co/spaces/multimodalart/dreambooth-training
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ import pathlib
13
+
14
+ import gradio as gr
15
+ import torch
16
+ from typing import List
17
+
18
+ from inference import InferencePipeline
19
+ from trainer import Trainer
20
+ from uploader import upload
21
+
22
+
23
+ TITLE = "# LoRA + Dreambooth Training and Inference Demo 🎨"
24
+ DESCRIPTION = "Demo showcasing parameter-efficient fine-tuning of Stable Dissfusion via Dreambooth leveraging 🤗 PEFT (https://github.com/huggingface/peft)."
25
+
26
+
27
+ ORIGINAL_SPACE_ID = "smangrul/peft-lora-sd-dreambooth"
28
+
29
+ SPACE_ID = os.getenv("SPACE_ID", ORIGINAL_SPACE_ID)
30
+ SHARED_UI_WARNING = f"""# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
31
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
32
+ """
33
+ if os.getenv("SYSTEM") == "spaces" and SPACE_ID != ORIGINAL_SPACE_ID:
34
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
35
+
36
+ else:
37
+ SETTINGS = "Settings"
38
+ CUDA_NOT_AVAILABLE_WARNING = f"""# Attention - Running on CPU.
39
+ <center>
40
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
41
+ "T4 small" is sufficient to run this demo.
42
+ </center>
43
+ """
44
+
45
+
46
+ def show_warning(warning_text: str) -> gr.Blocks:
47
+ with gr.Blocks() as demo:
48
+ with gr.Box():
49
+ gr.Markdown(warning_text)
50
+ return demo
51
+
52
+
53
+ def update_output_files() -> dict:
54
+ paths = sorted(pathlib.Path("results").glob("*.pt"))
55
+ config_paths = sorted(pathlib.Path("results").glob("*.json"))
56
+ paths = paths + config_paths
57
+ paths = [path.as_posix() for path in paths] # type: ignore
58
+ return gr.update(value=paths or None)
59
+
60
+
61
+ def create_training_demo(trainer: Trainer, pipe: InferencePipeline) -> gr.Blocks:
62
+ with gr.Blocks() as demo:
63
+ base_model = gr.Dropdown(
64
+ choices=[
65
+ "CompVis/stable-diffusion-v1-4",
66
+ "runwayml/stable-diffusion-v1-5",
67
+ "stabilityai/stable-diffusion-2-1-base",
68
+ ],
69
+ value="runwayml/stable-diffusion-v1-5",
70
+ label="Base Model",
71
+ visible=True,
72
+ )
73
+ resolution = gr.Dropdown(choices=["512"], value="512", label="Resolution", visible=False)
74
+
75
+ with gr.Row():
76
+ with gr.Box():
77
+ gr.Markdown("Training Data")
78
+ concept_images = gr.Files(label="Images for your concept")
79
+ concept_prompt = gr.Textbox(label="Concept Prompt", max_lines=1)
80
+ gr.Markdown(
81
+ """
82
+ - Upload images of the style you are planning on training on.
83
+ - For a concept prompt, use a unique, made up word to avoid collisions.
84
+ - Guidelines for getting good results:
85
+ - Dreambooth for an `object` or `style`:
86
+ - 5-10 images of the object from different angles
87
+ - 500-800 iterations should be good enough.
88
+ - Prior preservation is recommended.
89
+ - `class_prompt`:
90
+ - `a photo of object`
91
+ - `style`
92
+ - `concept_prompt`:
93
+ - `<concept prompt> object`
94
+ - `<concept prompt> style`
95
+ - `a photo of <concept prompt> object`
96
+ - `a photo of <concept prompt> style`
97
+ - Dreambooth for a `Person/Face`:
98
+ - 15-50 images of the person from different angles, lighting, and expressions.
99
+ Have considerable photos with close up faces.
100
+ - 800-1200 iterations should be good enough.
101
+ - good defaults for hyperparams
102
+ - Model - `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1-base`
103
+ - Use/check Prior preservation.
104
+ - Number of class images to use - 200
105
+ - Prior Loss Weight - 1
106
+ - LoRA Rank for unet - 16
107
+ - LoRA Alpha for unet - 20
108
+ - lora dropout - 0
109
+ - LoRA Bias for unet - `all`
110
+ - LoRA Rank for CLIP - 16
111
+ - LoRA Alpha for CLIP - 17
112
+ - LoRA Bias for CLIP - `all`
113
+ - lora dropout for CLIP - 0
114
+ - Uncheck `FP16` and `8bit-Adam` (don't use them for faces)
115
+ - `class_prompt`: Use the gender related word of the person
116
+ - `man`
117
+ - `woman`
118
+ - `boy`
119
+ - `girl`
120
+ - `concept_prompt`: just the unique, made up word, e.g., `srm`
121
+ - Choose `all` for `lora_bias` and `text_encode_lora_bias`
122
+ - Dreambooth for a `Scene`:
123
+ - 15-50 images of the scene from different angles, lighting, and expressions.
124
+ - 800-1200 iterations should be good enough.
125
+ - Prior preservation is recommended.
126
+ - `class_prompt`:
127
+ - `scene`
128
+ - `landscape`
129
+ - `city`
130
+ - `beach`
131
+ - `mountain`
132
+ - `concept_prompt`:
133
+ - `<concept prompt> scene`
134
+ - `<concept prompt> landscape`
135
+ - Experiment with various values for lora dropouts, enabling/disabling fp16 and 8bit-Adam
136
+ """
137
+ )
138
+ with gr.Box():
139
+ gr.Markdown("Training Parameters")
140
+ num_training_steps = gr.Number(label="Number of Training Steps", value=1000, precision=0)
141
+ learning_rate = gr.Number(label="Learning Rate", value=0.0001)
142
+ gradient_checkpointing = gr.Checkbox(label="Whether to use gradient checkpointing", value=True)
143
+ train_text_encoder = gr.Checkbox(label="Train Text Encoder", value=True)
144
+ with_prior_preservation = gr.Checkbox(label="Prior Preservation", value=True)
145
+ class_prompt = gr.Textbox(
146
+ label="Class Prompt", max_lines=1, placeholder='Example: "a photo of object"'
147
+ )
148
+ num_class_images = gr.Number(label="Number of class images to use", value=50, precision=0)
149
+ prior_loss_weight = gr.Number(label="Prior Loss Weight", value=1.0, precision=1)
150
+ # use_lora = gr.Checkbox(label="Whether to use LoRA", value=True)
151
+ lora_r = gr.Number(label="LoRA Rank for unet", value=4, precision=0)
152
+ lora_alpha = gr.Number(
153
+ label="LoRA Alpha for unet. scaling factor = lora_r/lora_alpha", value=4, precision=0
154
+ )
155
+ lora_dropout = gr.Number(label="lora dropout", value=0.00)
156
+ lora_bias = gr.Dropdown(
157
+ choices=["none", "all", "lora_only"],
158
+ value="none",
159
+ label="LoRA Bias for unet. This enables bias params to be trainable based on the bias type",
160
+ visible=True,
161
+ )
162
+ lora_text_encoder_r = gr.Number(label="LoRA Rank for CLIP", value=4, precision=0)
163
+ lora_text_encoder_alpha = gr.Number(
164
+ label="LoRA Alpha for CLIP. scaling factor = lora_r/lora_alpha", value=4, precision=0
165
+ )
166
+ lora_text_encoder_dropout = gr.Number(label="lora dropout for CLIP", value=0.00)
167
+ lora_text_encoder_bias = gr.Dropdown(
168
+ choices=["none", "all", "lora_only"],
169
+ value="none",
170
+ label="LoRA Bias for CLIP. This enables bias params to be trainable based on the bias type",
171
+ visible=True,
172
+ )
173
+ gradient_accumulation = gr.Number(label="Number of Gradient Accumulation", value=1, precision=0)
174
+ fp16 = gr.Checkbox(label="FP16", value=True)
175
+ use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=True)
176
+ gr.Markdown(
177
+ """
178
+ - It will take about 20-30 minutes to train for 1000 steps with a T4 GPU.
179
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
180
+ - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
181
+ """
182
+ )
183
+
184
+ run_button = gr.Button("Start Training")
185
+ with gr.Box():
186
+ with gr.Row():
187
+ check_status_button = gr.Button("Check Training Status")
188
+ with gr.Column():
189
+ with gr.Box():
190
+ gr.Markdown("Message")
191
+ training_status = gr.Markdown()
192
+ output_files = gr.Files(label="Trained Weight Files and Configs")
193
+
194
+ run_button.click(fn=pipe.clear)
195
+
196
+ run_button.click(
197
+ fn=trainer.run,
198
+ inputs=[
199
+ base_model,
200
+ resolution,
201
+ num_training_steps,
202
+ concept_images,
203
+ concept_prompt,
204
+ learning_rate,
205
+ gradient_accumulation,
206
+ fp16,
207
+ use_8bit_adam,
208
+ gradient_checkpointing,
209
+ train_text_encoder,
210
+ with_prior_preservation,
211
+ prior_loss_weight,
212
+ class_prompt,
213
+ num_class_images,
214
+ lora_r,
215
+ lora_alpha,
216
+ lora_bias,
217
+ lora_dropout,
218
+ lora_text_encoder_r,
219
+ lora_text_encoder_alpha,
220
+ lora_text_encoder_bias,
221
+ lora_text_encoder_dropout,
222
+ ],
223
+ outputs=[
224
+ training_status,
225
+ output_files,
226
+ ],
227
+ queue=False,
228
+ )
229
+ check_status_button.click(fn=trainer.check_if_running, inputs=None, outputs=training_status, queue=False)
230
+ check_status_button.click(fn=update_output_files, inputs=None, outputs=output_files, queue=False)
231
+ return demo
232
+
233
+
234
+ def find_weight_files() -> List[str]:
235
+ curr_dir = pathlib.Path(__file__).parent
236
+ paths = sorted(curr_dir.rglob("*.pt"))
237
+ return [path.relative_to(curr_dir).as_posix() for path in paths]
238
+
239
+
240
+ def reload_lora_weight_list() -> dict:
241
+ return gr.update(choices=find_weight_files())
242
+
243
+
244
+ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
245
+ with gr.Blocks() as demo:
246
+ with gr.Row():
247
+ with gr.Column():
248
+ base_model = gr.Dropdown(
249
+ choices=[
250
+ "CompVis/stable-diffusion-v1-4",
251
+ "runwayml/stable-diffusion-v1-5",
252
+ "stabilityai/stable-diffusion-2-1-base",
253
+ ],
254
+ value="runwayml/stable-diffusion-v1-5",
255
+ label="Base Model",
256
+ visible=True,
257
+ )
258
+ reload_button = gr.Button("Reload Weight List")
259
+ lora_weight_name = gr.Dropdown(
260
+ choices=find_weight_files(), value="lora/lora_disney.pt", label="LoRA Weight File"
261
+ )
262
+ prompt = gr.Textbox(label="Prompt", max_lines=1, placeholder='Example: "style of sks, baby lion"')
263
+ negative_prompt = gr.Textbox(
264
+ label="Negative Prompt", max_lines=1, placeholder='Example: "blurry, botched, low quality"'
265
+ )
266
+ seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, value=1)
267
+ with gr.Accordion("Other Parameters", open=False):
268
+ num_steps = gr.Slider(label="Number of Steps", minimum=0, maximum=1000, step=1, value=50)
269
+ guidance_scale = gr.Slider(label="CFG Scale", minimum=0, maximum=50, step=0.1, value=7)
270
+
271
+ run_button = gr.Button("Generate")
272
+
273
+ gr.Markdown(
274
+ """
275
+ - After training, you can press "Reload Weight List" button to load your trained model names.
276
+ - Few repos to refer for ideas:
277
+ - https://huggingface.co/smangrul/smangrul
278
+ - https://huggingface.co/smangrul/painting-in-the-style-of-smangrul
279
+ - https://huggingface.co/smangrul/erenyeager
280
+ """
281
+ )
282
+ with gr.Column():
283
+ result = gr.Image(label="Result")
284
+
285
+ reload_button.click(fn=reload_lora_weight_list, inputs=None, outputs=lora_weight_name)
286
+ prompt.submit(
287
+ fn=pipe.run,
288
+ inputs=[
289
+ base_model,
290
+ lora_weight_name,
291
+ prompt,
292
+ negative_prompt,
293
+ seed,
294
+ num_steps,
295
+ guidance_scale,
296
+ ],
297
+ outputs=result,
298
+ queue=False,
299
+ )
300
+ run_button.click(
301
+ fn=pipe.run,
302
+ inputs=[
303
+ base_model,
304
+ lora_weight_name,
305
+ prompt,
306
+ negative_prompt,
307
+ seed,
308
+ num_steps,
309
+ guidance_scale,
310
+ ],
311
+ outputs=result,
312
+ queue=False,
313
+ )
314
+ seed.change(
315
+ fn=pipe.run,
316
+ inputs=[
317
+ base_model,
318
+ lora_weight_name,
319
+ prompt,
320
+ negative_prompt,
321
+ seed,
322
+ num_steps,
323
+ guidance_scale,
324
+ ],
325
+ outputs=result,
326
+ queue=False,
327
+ )
328
+ return demo
329
+
330
+
331
+ def create_upload_demo() -> gr.Blocks:
332
+ with gr.Blocks() as demo:
333
+ model_name = gr.Textbox(label="Model Name")
334
+ hf_token = gr.Textbox(label="Hugging Face Token (with write permission)")
335
+ upload_button = gr.Button("Upload")
336
+ with gr.Box():
337
+ gr.Markdown("Message")
338
+ result = gr.Markdown()
339
+ gr.Markdown(
340
+ """
341
+ - You can upload your trained model to your private Model repo (i.e. https://huggingface.co/{your_username}/{model_name}).
342
+ - You can find your Hugging Face token [here](https://huggingface.co/settings/tokens).
343
+ """
344
+ )
345
+
346
+ upload_button.click(fn=upload, inputs=[model_name, hf_token], outputs=result)
347
+
348
+ return demo
349
+
350
+
351
+ pipe = InferencePipeline()
352
+ trainer = Trainer()
353
+
354
+ with gr.Blocks(css="style.css") as demo:
355
+ if os.getenv("IS_SHARED_UI"):
356
+ show_warning(SHARED_UI_WARNING)
357
+ if not torch.cuda.is_available():
358
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
359
+
360
+ gr.Markdown(TITLE)
361
+ gr.Markdown(DESCRIPTION)
362
+
363
+ with gr.Tabs():
364
+ with gr.TabItem("Train"):
365
+ create_training_demo(trainer, pipe)
366
+ with gr.TabItem("Test"):
367
+ create_inference_demo(pipe)
368
+ with gr.TabItem("Upload"):
369
+ create_upload_demo()
370
+
371
+ demo.queue(default_enabled=False).launch(share=True)
inference.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import json
5
+ import pathlib
6
+ import sys
7
+
8
+ import gradio as gr
9
+ import PIL.Image
10
+ import torch
11
+ from diffusers import StableDiffusionInpaintingPipeline
12
+
13
+
14
+ class InferencePipeline:
15
+ def __init__(self):
16
+ self.pipe = None
17
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+
19
+ def clear(self) -> None:
20
+ del self.pipe
21
+ self.pipe = None
22
+ torch.cuda.empty_cache()
23
+ gc.collect()
24
+
25
+ def load_pipe(self, realfill_model: str) -> None:
26
+ pipe = StableDiffusionInpaintingPipeline.from_pretrained(
27
+ realfill_model, torch_dtype=torch.float16
28
+ ).to(self.device)
29
+ pipe = pipe.to(self.device)
30
+ self.pipe = pipe
31
+
32
+ def run(
33
+ self,
34
+ realfill_model: str,
35
+ target_image: PIL.Image,
36
+ target_mask: PIL.Image,
37
+ seed: int,
38
+ n_steps: int,
39
+ guidance_scale: float,
40
+ ) -> PIL.Image.Image:
41
+ if not torch.cuda.is_available():
42
+ raise gr.Error("CUDA is not available.")
43
+
44
+ self.load_pipe(realfill_model)
45
+
46
+ image = PIL.Image.open(target_image)
47
+ mask_image = PIL.Image.open(target_mask)
48
+
49
+ generator = torch.Generator(device=self.device).manual_seed(seed)
50
+ out = self.pipe(
51
+ "a photo of sks",
52
+ image=image,
53
+ mask_image=mask_image,
54
+ num_inference_steps=n_steps,
55
+ guidance_scale=guidance_scale,
56
+ generator=generator,
57
+ ).images[0] # type: ignore
58
+
59
+ erode_kernel = PIL.ImageFilter.MaxFilter(3)
60
+ mask_image = mask_image.filter(erode_kernel)
61
+
62
+ result = PIL.Image.composite(result, out, mask_image)
63
+ return result
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.20.1
2
+ accelerate==0.23.0
3
+ transformers==4.34.0
4
+ peft==0.5.0
5
+ torch==2.0.1
6
+ torchvision==0.15.2
7
+ ftfy==6.1.1
8
+ tensorboard==2.14.0
9
+ Jinja2==3.1.2
10
+ Pillow==10.0.1
11
+ bitsandbytes==0.41.1
12
+ gradio==3.47.1
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
train_realfill.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import argparse
3
+ import copy
4
+ import itertools
5
+ import logging
6
+ import math
7
+ import os
8
+ import shutil
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint
15
+ import torchvision.transforms.v2 as transforms_v2
16
+ import transformers
17
+ from accelerate import Accelerator
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import set_seed
20
+ from huggingface_hub import create_repo, upload_folder
21
+ from packaging import version
22
+ from PIL import Image
23
+ from PIL.ImageOps import exif_transpose
24
+ from torch.utils.data import Dataset
25
+ from torchvision import transforms
26
+ from tqdm.auto import tqdm
27
+ from transformers import AutoTokenizer, CLIPTextModel
28
+
29
+ import diffusers
30
+ from diffusers import (
31
+ AutoencoderKL,
32
+ DDPMScheduler,
33
+ StableDiffusionInpaintPipeline,
34
+ DPMSolverMultistepScheduler,
35
+ UNet2DConditionModel,
36
+ )
37
+ from diffusers.optimization import get_scheduler
38
+ from diffusers.utils import check_min_version, is_wandb_available
39
+ from diffusers.utils.import_utils import is_xformers_available
40
+
41
+ from peft import PeftModel, LoraConfig, get_peft_model
42
+
43
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
44
+ check_min_version("0.20.1")
45
+
46
+ logger = get_logger(__name__)
47
+
48
+ def make_mask(images, resolution, times=30):
49
+ mask, times = torch.ones_like(images[0:1, :, :]), np.random.randint(1, times)
50
+ min_size, max_size, margin = np.array([0.03, 0.25, 0.01]) * resolution
51
+ max_size = min(max_size, resolution - margin * 2)
52
+
53
+ for _ in range(times):
54
+ width = np.random.randint(int(min_size), int(max_size))
55
+ height = np.random.randint(int(min_size), int(max_size))
56
+
57
+ x_start = np.random.randint(int(margin), resolution - int(margin) - width + 1)
58
+ y_start = np.random.randint(int(margin), resolution - int(margin) - height + 1)
59
+ mask[:, y_start:y_start + height, x_start:x_start + width] = 0
60
+
61
+ mask = 1 - mask if random.random() < 0.5 else mask
62
+ return mask
63
+
64
+ def save_model_card(
65
+ repo_id: str,
66
+ images=None,
67
+ base_model=str,
68
+ repo_folder=None,
69
+ ):
70
+ img_str = ""
71
+ for i, image in enumerate(images):
72
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
73
+ img_str += f"![img_{i}](./image_{i}.png)\n"
74
+
75
+ yaml = f"""
76
+ ---
77
+ license: creativeml-openrail-m
78
+ base_model: {base_model}
79
+ prompt: "a photo of sks"
80
+ tags:
81
+ - stable-diffusion-inpainting
82
+ - stable-diffusion-inpainting-diffusers
83
+ - text-to-image
84
+ - diffusers
85
+ - realfill
86
+ inference: true
87
+ ---
88
+ """
89
+ model_card = f"""
90
+ # RealFill - {repo_id}
91
+
92
+ This is a realfill model derived from {base_model}. The weights were trained using [RealFill](https://realfill.github.io/).
93
+ You can find some example images in the following. \n
94
+ {img_str}
95
+ """
96
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
97
+ f.write(yaml + model_card)
98
+
99
+ def log_validation(
100
+ text_encoder,
101
+ tokenizer,
102
+ unet,
103
+ args,
104
+ accelerator,
105
+ weight_dtype,
106
+ epoch,
107
+ ):
108
+ logger.info(
109
+ f"Running validation... \nGenerating {args.num_validation_images} images"
110
+ )
111
+
112
+ # create pipeline (note: unet and vae are loaded again in float32)
113
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
114
+ args.pretrained_model_name_or_path,
115
+ tokenizer=tokenizer,
116
+ revision=args.revision,
117
+ torch_dtype=weight_dtype,
118
+ )
119
+
120
+ # set `keep_fp32_wrapper` to True because we do not want to remove
121
+ # mixed precision hooks while we are still training
122
+ pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
123
+ pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
124
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
125
+
126
+ pipeline = pipeline.to(accelerator.device)
127
+ pipeline.set_progress_bar_config(disable=True)
128
+
129
+ # run inference
130
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
131
+
132
+ target_dir = Path(args.train_data_dir) / "target"
133
+ target_image, target_mask = target_dir / "target.png", target_dir / "mask.png"
134
+ image, mask_image = Image.open(target_image), Image.open(target_mask)
135
+
136
+ if image.mode != "RGB":
137
+ image = image.convert("RGB")
138
+
139
+ images = []
140
+ for _ in range(args.num_validation_images):
141
+ image = pipeline(
142
+ prompt="a photo of sks", image=image, mask_image=mask_image,
143
+ num_inference_steps=25, guidance_scale=5, generator=generator
144
+ ).images[0]
145
+ images.append(image)
146
+
147
+ for tracker in accelerator.trackers:
148
+ if tracker.name == "tensorboard":
149
+ np_images = np.stack([np.asarray(img) for img in images])
150
+ tracker.writer.add_images(f"validation", np_images, epoch, dataformats="NHWC")
151
+ if tracker.name == "wandb":
152
+ tracker.log(
153
+ {
154
+ f"validation": [
155
+ wandb.Image(image, caption=str(i)) for i, image in enumerate(images)
156
+ ]
157
+ }
158
+ )
159
+
160
+ del pipeline
161
+ torch.cuda.empty_cache()
162
+
163
+ return images
164
+
165
+ def parse_args(input_args=None):
166
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
167
+ parser.add_argument(
168
+ "--pretrained_model_name_or_path",
169
+ type=str,
170
+ default=None,
171
+ required=True,
172
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
173
+ )
174
+ parser.add_argument(
175
+ "--revision",
176
+ type=str,
177
+ default=None,
178
+ required=False,
179
+ help="Revision of pretrained model identifier from huggingface.co/models.",
180
+ )
181
+ parser.add_argument(
182
+ "--tokenizer_name",
183
+ type=str,
184
+ default=None,
185
+ help="Pretrained tokenizer name or path if not the same as model_name",
186
+ )
187
+ parser.add_argument(
188
+ "--train_data_dir",
189
+ type=str,
190
+ default=None,
191
+ required=True,
192
+ help="A folder containing the training data of images.",
193
+ )
194
+ parser.add_argument(
195
+ "--num_validation_images",
196
+ type=int,
197
+ default=4,
198
+ help="Number of images that should be generated during validation with `validation_conditioning`.",
199
+ )
200
+ parser.add_argument(
201
+ "--validation_steps",
202
+ type=int,
203
+ default=100,
204
+ help=(
205
+ "Run realfill validation every X steps. RealFill validation consists of running the conditioning"
206
+ " `args.validation_conditioning` multiple times: `args.num_validation_images`."
207
+ ),
208
+ )
209
+ parser.add_argument(
210
+ "--output_dir",
211
+ type=str,
212
+ default="realfill-model",
213
+ help="The output directory where the model predictions and checkpoints will be written.",
214
+ )
215
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
216
+ parser.add_argument(
217
+ "--resolution",
218
+ type=int,
219
+ default=512,
220
+ help=(
221
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
222
+ " resolution"
223
+ ),
224
+ )
225
+ parser.add_argument(
226
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
227
+ )
228
+ parser.add_argument("--num_train_epochs", type=int, default=1)
229
+ parser.add_argument(
230
+ "--max_train_steps",
231
+ type=int,
232
+ default=None,
233
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
234
+ )
235
+ parser.add_argument(
236
+ "--checkpointing_steps",
237
+ type=int,
238
+ default=500,
239
+ help=(
240
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
241
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
242
+ " training using `--resume_from_checkpoint`."
243
+ ),
244
+ )
245
+ parser.add_argument(
246
+ "--checkpoints_total_limit",
247
+ type=int,
248
+ default=None,
249
+ help=("Max number of checkpoints to store."),
250
+ )
251
+ parser.add_argument(
252
+ "--resume_from_checkpoint",
253
+ type=str,
254
+ default=None,
255
+ help=(
256
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
257
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
258
+ ),
259
+ )
260
+ parser.add_argument(
261
+ "--gradient_accumulation_steps",
262
+ type=int,
263
+ default=1,
264
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
265
+ )
266
+ parser.add_argument(
267
+ "--gradient_checkpointing",
268
+ action="store_true",
269
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
270
+ )
271
+ parser.add_argument(
272
+ "--unet_learning_rate",
273
+ type=float,
274
+ default=2e-4,
275
+ help="Learning rate to use for unet.",
276
+ )
277
+ parser.add_argument(
278
+ "--text_encoder_learning_rate",
279
+ type=float,
280
+ default=4e-5,
281
+ help="Learning rate to use for text encoder.",
282
+ )
283
+ parser.add_argument(
284
+ "--scale_lr",
285
+ action="store_true",
286
+ default=False,
287
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
288
+ )
289
+ parser.add_argument(
290
+ "--lr_scheduler",
291
+ type=str,
292
+ default="constant",
293
+ help=(
294
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
295
+ ' "constant", "constant_with_warmup"]'
296
+ ),
297
+ )
298
+ parser.add_argument(
299
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
300
+ )
301
+ parser.add_argument(
302
+ "--lr_num_cycles",
303
+ type=int,
304
+ default=1,
305
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
306
+ )
307
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
308
+ parser.add_argument(
309
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
310
+ )
311
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
312
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
313
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
314
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
315
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
316
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
317
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
318
+ parser.add_argument(
319
+ "--hub_model_id",
320
+ type=str,
321
+ default=None,
322
+ help="The name of the repository to keep in sync with the local `output_dir`.",
323
+ )
324
+ parser.add_argument(
325
+ "--logging_dir",
326
+ type=str,
327
+ default="logs",
328
+ help=(
329
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
330
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
331
+ ),
332
+ )
333
+ parser.add_argument(
334
+ "--allow_tf32",
335
+ action="store_true",
336
+ help=(
337
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
338
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
339
+ ),
340
+ )
341
+ parser.add_argument(
342
+ "--report_to",
343
+ type=str,
344
+ default="tensorboard",
345
+ help=(
346
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
347
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
348
+ ),
349
+ )
350
+ parser.add_argument(
351
+ "--wandb_key",
352
+ type=str,
353
+ default=None,
354
+ help=("If report to option is set to wandb, api-key for wandb used for login to wandb "),
355
+ )
356
+ parser.add_argument(
357
+ "--wandb_project_name",
358
+ type=str,
359
+ default=None,
360
+ help=("If report to option is set to wandb, project name in wandb for log tracking "),
361
+ )
362
+ parser.add_argument(
363
+ "--mixed_precision",
364
+ type=str,
365
+ default=None,
366
+ choices=["no", "fp16", "bf16"],
367
+ help=(
368
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
369
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
370
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
371
+ ),
372
+ )
373
+ parser.add_argument(
374
+ "--set_grads_to_none",
375
+ action="store_true",
376
+ help=(
377
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
378
+ " behaviors, so disable this argument if it causes any problems. More info:"
379
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
380
+ ),
381
+ )
382
+ parser.add_argument(
383
+ "--lora_rank",
384
+ type=int,
385
+ default=16,
386
+ help=("The dimension of the LoRA update matrices."),
387
+ )
388
+ parser.add_argument(
389
+ "--lora_alpha",
390
+ type=int,
391
+ default=27,
392
+ help=("The alpha constant of the LoRA update matrices."),
393
+ )
394
+ parser.add_argument(
395
+ "--lora_dropout",
396
+ type=float,
397
+ default=0.1,
398
+ help="The dropout rate of the LoRA update matrices.",
399
+ )
400
+ parser.add_argument(
401
+ "--lora_bias",
402
+ type=str,
403
+ default="none",
404
+ help="The bias type of the Lora update matrices. Must be 'none', 'all' or 'lora_only'.",
405
+ )
406
+
407
+ if input_args is not None:
408
+ args = parser.parse_args(input_args)
409
+ else:
410
+ args = parser.parse_args()
411
+
412
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
413
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
414
+ args.local_rank = env_local_rank
415
+
416
+ return args
417
+
418
+ class RealFillDataset(Dataset):
419
+ """
420
+ A dataset to prepare the training and conditioning images and
421
+ the masks with the dummy prompt for fine-tuning the model.
422
+ It pre-processes the images, masks and tokenizes the prompts.
423
+ """
424
+
425
+ def __init__(
426
+ self,
427
+ train_data_root,
428
+ tokenizer,
429
+ size=512,
430
+ ):
431
+ self.size = size
432
+ self.tokenizer = tokenizer
433
+
434
+ self.ref_data_root = Path(train_data_root) / "ref"
435
+ self.target_image = Path(train_data_root) / "target" / "target.png"
436
+ self.target_mask = Path(train_data_root) / "target" / "mask.png"
437
+ if not (self.ref_data_root.exists() and self.target_image.exists() and self.target_mask.exists()):
438
+ raise ValueError("Train images root doesn't exists.")
439
+
440
+ self.train_images_path = list(self.ref_data_root.iterdir()) + [self.target_image]
441
+ self.num_train_images = len(self.train_images_path)
442
+ self.train_prompt = "a photo of sks"
443
+
444
+ self.image_transforms = transforms.Compose(
445
+ [
446
+ transforms_v2.RandomResize(size, int(1.125 * size)),
447
+ transforms.RandomCrop(size),
448
+ transforms.ToTensor(),
449
+ transforms.Normalize([0.5], [0.5]),
450
+ ]
451
+ )
452
+
453
+ def __len__(self):
454
+ return self.num_train_images
455
+
456
+ def __getitem__(self, index):
457
+ example = {}
458
+
459
+ image = Image.open(self.train_images_path[index])
460
+ image = exif_transpose(image)
461
+
462
+ if not image.mode == "RGB":
463
+ image = image.convert("RGB")
464
+ example["images"] = self.image_transforms(image)
465
+
466
+ if random.random() < 0.1:
467
+ example["masks"] = torch.ones_like(example["images"][0:1, :, :])
468
+ else:
469
+ example["masks"] = make_mask(example["images"], self.size)
470
+
471
+ if index < len(self) - 1:
472
+ example["weightings"] = torch.ones_like(example["masks"])
473
+ else:
474
+ weighting = Image.open(self.target_mask)
475
+ weighting = exif_transpose(weighting)
476
+
477
+ weightings = self.image_transforms(weighting)
478
+ example["weightings"] = weightings < 0.5
479
+
480
+ example["conditioning_images"] = example["images"] * (example["masks"] < 0.5)
481
+
482
+ train_prompt = "" if random.random() < 0.1 else self.train_prompt
483
+ example["prompt_ids"] = self.tokenizer(
484
+ train_prompt,
485
+ truncation=True,
486
+ padding="max_length",
487
+ max_length=self.tokenizer.model_max_length,
488
+ return_tensors="pt",
489
+ ).input_ids
490
+
491
+ return example
492
+
493
+ def collate_fn(examples):
494
+ input_ids = [example["prompt_ids"] for example in examples]
495
+ images = [example["images"] for example in examples]
496
+
497
+ masks = [example["masks"] for example in examples]
498
+ weightings = [example["weightings"] for example in examples]
499
+ conditioning_images = [example["conditioning_images"] for example in examples]
500
+
501
+ images = torch.stack(images)
502
+ images = images.to(memory_format=torch.contiguous_format).float()
503
+
504
+ masks = torch.stack(masks)
505
+ masks = masks.to(memory_format=torch.contiguous_format).float()
506
+
507
+ weightings = torch.stack(weightings)
508
+ weightings = weightings.to(memory_format=torch.contiguous_format).float()
509
+
510
+ conditioning_images = torch.stack(conditioning_images)
511
+ conditioning_images = conditioning_images.to(memory_format=torch.contiguous_format).float()
512
+
513
+ input_ids = torch.cat(input_ids, dim=0)
514
+
515
+ batch = {
516
+ "input_ids": input_ids,
517
+ "images": images,
518
+ "masks": masks,
519
+ "weightings": weightings,
520
+ "conditioning_images": conditioning_images,
521
+ }
522
+ return batch
523
+
524
+ def main(args):
525
+ logging_dir = Path(args.output_dir, args.logging_dir)
526
+
527
+ accelerator = Accelerator(
528
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
529
+ mixed_precision=args.mixed_precision,
530
+ log_with=args.report_to,
531
+ project_dir=logging_dir,
532
+ )
533
+
534
+ if args.report_to == "wandb":
535
+ if not is_wandb_available():
536
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
537
+ import wandb
538
+
539
+ wandb.login(key=args.wandb_key)
540
+ wandb.init(project=args.wandb_project_name)
541
+
542
+ # Make one log on every process with the configuration for debugging.
543
+ logging.basicConfig(
544
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
545
+ datefmt="%m/%d/%Y %H:%M:%S",
546
+ level=logging.INFO,
547
+ )
548
+ logger.info(accelerator.state, main_process_only=False)
549
+ if accelerator.is_local_main_process:
550
+ transformers.utils.logging.set_verbosity_warning()
551
+ diffusers.utils.logging.set_verbosity_info()
552
+ else:
553
+ transformers.utils.logging.set_verbosity_error()
554
+ diffusers.utils.logging.set_verbosity_error()
555
+
556
+ # If passed along, set the training seed now.
557
+ if args.seed is not None:
558
+ set_seed(args.seed)
559
+
560
+ # Handle the repository creation
561
+ if accelerator.is_main_process:
562
+ if args.output_dir is not None:
563
+ os.makedirs(args.output_dir, exist_ok=True)
564
+
565
+ if args.push_to_hub:
566
+ repo_id = create_repo(
567
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
568
+ ).repo_id
569
+
570
+ # Load the tokenizer
571
+ if args.tokenizer_name:
572
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
573
+ elif args.pretrained_model_name_or_path:
574
+ tokenizer = AutoTokenizer.from_pretrained(
575
+ args.pretrained_model_name_or_path,
576
+ subfolder="tokenizer",
577
+ revision=args.revision,
578
+ use_fast=False,
579
+ )
580
+
581
+ # Load scheduler and models
582
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
583
+ text_encoder = CLIPTextModel.from_pretrained(
584
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
585
+ )
586
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
587
+ unet = UNet2DConditionModel.from_pretrained(
588
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
589
+ )
590
+
591
+ config = LoraConfig(
592
+ r=args.lora_rank,
593
+ lora_alpha=args.lora_alpha,
594
+ target_modules=["to_k", "to_q", "to_v", "key", "query", "value"],
595
+ lora_dropout=args.lora_dropout,
596
+ bias=args.lora_bias,
597
+ )
598
+ unet = get_peft_model(unet, config)
599
+
600
+ config = LoraConfig(
601
+ r=args.lora_rank,
602
+ lora_alpha=args.lora_alpha,
603
+ target_modules=["k_proj", "q_proj", "v_proj"],
604
+ lora_dropout=args.lora_dropout,
605
+ bias=args.lora_bias,
606
+ )
607
+ text_encoder = get_peft_model(text_encoder, config)
608
+
609
+ vae.requires_grad_(False)
610
+
611
+ if args.enable_xformers_memory_efficient_attention:
612
+ if is_xformers_available():
613
+ import xformers
614
+
615
+ xformers_version = version.parse(xformers.__version__)
616
+ if xformers_version == version.parse("0.0.16"):
617
+ logger.warn(
618
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
619
+ )
620
+ unet.enable_xformers_memory_efficient_attention()
621
+ else:
622
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
623
+
624
+ if args.gradient_checkpointing:
625
+ unet.enable_gradient_checkpointing()
626
+ text_encoder.gradient_checkpointing_enable()
627
+
628
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
629
+ def save_model_hook(models, weights, output_dir):
630
+ if accelerator.is_main_process:
631
+ for model in models:
632
+ sub_dir = "unet" if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model))) else "text_encoder"
633
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
634
+
635
+ # make sure to pop weight so that corresponding model is not saved again
636
+ weights.pop()
637
+
638
+ def load_model_hook(models, input_dir):
639
+ while len(models) > 0:
640
+ # pop models so that they are not loaded again
641
+ model = models.pop()
642
+
643
+ sub_dir = "unet" if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model))) else "text_encoder"
644
+ model_cls = UNet2DConditionModel if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model))) else CLIPTextModel
645
+
646
+ load_model = model_cls.from_pretrained(args.pretrained_model_name_or_path, subfolder=sub_dir)
647
+ load_model = PeftModel.from_pretrained(load_model, input_dir, subfolder=sub_dir)
648
+
649
+ model.load_state_dict(load_model.state_dict())
650
+ del load_model
651
+
652
+ accelerator.register_save_state_pre_hook(save_model_hook)
653
+ accelerator.register_load_state_pre_hook(load_model_hook)
654
+
655
+ # Enable TF32 for faster training on Ampere GPUs,
656
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
657
+ if args.allow_tf32:
658
+ torch.backends.cuda.matmul.allow_tf32 = True
659
+
660
+ if args.scale_lr:
661
+ args.unet_learning_rate = (
662
+ args.unet_learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
663
+ )
664
+
665
+ args.text_encoder_learning_rate = (
666
+ args.text_encoder_learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
667
+ )
668
+
669
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
670
+ if args.use_8bit_adam:
671
+ try:
672
+ import bitsandbytes as bnb
673
+ except ImportError:
674
+ raise ImportError(
675
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
676
+ )
677
+
678
+ optimizer_class = bnb.optim.AdamW8bit
679
+ else:
680
+ optimizer_class = torch.optim.AdamW
681
+
682
+ # Optimizer creation
683
+ optimizer = optimizer_class(
684
+ [
685
+ {"params": unet.parameters(), "lr": args.unet_learning_rate},
686
+ {"params": text_encoder.parameters(), "lr": args.text_encoder_learning_rate}
687
+ ],
688
+ betas=(args.adam_beta1, args.adam_beta2),
689
+ weight_decay=args.adam_weight_decay,
690
+ eps=args.adam_epsilon,
691
+ )
692
+
693
+ # Dataset and DataLoaders creation:
694
+ train_dataset = RealFillDataset(
695
+ train_data_root=args.train_data_dir,
696
+ tokenizer=tokenizer,
697
+ size=args.resolution,
698
+ )
699
+
700
+ train_dataloader = torch.utils.data.DataLoader(
701
+ train_dataset,
702
+ batch_size=args.train_batch_size,
703
+ shuffle=True,
704
+ collate_fn=collate_fn,
705
+ num_workers=1,
706
+ )
707
+
708
+ # Scheduler and math around the number of training steps.
709
+ overrode_max_train_steps = False
710
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
711
+ if args.max_train_steps is None:
712
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
713
+ overrode_max_train_steps = True
714
+
715
+ lr_scheduler = get_scheduler(
716
+ args.lr_scheduler,
717
+ optimizer=optimizer,
718
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
719
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
720
+ num_cycles=args.lr_num_cycles,
721
+ power=args.lr_power,
722
+ )
723
+
724
+ # Prepare everything with our `accelerator`.
725
+ unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
726
+ unet, text_encoder, optimizer, train_dataloader
727
+ )
728
+
729
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
730
+ # as these weights are only used for inference, keeping weights in full precision is not required.
731
+ weight_dtype = torch.float32
732
+ if accelerator.mixed_precision == "fp16":
733
+ weight_dtype = torch.float16
734
+ elif accelerator.mixed_precision == "bf16":
735
+ weight_dtype = torch.bfloat16
736
+
737
+ # Move vae to device and cast to weight_dtype
738
+ vae.to(accelerator.device, dtype=weight_dtype)
739
+
740
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
741
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
742
+ if overrode_max_train_steps:
743
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
744
+ # Afterwards we recalculate our number of training epochs
745
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
746
+
747
+ # We need to initialize the trackers we use, and also store our configuration.
748
+ # The trackers initializes automatically on the main process.
749
+ if accelerator.is_main_process:
750
+ tracker_config = vars(copy.deepcopy(args))
751
+ accelerator.init_trackers("realfill", config=tracker_config)
752
+
753
+ # Train!
754
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
755
+
756
+ logger.info("***** Running training *****")
757
+ logger.info(f" Num examples = {len(train_dataset)}")
758
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
759
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
760
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
761
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
762
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
763
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
764
+ global_step = 0
765
+ first_epoch = 0
766
+
767
+ # Potentially load in the weights and states from a previous save
768
+ if args.resume_from_checkpoint:
769
+ if args.resume_from_checkpoint != "latest":
770
+ path = os.path.basename(args.resume_from_checkpoint)
771
+ else:
772
+ # Get the mos recent checkpoint
773
+ dirs = os.listdir(args.output_dir)
774
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
775
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
776
+ path = dirs[-1] if len(dirs) > 0 else None
777
+
778
+ if path is None:
779
+ accelerator.print(
780
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
781
+ )
782
+ args.resume_from_checkpoint = None
783
+ initial_global_step = 0
784
+ else:
785
+ accelerator.print(f"Resuming from checkpoint {path}")
786
+ accelerator.load_state(os.path.join(args.output_dir, path))
787
+ global_step = int(path.split("-")[1])
788
+
789
+ initial_global_step = global_step
790
+ first_epoch = global_step // num_update_steps_per_epoch
791
+ else:
792
+ initial_global_step = 0
793
+
794
+ progress_bar = tqdm(
795
+ range(0, args.max_train_steps),
796
+ initial=initial_global_step,
797
+ desc="Steps",
798
+ # Only show the progress bar once on each machine.
799
+ disable=not accelerator.is_local_main_process,
800
+ )
801
+
802
+ for epoch in range(first_epoch, args.num_train_epochs):
803
+ unet.train()
804
+ text_encoder.train()
805
+
806
+ for step, batch in enumerate(train_dataloader):
807
+ with accelerator.accumulate(unet, text_encoder):
808
+ # Convert images to latent space
809
+ latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
810
+ latents = latents * 0.18215
811
+
812
+ conditionings = vae.encode(batch["conditioning_images"].to(dtype=weight_dtype)).latent_dist.sample()
813
+ conditionings = conditionings * 0.18215
814
+
815
+ masks, size = batch["masks"].to(dtype=weight_dtype), latents.shape[2]
816
+ masks = F.interpolate(masks, size=size)
817
+
818
+ weightings = batch["weightings"].to(dtype=weight_dtype)
819
+ weightings = F.interpolate(weightings, size=size)
820
+
821
+ # Sample noise that we'll add to the latents
822
+ noise = torch.randn_like(latents)
823
+ bsz = latents.shape[0]
824
+
825
+ # Sample a random timestep for each image
826
+ timesteps = torch.randint(
827
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
828
+ )
829
+ timesteps = timesteps.long()
830
+
831
+ # Add noise to the latents according to the noise magnitude at each timestep
832
+ # (this is the forward diffusion process)
833
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
834
+
835
+ # Concatenate noisy latents, masks and conditionings to get inputs to unet
836
+ inputs = torch.cat([noisy_latents, masks, conditionings], dim=1)
837
+
838
+ # Get the text embedding for conditioning
839
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
840
+
841
+ # Predict the noise residual
842
+ model_pred = unet(inputs, timesteps, encoder_hidden_states).sample
843
+
844
+ # Compute the diffusion loss
845
+ assert noise_scheduler.config.prediction_type == "epsilon"
846
+ loss = (weightings * F.mse_loss(model_pred.float(), noise.float(), reduction="none")).mean()
847
+
848
+ accelerator.backward(loss)
849
+ if accelerator.sync_gradients:
850
+ params_to_clip = itertools.chain(
851
+ unet.parameters(), text_encoder.parameters()
852
+ )
853
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
854
+
855
+ optimizer.step()
856
+ lr_scheduler.step()
857
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
858
+
859
+ # Checks if the accelerator has performed an optimization step behind the scenes
860
+ if accelerator.sync_gradients:
861
+ progress_bar.update(1)
862
+ if args.report_to == "wandb":
863
+ accelerator.print(progress_bar)
864
+ global_step += 1
865
+
866
+ if accelerator.is_main_process:
867
+ if global_step % args.checkpointing_steps == 0:
868
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
869
+ if args.checkpoints_total_limit is not None:
870
+ checkpoints = os.listdir(args.output_dir)
871
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
872
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
873
+
874
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
875
+ if len(checkpoints) >= args.checkpoints_total_limit:
876
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
877
+ removing_checkpoints = checkpoints[0:num_to_remove]
878
+
879
+ logger.info(
880
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
881
+ )
882
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
883
+
884
+ for removing_checkpoint in removing_checkpoints:
885
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
886
+ shutil.rmtree(removing_checkpoint)
887
+
888
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
889
+ accelerator.save_state(save_path)
890
+ logger.info(f"Saved state to {save_path}")
891
+
892
+ if global_step % args.validation_steps == 0:
893
+ log_validation(
894
+ text_encoder,
895
+ tokenizer,
896
+ unet,
897
+ args,
898
+ accelerator,
899
+ weight_dtype,
900
+ global_step,
901
+ )
902
+
903
+ logs = {"loss": loss.detach().item()}
904
+ progress_bar.set_postfix(**logs)
905
+ accelerator.log(logs, step=global_step)
906
+
907
+ if global_step >= args.max_train_steps:
908
+ break
909
+
910
+ # Save the lora layers
911
+ accelerator.wait_for_everyone()
912
+ if accelerator.is_main_process:
913
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
914
+ args.pretrained_model_name_or_path,
915
+ unet=accelerator.unwrap_model(unet.merge_and_unload(), keep_fp32_wrapper=True),
916
+ text_encoder=accelerator.unwrap_model(text_encoder.merge_and_unload(), keep_fp32_wrapper=True),
917
+ revision=args.revision,
918
+ )
919
+
920
+ pipeline.save_pretrained(args.output_dir)
921
+
922
+ # Final inference
923
+ images = log_validation(
924
+ text_encoder,
925
+ tokenizer,
926
+ unet,
927
+ args,
928
+ accelerator,
929
+ weight_dtype,
930
+ global_step,
931
+ )
932
+
933
+ if args.push_to_hub:
934
+ save_model_card(
935
+ repo_id,
936
+ images=images,
937
+ base_model=args.pretrained_model_name_or_path,
938
+ repo_folder=args.output_dir,
939
+ )
940
+ upload_folder(
941
+ repo_id=repo_id,
942
+ folder_path=args.output_dir,
943
+ commit_message="End of training",
944
+ ignore_patterns=["step_*", "epoch_*"],
945
+ )
946
+
947
+ accelerator.end_training()
948
+
949
+
950
+ if __name__ == "__main__":
951
+ args = parse_args()
952
+ main(args)
trainer.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import shlex
6
+ import shutil
7
+ import subprocess
8
+
9
+ import gradio as gr
10
+ import PIL.Image
11
+ import torch
12
+
13
+
14
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
15
+ w, h = image.size
16
+ if w == h:
17
+ return image
18
+ elif w > h:
19
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
20
+ new_image.paste(image, (0, (w - h) // 2))
21
+ return new_image
22
+ else:
23
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
24
+ new_image.paste(image, ((h - w) // 2, 0))
25
+ return new_image
26
+
27
+
28
+ class Trainer:
29
+ def __init__(self):
30
+ self.is_running = False
31
+ self.is_running_message = "Another training is in progress."
32
+
33
+ self.output_dir = pathlib.Path("results")
34
+ self.data_dir = pathlib.Path("data")
35
+
36
+ self.ref_data_dir = self.data_dir / "ref"
37
+ self.target_data_dir = self.data_dir / "target"
38
+
39
+ def check_if_running(self) -> dict:
40
+ if self.is_running:
41
+ return gr.update(value=self.is_running_message)
42
+ else:
43
+ return gr.update(value="No training is running.")
44
+
45
+ def cleanup_dirs(self) -> None:
46
+ shutil.rmtree(self.output_dir, ignore_errors=True)
47
+
48
+ def prepare_dataset(
49
+ self,
50
+ ref_images: list,
51
+ target_image: PIL.Image,
52
+ target_mask: PIL.Image,
53
+ resolution: int
54
+ ) -> None:
55
+ self.ref_data_dir.mkdir(parents=True)
56
+ self.target_data_dir.mkdir(parents=True)
57
+
58
+ for i, temp_path in enumerate(ref_images):
59
+ image = PIL.Image.open(temp_path.name)
60
+ image = pad_image(image)
61
+ image = image.resize((resolution, resolution))
62
+ image = image.convert("RGB")
63
+ out_path = self.ref_data_dir / f"{i:03d}.jpg"
64
+ image.save(out_path, format="JPEG", quality=100)
65
+
66
+ target_image.save(self.target_data_dir / "target.jpg", format="JPEG", quality=100)
67
+ target_mask.save(self.target_data_dir / "mask.jpg", format="JPEG", quality=100)
68
+
69
+ def run(
70
+ self,
71
+ base_model: str,
72
+ resolution_s: str,
73
+ n_steps: int,
74
+ ref_images: list | None,
75
+ target_image: PIL.Image,
76
+ target_mask: PIL.Image,
77
+ unet_learning_rate: float,
78
+ text_encoder_learning_rate: float,
79
+ gradient_accumulation: int,
80
+ fp16: bool,
81
+ use_8bit_adam: bool,
82
+ gradient_checkpointing: bool,
83
+ lora_rank: int,
84
+ lora_alpha: int,
85
+ lora_bias: str,
86
+ lora_dropout: float,
87
+ ) -> tuple[dict, list[pathlib.Path]]:
88
+ if not torch.cuda.is_available():
89
+ raise gr.Error("CUDA is not available.")
90
+
91
+ if self.is_running:
92
+ return gr.update(value=self.is_running_message), []
93
+
94
+ if ref_images is None:
95
+ raise gr.Error("You need to upload reference images.")
96
+ if target_image is None:
97
+ raise gr.Error("You need to upload target image.")
98
+ if target_mask is None:
99
+ raise gr.Error("You need to upload target mask.")
100
+
101
+ resolution = int(resolution_s)
102
+
103
+ self.cleanup_dirs()
104
+ self.prepare_dataset(ref_images, target_image, target_mask, resolution)
105
+
106
+ command = f"""
107
+ accelerate launch train_dreambooth.py \
108
+ --pretrained_model_name_or_path={base_model} \
109
+ --train_data_dir={self.data_dir} \
110
+ --output_dir={self.output_dir} \
111
+ --resolution={resolution} \
112
+ --gradient_accumulation_steps={gradient_accumulation} \
113
+ --unet_learning_rate={unet_learning_rate} \
114
+ --text_encoder_learning_rate={text_encoder_learning_rate} \
115
+ --max_train_steps={n_steps} \
116
+ --train_batch_size=16 \
117
+ --lr_scheduler=constant \
118
+ --lr_warmup_steps=100 \
119
+ --lora_r={lora_rank} \
120
+ --lora_alpha={lora_alpha} \
121
+ --lora_bias={lora_bias} \
122
+ --lora_dropout={lora_dropout} \
123
+ """
124
+
125
+ if fp16:
126
+ command += " --mixed_precision fp16"
127
+ if use_8bit_adam:
128
+ command += " --use_8bit_adam"
129
+ if gradient_checkpointing:
130
+ command += " --gradient_checkpointing"
131
+
132
+ with open(self.output_dir / "train.sh", "w") as f:
133
+ command_s = " ".join(command.split())
134
+ f.write(command_s)
135
+
136
+ self.is_running = True
137
+ res = subprocess.run(shlex.split(command))
138
+ self.is_running = False
139
+
140
+ if res.returncode == 0:
141
+ result_message = "Training Completed!"
142
+ else:
143
+ result_message = "Training Failed!"
144
+ model_paths = sorted(self.output_dir.glob("*"))
145
+ return gr.update(value=result_message), model_paths
uploader.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi
3
+
4
+
5
+ def upload(model_name: str, hf_token: str) -> None:
6
+ api = HfApi(token=hf_token)
7
+ user_name = api.whoami()["name"]
8
+ model_id = f"{user_name}/{model_name}"
9
+ try:
10
+ api.create_repo(model_id, repo_type="model", private=True)
11
+ api.upload_folder(repo_id=model_id, folder_path="results", path_in_repo="results", repo_type="model")
12
+ url = f"https://huggingface.co/{model_id}"
13
+ message = f"Your model was successfully uploaded to [{url}]({url})."
14
+ except Exception as e:
15
+ message = str(e)
16
+
17
+ return gr.update(value=message, visible=True)