Update imagic.py
Browse files
imagic.py
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
import inspect
|
| 6 |
import warnings
|
| 7 |
from typing import List, Optional, Union
|
| 8 |
-
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
|
@@ -236,7 +236,8 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline):
|
|
| 236 |
text_embeddings_orig = text_embeddings.clone()
|
| 237 |
|
| 238 |
# Initialize the optimizer
|
| 239 |
-
|
|
|
|
| 240 |
[text_embeddings], # only optimize the embeddings
|
| 241 |
lr=embedding_learning_rate,
|
| 242 |
)
|
|
|
|
| 5 |
import inspect
|
| 6 |
import warnings
|
| 7 |
from typing import List, Optional, Union
|
| 8 |
+
import bitsandbytes as bnb
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
|
|
|
| 236 |
text_embeddings_orig = text_embeddings.clone()
|
| 237 |
|
| 238 |
# Initialize the optimizer
|
| 239 |
+
|
| 240 |
+
optimizer = bnb.optim.Adam8bit(
|
| 241 |
[text_embeddings], # only optimize the embeddings
|
| 242 |
lr=embedding_learning_rate,
|
| 243 |
)
|