primerz commited on
Commit
23f918b
·
verified ·
1 Parent(s): 58a5607

Update cog_sdxl_dataset_and_utils.py

Browse files
Files changed (1) hide show
  1. cog_sdxl_dataset_and_utils.py +6 -0
cog_sdxl_dataset_and_utils.py CHANGED
@@ -33,6 +33,11 @@ def prepare_mask(mask: PIL.Image.Image, width: int = 512, height: int = 512) ->
33
  return torch.from_numpy(np.expand_dims(arr, 0)).unsqueeze(0)
34
 
35
 
 
 
 
 
 
36
  class PreprocessedDataset(Dataset):
37
  def __init__(
38
  self,
@@ -175,3 +180,4 @@ def load_models(pretrained_model_name_or_path, revision, device, weight_dtype):
175
  unet.to(device, dtype=weight_dtype)
176
 
177
  return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet
 
 
33
  return torch.from_numpy(np.expand_dims(arr, 0)).unsqueeze(0)
34
 
35
 
36
+ class TokenEmbeddingsHandler:
37
+ def __init__(self, text_encoders, tokenizers):
38
+ self.text_encoders = text_encoders
39
+ self.tokenizers = tokenizers
40
+
41
  class PreprocessedDataset(Dataset):
42
  def __init__(
43
  self,
 
180
  unet.to(device, dtype=weight_dtype)
181
 
182
  return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet
183
+