Spaces:
Running
on
Zero
Running
on
Zero
Update cog_sdxl_dataset_and_utils.py
Browse files
cog_sdxl_dataset_and_utils.py
CHANGED
@@ -33,6 +33,11 @@ def prepare_mask(mask: PIL.Image.Image, width: int = 512, height: int = 512) ->
|
|
33 |
return torch.from_numpy(np.expand_dims(arr, 0)).unsqueeze(0)
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
class PreprocessedDataset(Dataset):
|
37 |
def __init__(
|
38 |
self,
|
@@ -175,3 +180,4 @@ def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
|
|
175 |
unet.to(device, dtype=weight_dtype)
|
176 |
|
177 |
return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet
|
|
|
|
33 |
return torch.from_numpy(np.expand_dims(arr, 0)).unsqueeze(0)
|
34 |
|
35 |
|
36 |
+
class TokenEmbeddingsHandler:
|
37 |
+
def __init__(self, text_encoders, tokenizers):
|
38 |
+
self.text_encoders = text_encoders
|
39 |
+
self.tokenizers = tokenizers
|
40 |
+
|
41 |
class PreprocessedDataset(Dataset):
|
42 |
def __init__(
|
43 |
self,
|
|
|
180 |
unet.to(device, dtype=weight_dtype)
|
181 |
|
182 |
return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet
|
183 |
+
|