Spaces:
Runtime error
Runtime error
Memory saving
Browse files
train.py
CHANGED
|
@@ -10,6 +10,9 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
|
|
| 10 |
diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
|
| 11 |
diffuser.train()
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
|
| 14 |
|
| 15 |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
|
@@ -22,7 +25,11 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
|
|
| 22 |
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
|
| 23 |
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
for i in pbar:
|
| 28 |
|
|
@@ -61,8 +68,11 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
|
|
| 61 |
neutral_latents.requires_grad = False
|
| 62 |
|
| 63 |
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
loss.backward()
|
| 65 |
-
losses.append(loss.item())
|
| 66 |
optimizer.step()
|
| 67 |
|
| 68 |
torch.save(finetuner.state_dict(), save_path)
|
|
|
|
| 10 |
diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
|
| 11 |
diffuser.train()
|
| 12 |
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
|
| 17 |
|
| 18 |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
|
|
|
| 25 |
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
|
| 26 |
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
|
| 27 |
|
| 28 |
+
del diffuser.vae
|
| 29 |
+
del diffuser.text_encoder
|
| 30 |
+
del diffuser.tokenizer
|
| 31 |
+
|
| 32 |
+
torch.cuda.empty_cache()
|
| 33 |
|
| 34 |
for i in pbar:
|
| 35 |
|
|
|
|
| 68 |
neutral_latents.requires_grad = False
|
| 69 |
|
| 70 |
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
|
| 71 |
+
|
| 72 |
+
del negative_latents, neutral_latents, positive_latents, latents_steps, latents
|
| 73 |
+
torch.cuda.empty_cache()
|
| 74 |
+
|
| 75 |
loss.backward()
|
|
|
|
| 76 |
optimizer.step()
|
| 77 |
|
| 78 |
torch.save(finetuner.state_dict(), save_path)
|