tolgacangoz commited on
Commit
d3d0ea5
·
verified ·
1 Parent(s): c8ee5fc

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. scheduler/matryoshka.py +14 -2
scheduler/matryoshka.py CHANGED
@@ -3746,10 +3746,12 @@ class MatryoshkaPipeline(
3746
  self,
3747
  text_encoder: T5EncoderModel,
3748
  tokenizer: T5TokenizerFast,
3749
- unet: MatryoshkaUNet2DConditionModel,
3750
  scheduler: MatryoshkaDDIMScheduler,
 
3751
  feature_extractor: CLIPImageProcessor = None,
3752
  image_encoder: CLIPVisionModelWithProjection = None,
 
 
3753
  ):
3754
  super().__init__()
3755
 
@@ -3801,6 +3803,16 @@ class MatryoshkaPipeline(
3801
  new_config["sample_size"] = 64
3802
  unet._internal_dict = FrozenDict(new_config)
3803
 
 
 
 
 
 
 
 
 
 
 
3804
  self.register_modules(
3805
  text_encoder=text_encoder,
3806
  tokenizer=tokenizer,
@@ -4510,7 +4522,7 @@ class MatryoshkaPipeline(
4510
  timesteps, num_inference_steps = retrieve_timesteps(
4511
  self.scheduler, num_inference_steps, device, timesteps, sigmas
4512
  )
4513
- timesteps = timesteps[:-1] # is this correct???
4514
  else:
4515
  timesteps = self.scheduler.timesteps
4516
 
 
3746
  self,
3747
  text_encoder: T5EncoderModel,
3748
  tokenizer: T5TokenizerFast,
 
3749
  scheduler: MatryoshkaDDIMScheduler,
3750
+ unet: MatryoshkaUNet2DConditionModel = None,
3751
  feature_extractor: CLIPImageProcessor = None,
3752
  image_encoder: CLIPVisionModelWithProjection = None,
3753
+ trust_remote_code: bool = False,
3754
+ nesting_level: int = 0,
3755
  ):
3756
  super().__init__()
3757
 
 
3803
  new_config["sample_size"] = 64
3804
  unet._internal_dict = FrozenDict(new_config)
3805
 
3806
+ if nesting_level == 0:
3807
+ unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3808
+ subfolder="unet/nesting_level_0")
3809
+ elif nesting_level == 1:
3810
+ unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3811
+ subfolder="unet/nesting_level_1")
3812
+ elif nesting_level == 2:
3813
+ unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
3814
+ subfolder="unet/nesting_level_2")
3815
+
3816
  self.register_modules(
3817
  text_encoder=text_encoder,
3818
  tokenizer=tokenizer,
 
4522
  timesteps, num_inference_steps = retrieve_timesteps(
4523
  self.scheduler, num_inference_steps, device, timesteps, sigmas
4524
  )
4525
+ timesteps = timesteps[:-1]
4526
  else:
4527
  timesteps = self.scheduler.timesteps
4528