giulio98 commited on
Commit
ef4f2e6
·
verified ·
1 Parent(s): 65fe534

Update unet/__main__.py

Browse files
Files changed (1) hide show
  1. unet/__main__.py +2 -1
unet/__main__.py CHANGED
@@ -105,6 +105,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
105
  class_embed_type: Optional[str] = None,
106
  num_class_embeds: Optional[int] = None,
107
  num_train_timesteps: Optional[int] = None,
 
108
  ):
109
  super().__init__()
110
 
@@ -127,7 +128,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
127
 
128
  # time
129
  if time_embedding_type == "fourier":
130
- self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16, set_W_to_weight=False)
131
  timestep_input_dim = 2 * block_out_channels[0]
132
  elif time_embedding_type == "positional":
133
  self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
 
105
  class_embed_type: Optional[str] = None,
106
  num_class_embeds: Optional[int] = None,
107
  num_train_timesteps: Optional[int] = None,
108
+ set_W_to_weight: Optional[bool] = True,
109
  ):
110
  super().__init__()
111
 
 
128
 
129
  # time
130
  if time_embedding_type == "fourier":
131
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16, set_W_to_weight=set_W_to_weight)
132
  timestep_input_dim = 2 * block_out_channels[0]
133
  elif time_embedding_type == "positional":
134
  self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)