Divyanshu04 commited on
Commit
5b512a0
·
1 Parent(s): ef31f0a
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["gunicorn", "-b", "0.0.0.0:7860", "Text2image-api:app"]
DreamBooth_Stable_Diffusion.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Text2image-api.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, jsonify, request
2
+ from pathlib import Path
3
+ import sys
4
+ import torch
5
+ import os
6
+ from torch import autocast
7
+ from diffusers import StableDiffusionPipeline, DDIMScheduler
8
+ import streamlit as st
9
+
10
+ # model_path = WEIGHTS_DIR # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive
11
+
12
+ # pipe = StableDiffusionPipeline.from_pretrained(model_path, safety_checker=None, torch_dtype=torch.float32).to("cuda")
13
+ # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
14
+ # pipe.enable_xformers_memory_efficient_attention()
15
+ # g_cuda = None
16
+
17
+ FILE = Path(__file__).resolve()
18
+ ROOT = FILE.parents[0] # YOLOv5 root directory
19
+ if str(ROOT) not in sys.path:
20
+ sys.path.append(str(ROOT)) # add ROOT to PATH
21
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd()))
22
+
23
+ app = Flask(__name__)
24
+
25
+ # @app.route('/', methods = ['GET', 'POST'])
26
+ # def home():
27
+ # if(request.method == 'GET'):
28
+
29
+ # data = "Text2Image"
30
+ # return jsonify({'service': data})
31
+
32
+
33
+ @app.route("/", methods=["POST"])
34
+ def generate():
35
+
36
+ # prompt = request.form['prompt']
37
+ # negative_prompt = request.form['Negative prompt']
38
+ # num_samples = request.form['No. of samples']
39
+
40
+ prompt = st.text_area(placeholder = "prompt", key="pmpt")
41
+ negative_prompt = st.text_area(placeholder = "Negative prompt", key="ng_pmpt")
42
+ num_samples = st.number_input("No. of samples")
43
+
44
+ res = st.button("Reset", type="primary")
45
+
46
+ if res:
47
+
48
+ guidance_scale = 7.5
49
+ num_inference_steps = 24
50
+ height = 512
51
+ width = 512
52
+
53
+ g_cuda = torch.Generator(device='cuda')
54
+ seed = 52362
55
+ g_cuda.manual_seed(seed)
56
+
57
+ # commandline_args = os.environ.get('COMMANDLINE_ARGS', "--skip-torch-cuda-test --no-half")
58
+
59
+ with autocast("cuda"), torch.inference_mode():
60
+ images = pipe(
61
+ prompt,
62
+ height=height,
63
+ width=width,
64
+ negative_prompt=negative_prompt,
65
+ num_images_per_prompt=num_samples,
66
+ num_inference_steps=num_inference_steps,
67
+ guidance_scale=guidance_scale,
68
+ generator=g_cuda
69
+ ).images
70
+
71
+ return {"message": "successful"}
72
+
73
+ else:
74
+ return {"message": "Running.."}
75
+
76
+
77
+
78
+
79
+ # driver function
80
+ # if __name__ == '__main__':
81
+
82
+ # app.run(debug = True)
concepts_list.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instance_prompt": "photo of zwx dog",
4
+ "class_prompt": "photo of a dog",
5
+ "instance_data_dir": "/content/data/zwx",
6
+ "class_data_dir": "/content/data/dog"
7
+ }
8
+ ]
main.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from flask import Flask
2
+
3
+ app = Flask(__name__)
4
+
5
+ @app.route('/')
6
+ def hello():
7
+ return {'hei': "you success"}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ flask
2
+ gunicorn
3
+ xformers==0.0.20
4
+ diffusers
5
+ gradio
train_dreambooth.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import itertools
4
+ import random
5
+ import json
6
+ import logging
7
+ import math
8
+ import os
9
+ from contextlib import nullcontext
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint
16
+ from torch.utils.data import Dataset
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import set_seed
21
+ from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
22
+ from diffusers.optimization import get_scheduler
23
+ from diffusers.utils.import_utils import is_xformers_available
24
+ from huggingface_hub import HfFolder, Repository, whoami
25
+ from PIL import Image
26
+ from torchvision import transforms
27
+ from tqdm.auto import tqdm
28
+ from transformers import CLIPTextModel, CLIPTokenizer
29
+
30
+
31
+ torch.backends.cudnn.benchmark = True
32
+
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ def parse_args(input_args=None):
38
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
39
+ parser.add_argument(
40
+ "--pretrained_model_name_or_path",
41
+ type=str,
42
+ default=None,
43
+ required=True,
44
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
45
+ )
46
+ parser.add_argument(
47
+ "--pretrained_vae_name_or_path",
48
+ type=str,
49
+ default=None,
50
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
51
+ )
52
+ parser.add_argument(
53
+ "--revision",
54
+ type=str,
55
+ default=None,
56
+ required=False,
57
+ help="Revision of pretrained model identifier from huggingface.co/models.",
58
+ )
59
+ parser.add_argument(
60
+ "--tokenizer_name",
61
+ type=str,
62
+ default=None,
63
+ help="Pretrained tokenizer name or path if not the same as model_name",
64
+ )
65
+ parser.add_argument(
66
+ "--instance_data_dir",
67
+ type=str,
68
+ default=None,
69
+ help="A folder containing the training data of instance images.",
70
+ )
71
+ parser.add_argument(
72
+ "--class_data_dir",
73
+ type=str,
74
+ default=None,
75
+ help="A folder containing the training data of class images.",
76
+ )
77
+ parser.add_argument(
78
+ "--instance_prompt",
79
+ type=str,
80
+ default=None,
81
+ help="The prompt with identifier specifying the instance",
82
+ )
83
+ parser.add_argument(
84
+ "--class_prompt",
85
+ type=str,
86
+ default=None,
87
+ help="The prompt to specify images in the same class as provided instance images.",
88
+ )
89
+ parser.add_argument(
90
+ "--save_sample_prompt",
91
+ type=str,
92
+ default=None,
93
+ help="The prompt used to generate sample outputs to save.",
94
+ )
95
+ parser.add_argument(
96
+ "--save_sample_negative_prompt",
97
+ type=str,
98
+ default=None,
99
+ help="The negative prompt used to generate sample outputs to save.",
100
+ )
101
+ parser.add_argument(
102
+ "--n_save_sample",
103
+ type=int,
104
+ default=4,
105
+ help="The number of samples to save.",
106
+ )
107
+ parser.add_argument(
108
+ "--save_guidance_scale",
109
+ type=float,
110
+ default=7.5,
111
+ help="CFG for save sample.",
112
+ )
113
+ parser.add_argument(
114
+ "--save_infer_steps",
115
+ type=int,
116
+ default=20,
117
+ help="The number of inference steps for save sample.",
118
+ )
119
+ parser.add_argument(
120
+ "--pad_tokens",
121
+ default=False,
122
+ action="store_true",
123
+ help="Flag to pad tokens to length 77.",
124
+ )
125
+ parser.add_argument(
126
+ "--with_prior_preservation",
127
+ default=False,
128
+ action="store_true",
129
+ help="Flag to add prior preservation loss.",
130
+ )
131
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
132
+ parser.add_argument(
133
+ "--num_class_images",
134
+ type=int,
135
+ default=100,
136
+ help=(
137
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
138
+ " sampled with class_prompt."
139
+ ),
140
+ )
141
+ parser.add_argument(
142
+ "--output_dir",
143
+ type=str,
144
+ default="text-inversion-model",
145
+ help="The output directory where the model predictions and checkpoints will be written.",
146
+ )
147
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
148
+ parser.add_argument(
149
+ "--resolution",
150
+ type=int,
151
+ default=512,
152
+ help=(
153
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
154
+ " resolution"
155
+ ),
156
+ )
157
+ parser.add_argument(
158
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
159
+ )
160
+ parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
161
+ parser.add_argument(
162
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
163
+ )
164
+ parser.add_argument(
165
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
166
+ )
167
+ parser.add_argument("--num_train_epochs", type=int, default=1)
168
+ parser.add_argument(
169
+ "--max_train_steps",
170
+ type=int,
171
+ default=None,
172
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
173
+ )
174
+ parser.add_argument(
175
+ "--gradient_accumulation_steps",
176
+ type=int,
177
+ default=1,
178
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
179
+ )
180
+ parser.add_argument(
181
+ "--gradient_checkpointing",
182
+ action="store_true",
183
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
184
+ )
185
+ parser.add_argument(
186
+ "--learning_rate",
187
+ type=float,
188
+ default=5e-6,
189
+ help="Initial learning rate (after the potential warmup period) to use.",
190
+ )
191
+ parser.add_argument(
192
+ "--scale_lr",
193
+ action="store_true",
194
+ default=False,
195
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
196
+ )
197
+ parser.add_argument(
198
+ "--lr_scheduler",
199
+ type=str,
200
+ default="constant",
201
+ help=(
202
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
203
+ ' "constant", "constant_with_warmup"]'
204
+ ),
205
+ )
206
+ parser.add_argument(
207
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
208
+ )
209
+ parser.add_argument(
210
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
211
+ )
212
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
213
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
214
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
215
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
216
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
217
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
218
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
219
+ parser.add_argument(
220
+ "--hub_model_id",
221
+ type=str,
222
+ default=None,
223
+ help="The name of the repository to keep in sync with the local `output_dir`.",
224
+ )
225
+ parser.add_argument(
226
+ "--logging_dir",
227
+ type=str,
228
+ default="logs",
229
+ help=(
230
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
231
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
232
+ ),
233
+ )
234
+ parser.add_argument("--log_interval", type=int, default=10, help="Log every N steps.")
235
+ parser.add_argument("--save_interval", type=int, default=10_000, help="Save weights every N steps.")
236
+ parser.add_argument("--save_min_steps", type=int, default=0, help="Start saving weights after N steps.")
237
+ parser.add_argument(
238
+ "--mixed_precision",
239
+ type=str,
240
+ default=None,
241
+ choices=["no", "fp16", "bf16"],
242
+ help=(
243
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
244
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
245
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
246
+ ),
247
+ )
248
+ parser.add_argument("--not_cache_latents", action="store_true", help="Do not precompute and cache latents from VAE.")
249
+ parser.add_argument("--hflip", action="store_true", help="Apply horizontal flip data augmentation.")
250
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
251
+ parser.add_argument(
252
+ "--concepts_list",
253
+ type=str,
254
+ default=None,
255
+ help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
256
+ )
257
+ parser.add_argument(
258
+ "--read_prompts_from_txts",
259
+ action="store_true",
260
+ help="Use prompt per image. Put prompts in the same directory as images, e.g. for image.png create image.png.txt.",
261
+ )
262
+
263
+ if input_args is not None:
264
+ args = parser.parse_args(input_args)
265
+ else:
266
+ args = parser.parse_args()
267
+
268
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
269
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
270
+ args.local_rank = env_local_rank
271
+
272
+ return args
273
+
274
+
275
+ class DreamBoothDataset(Dataset):
276
+ """
277
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
278
+ It pre-processes the images and the tokenizes prompts.
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ concepts_list,
284
+ tokenizer,
285
+ with_prior_preservation=True,
286
+ size=512,
287
+ center_crop=False,
288
+ num_class_images=None,
289
+ pad_tokens=False,
290
+ hflip=False,
291
+ read_prompts_from_txts=False,
292
+ ):
293
+ self.size = size
294
+ self.center_crop = center_crop
295
+ self.tokenizer = tokenizer
296
+ self.with_prior_preservation = with_prior_preservation
297
+ self.pad_tokens = pad_tokens
298
+ self.read_prompts_from_txts = read_prompts_from_txts
299
+
300
+ self.instance_images_path = []
301
+ self.class_images_path = []
302
+
303
+ for concept in concepts_list:
304
+ inst_img_path = [
305
+ (x, concept["instance_prompt"])
306
+ for x in Path(concept["instance_data_dir"]).iterdir()
307
+ if x.is_file() and not str(x).endswith(".txt")
308
+ ]
309
+ self.instance_images_path.extend(inst_img_path)
310
+
311
+ if with_prior_preservation:
312
+ class_img_path = [(x, concept["class_prompt"]) for x in Path(concept["class_data_dir"]).iterdir() if x.is_file()]
313
+ self.class_images_path.extend(class_img_path[:num_class_images])
314
+
315
+ random.shuffle(self.instance_images_path)
316
+ self.num_instance_images = len(self.instance_images_path)
317
+ self.num_class_images = len(self.class_images_path)
318
+ self._length = max(self.num_class_images, self.num_instance_images)
319
+
320
+ self.image_transforms = transforms.Compose(
321
+ [
322
+ transforms.RandomHorizontalFlip(0.5 * hflip),
323
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
324
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
325
+ transforms.ToTensor(),
326
+ transforms.Normalize([0.5], [0.5]),
327
+ ]
328
+ )
329
+
330
+ def __len__(self):
331
+ return self._length
332
+
333
+ def __getitem__(self, index):
334
+ example = {}
335
+ instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]
336
+
337
+ if self.read_prompts_from_txts:
338
+ with open(str(instance_path) + ".txt") as f:
339
+ instance_prompt = f.read().strip()
340
+
341
+ instance_image = Image.open(instance_path)
342
+ if not instance_image.mode == "RGB":
343
+ instance_image = instance_image.convert("RGB")
344
+
345
+ example["instance_images"] = self.image_transforms(instance_image)
346
+ example["instance_prompt_ids"] = self.tokenizer(
347
+ instance_prompt,
348
+ padding="max_length" if self.pad_tokens else "do_not_pad",
349
+ truncation=True,
350
+ max_length=self.tokenizer.model_max_length,
351
+ ).input_ids
352
+
353
+ if self.with_prior_preservation:
354
+ class_path, class_prompt = self.class_images_path[index % self.num_class_images]
355
+ class_image = Image.open(class_path)
356
+ if not class_image.mode == "RGB":
357
+ class_image = class_image.convert("RGB")
358
+ example["class_images"] = self.image_transforms(class_image)
359
+ example["class_prompt_ids"] = self.tokenizer(
360
+ class_prompt,
361
+ padding="max_length" if self.pad_tokens else "do_not_pad",
362
+ truncation=True,
363
+ max_length=self.tokenizer.model_max_length,
364
+ ).input_ids
365
+
366
+ return example
367
+
368
+
369
+ class PromptDataset(Dataset):
370
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
371
+
372
+ def __init__(self, prompt, num_samples):
373
+ self.prompt = prompt
374
+ self.num_samples = num_samples
375
+
376
+ def __len__(self):
377
+ return self.num_samples
378
+
379
+ def __getitem__(self, index):
380
+ example = {}
381
+ example["prompt"] = self.prompt
382
+ example["index"] = index
383
+ return example
384
+
385
+
386
+ class LatentsDataset(Dataset):
387
+ def __init__(self, latents_cache, text_encoder_cache):
388
+ self.latents_cache = latents_cache
389
+ self.text_encoder_cache = text_encoder_cache
390
+
391
+ def __len__(self):
392
+ return len(self.latents_cache)
393
+
394
+ def __getitem__(self, index):
395
+ return self.latents_cache[index], self.text_encoder_cache[index]
396
+
397
+
398
+ class AverageMeter:
399
+ def __init__(self, name=None):
400
+ self.name = name
401
+ self.reset()
402
+
403
+ def reset(self):
404
+ self.sum = self.count = self.avg = 0
405
+
406
+ def update(self, val, n=1):
407
+ self.sum += val * n
408
+ self.count += n
409
+ self.avg = self.sum / self.count
410
+
411
+
412
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
413
+ if token is None:
414
+ token = HfFolder.get_token()
415
+ if organization is None:
416
+ username = whoami(token)["name"]
417
+ return f"{username}/{model_id}"
418
+ else:
419
+ return f"{organization}/{model_id}"
420
+
421
+
422
+ def main(args):
423
+ logging_dir = Path(args.output_dir, "0", args.logging_dir)
424
+
425
+ accelerator = Accelerator(
426
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
427
+ mixed_precision=args.mixed_precision,
428
+ log_with="tensorboard",
429
+ project_dir=logging_dir,
430
+ )
431
+
432
+ logging.basicConfig(
433
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
434
+ datefmt="%m/%d/%Y %H:%M:%S",
435
+ level=logging.INFO,
436
+ )
437
+
438
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
439
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
440
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
441
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
442
+ raise ValueError(
443
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
444
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
445
+ )
446
+
447
+ if args.seed is not None:
448
+ set_seed(args.seed)
449
+
450
+ if args.concepts_list is None:
451
+ args.concepts_list = [
452
+ {
453
+ "instance_prompt": args.instance_prompt,
454
+ "class_prompt": args.class_prompt,
455
+ "instance_data_dir": args.instance_data_dir,
456
+ "class_data_dir": args.class_data_dir
457
+ }
458
+ ]
459
+ else:
460
+ with open(args.concepts_list, "r") as f:
461
+ args.concepts_list = json.load(f)
462
+
463
+ if args.with_prior_preservation:
464
+ pipeline = None
465
+ for concept in args.concepts_list:
466
+ class_images_dir = Path(concept["class_data_dir"])
467
+ class_images_dir.mkdir(parents=True, exist_ok=True)
468
+ cur_class_images = len(list(class_images_dir.iterdir()))
469
+
470
+ if cur_class_images < args.num_class_images:
471
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
472
+ if pipeline is None:
473
+ pipeline = StableDiffusionPipeline.from_pretrained(
474
+ args.pretrained_model_name_or_path,
475
+ vae=AutoencoderKL.from_pretrained(
476
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
477
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
478
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
479
+ torch_dtype=torch_dtype
480
+ ),
481
+ torch_dtype=torch_dtype,
482
+ safety_checker=None,
483
+ revision=args.revision
484
+ )
485
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
486
+ if is_xformers_available():
487
+ pipeline.enable_xformers_memory_efficient_attention()
488
+ pipeline.set_progress_bar_config(disable=True)
489
+ pipeline.to(accelerator.device)
490
+
491
+ num_new_images = args.num_class_images - cur_class_images
492
+ logger.info(f"Number of class images to sample: {num_new_images}.")
493
+
494
+ sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
495
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
496
+
497
+ sample_dataloader = accelerator.prepare(sample_dataloader)
498
+
499
+ with torch.autocast("cuda"), torch.inference_mode():
500
+ for example in tqdm(
501
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
502
+ ):
503
+ images = pipeline(
504
+ example["prompt"],
505
+ num_inference_steps=args.save_infer_steps
506
+ ).images
507
+
508
+ for i, image in enumerate(images):
509
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
510
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
511
+ image.save(image_filename)
512
+
513
+ del pipeline
514
+ if torch.cuda.is_available():
515
+ torch.cuda.empty_cache()
516
+
517
+ # Load the tokenizer
518
+ if args.tokenizer_name:
519
+ tokenizer = CLIPTokenizer.from_pretrained(
520
+ args.tokenizer_name,
521
+ revision=args.revision,
522
+ )
523
+ elif args.pretrained_model_name_or_path:
524
+ tokenizer = CLIPTokenizer.from_pretrained(
525
+ args.pretrained_model_name_or_path,
526
+ subfolder="tokenizer",
527
+ revision=args.revision,
528
+ )
529
+
530
+ # Load models and create wrapper for stable diffusion
531
+ text_encoder = CLIPTextModel.from_pretrained(
532
+ args.pretrained_model_name_or_path,
533
+ subfolder="text_encoder",
534
+ revision=args.revision,
535
+ )
536
+ vae = AutoencoderKL.from_pretrained(
537
+ args.pretrained_model_name_or_path,
538
+ subfolder="vae",
539
+ revision=args.revision,
540
+ )
541
+ unet = UNet2DConditionModel.from_pretrained(
542
+ args.pretrained_model_name_or_path,
543
+ subfolder="unet",
544
+ revision=args.revision,
545
+ torch_dtype=torch.float32
546
+ )
547
+
548
+ vae.requires_grad_(False)
549
+ if not args.train_text_encoder:
550
+ text_encoder.requires_grad_(False)
551
+
552
+ if is_xformers_available():
553
+ vae.enable_xformers_memory_efficient_attention()
554
+ unet.enable_xformers_memory_efficient_attention()
555
+ else:
556
+ logger.warning("xformers is not available. Make sure it is installed correctly")
557
+
558
+ if args.gradient_checkpointing:
559
+ unet.enable_gradient_checkpointing()
560
+ if args.train_text_encoder:
561
+ text_encoder.gradient_checkpointing_enable()
562
+
563
+ if args.scale_lr:
564
+ args.learning_rate = (
565
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
566
+ )
567
+
568
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
569
+ if args.use_8bit_adam:
570
+ try:
571
+ import bitsandbytes as bnb
572
+ except ImportError:
573
+ raise ImportError(
574
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
575
+ )
576
+
577
+ optimizer_class = bnb.optim.AdamW8bit
578
+ else:
579
+ optimizer_class = torch.optim.AdamW
580
+
581
+ params_to_optimize = (
582
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
583
+ )
584
+ optimizer = optimizer_class(
585
+ params_to_optimize,
586
+ lr=args.learning_rate,
587
+ betas=(args.adam_beta1, args.adam_beta2),
588
+ weight_decay=args.adam_weight_decay,
589
+ eps=args.adam_epsilon,
590
+ )
591
+
592
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
593
+
594
+ train_dataset = DreamBoothDataset(
595
+ concepts_list=args.concepts_list,
596
+ tokenizer=tokenizer,
597
+ with_prior_preservation=args.with_prior_preservation,
598
+ size=args.resolution,
599
+ center_crop=args.center_crop,
600
+ num_class_images=args.num_class_images,
601
+ pad_tokens=args.pad_tokens,
602
+ hflip=args.hflip,
603
+ read_prompts_from_txts=args.read_prompts_from_txts,
604
+ )
605
+
606
+ def collate_fn(examples):
607
+ input_ids = [example["instance_prompt_ids"] for example in examples]
608
+ pixel_values = [example["instance_images"] for example in examples]
609
+
610
+ # Concat class and instance examples for prior preservation.
611
+ # We do this to avoid doing two forward passes.
612
+ if args.with_prior_preservation:
613
+ input_ids += [example["class_prompt_ids"] for example in examples]
614
+ pixel_values += [example["class_images"] for example in examples]
615
+
616
+ pixel_values = torch.stack(pixel_values)
617
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
618
+
619
+ input_ids = tokenizer.pad(
620
+ {"input_ids": input_ids},
621
+ padding=True,
622
+ return_tensors="pt",
623
+ ).input_ids
624
+
625
+ batch = {
626
+ "input_ids": input_ids,
627
+ "pixel_values": pixel_values,
628
+ }
629
+ return batch
630
+
631
+ train_dataloader = torch.utils.data.DataLoader(
632
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True
633
+ )
634
+
635
+ weight_dtype = torch.float32
636
+ if args.mixed_precision == "fp16":
637
+ weight_dtype = torch.float16
638
+ elif args.mixed_precision == "bf16":
639
+ weight_dtype = torch.bfloat16
640
+
641
+ # Move text_encode and vae to gpu.
642
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
643
+ # as these models are only used for inference, keeping weights in full precision is not required.
644
+ vae.to(accelerator.device, dtype=weight_dtype)
645
+ if not args.train_text_encoder:
646
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
647
+
648
+ if not args.not_cache_latents:
649
+ latents_cache = []
650
+ text_encoder_cache = []
651
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
652
+ with torch.no_grad():
653
+ batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
654
+ batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
655
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
656
+ if args.train_text_encoder:
657
+ text_encoder_cache.append(batch["input_ids"])
658
+ else:
659
+ text_encoder_cache.append(text_encoder(batch["input_ids"])[0])
660
+ train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
661
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
662
+
663
+ del vae
664
+ if not args.train_text_encoder:
665
+ del text_encoder
666
+ if torch.cuda.is_available():
667
+ torch.cuda.empty_cache()
668
+
669
+ # Scheduler and math around the number of training steps.
670
+ overrode_max_train_steps = False
671
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
672
+ if args.max_train_steps is None:
673
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
674
+ overrode_max_train_steps = True
675
+
676
+ lr_scheduler = get_scheduler(
677
+ args.lr_scheduler,
678
+ optimizer=optimizer,
679
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
680
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
681
+ )
682
+
683
+ if args.train_text_encoder:
684
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
685
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
686
+ )
687
+ else:
688
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
689
+ unet, optimizer, train_dataloader, lr_scheduler
690
+ )
691
+
692
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
693
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
694
+ if overrode_max_train_steps:
695
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
696
+ # Afterwards we recalculate our number of training epochs
697
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
698
+
699
+ # We need to initialize the trackers we use, and also store our configuration.
700
+ # The trackers initializes automatically on the main process.
701
+ if accelerator.is_main_process:
702
+ accelerator.init_trackers("dreambooth")
703
+
704
+ # Train!
705
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
706
+
707
+ logger.info("***** Running training *****")
708
+ logger.info(f" Num examples = {len(train_dataset)}")
709
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
710
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
711
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
712
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
713
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
714
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
715
+
716
+ def save_weights(step):
717
+ # Create the pipeline using using the trained modules and save it.
718
+ if accelerator.is_main_process:
719
+ if args.train_text_encoder:
720
+ text_enc_model = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True)
721
+ else:
722
+ text_enc_model = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
723
+ pipeline = StableDiffusionPipeline.from_pretrained(
724
+ args.pretrained_model_name_or_path,
725
+ unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True),
726
+ text_encoder=text_enc_model,
727
+ vae=AutoencoderKL.from_pretrained(
728
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
729
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
730
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
731
+ ),
732
+ safety_checker=None,
733
+ torch_dtype=torch.float16,
734
+ revision=args.revision,
735
+ )
736
+ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
737
+ if is_xformers_available():
738
+ pipeline.enable_xformers_memory_efficient_attention()
739
+ save_dir = os.path.join(args.output_dir, f"{step}")
740
+ pipeline.save_pretrained(save_dir)
741
+ with open(os.path.join(save_dir, "args.json"), "w") as f:
742
+ json.dump(args.__dict__, f, indent=2)
743
+
744
+ if args.save_sample_prompt is not None:
745
+ pipeline = pipeline.to(accelerator.device)
746
+ g_cuda = torch.Generator(device=accelerator.device).manual_seed(args.seed)
747
+ pipeline.set_progress_bar_config(disable=True)
748
+ sample_dir = os.path.join(save_dir, "samples")
749
+ os.makedirs(sample_dir, exist_ok=True)
750
+ with torch.autocast("cuda"), torch.inference_mode():
751
+ for i in tqdm(range(args.n_save_sample), desc="Generating samples"):
752
+ images = pipeline(
753
+ args.save_sample_prompt,
754
+ negative_prompt=args.save_sample_negative_prompt,
755
+ guidance_scale=args.save_guidance_scale,
756
+ num_inference_steps=args.save_infer_steps,
757
+ generator=g_cuda
758
+ ).images
759
+ images[0].save(os.path.join(sample_dir, f"{i}.png"))
760
+ del pipeline
761
+ if torch.cuda.is_available():
762
+ torch.cuda.empty_cache()
763
+ print(f"[*] Weights saved at {save_dir}")
764
+
765
+ # Only show the progress bar once on each machine.
766
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
767
+ progress_bar.set_description("Steps")
768
+ global_step = 0
769
+ loss_avg = AverageMeter()
770
+ text_enc_context = nullcontext() if args.train_text_encoder else torch.no_grad()
771
+ for epoch in range(args.num_train_epochs):
772
+ unet.train()
773
+ if args.train_text_encoder:
774
+ text_encoder.train()
775
+ for step, batch in enumerate(train_dataloader):
776
+ with accelerator.accumulate(unet):
777
+ # Convert images to latent space
778
+ with torch.no_grad():
779
+ if not args.not_cache_latents:
780
+ latent_dist = batch[0][0]
781
+ else:
782
+ latent_dist = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist
783
+ latents = latent_dist.sample() * 0.18215
784
+
785
+ # Sample noise that we'll add to the latents
786
+ noise = torch.randn_like(latents)
787
+ bsz = latents.shape[0]
788
+ # Sample a random timestep for each image
789
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
790
+ timesteps = timesteps.long()
791
+
792
+ # Add noise to the latents according to the noise magnitude at each timestep
793
+ # (this is the forward diffusion process)
794
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
795
+
796
+ # Get the text embedding for conditioning
797
+ with text_enc_context:
798
+ if not args.not_cache_latents:
799
+ if args.train_text_encoder:
800
+ encoder_hidden_states = text_encoder(batch[0][1])[0]
801
+ else:
802
+ encoder_hidden_states = batch[0][1]
803
+ else:
804
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
805
+
806
+ # Predict the noise residual
807
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
808
+
809
+ # Get the target for loss depending on the prediction type
810
+ if noise_scheduler.config.prediction_type == "epsilon":
811
+ target = noise
812
+ elif noise_scheduler.config.prediction_type == "v_prediction":
813
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
814
+ else:
815
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
816
+
817
+ if args.with_prior_preservation:
818
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
819
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
820
+ target, target_prior = torch.chunk(target, 2, dim=0)
821
+
822
+ # Compute instance loss
823
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
824
+
825
+ # Compute prior loss
826
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
827
+
828
+ # Add the prior loss to the instance loss.
829
+ loss = loss + args.prior_loss_weight * prior_loss
830
+ else:
831
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
832
+
833
+ accelerator.backward(loss)
834
+ # if accelerator.sync_gradients:
835
+ # params_to_clip = (
836
+ # itertools.chain(unet.parameters(), text_encoder.parameters())
837
+ # if args.train_text_encoder
838
+ # else unet.parameters()
839
+ # )
840
+ # accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
841
+ optimizer.step()
842
+ lr_scheduler.step()
843
+ optimizer.zero_grad(set_to_none=True)
844
+ loss_avg.update(loss.detach_(), bsz)
845
+
846
+ if not global_step % args.log_interval:
847
+ logs = {"loss": loss_avg.avg.item(), "lr": lr_scheduler.get_last_lr()[0]}
848
+ progress_bar.set_postfix(**logs)
849
+ accelerator.log(logs, step=global_step)
850
+
851
+ if global_step > 0 and not global_step % args.save_interval and global_step >= args.save_min_steps:
852
+ save_weights(global_step)
853
+
854
+ progress_bar.update(1)
855
+ global_step += 1
856
+
857
+ if global_step >= args.max_train_steps:
858
+ break
859
+
860
+ accelerator.wait_for_everyone()
861
+
862
+ save_weights(global_step)
863
+
864
+ accelerator.end_training()
865
+
866
+
867
+ if __name__ == "__main__":
868
+ args = parse_args()
869
+ main(args)