Update unet/__main__.py
Browse files- unet/__main__.py +2 -1
unet/__main__.py
CHANGED
@@ -105,6 +105,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
105 |
class_embed_type: Optional[str] = None,
|
106 |
num_class_embeds: Optional[int] = None,
|
107 |
num_train_timesteps: Optional[int] = None,
|
|
|
108 |
):
|
109 |
super().__init__()
|
110 |
|
@@ -127,7 +128,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
127 |
|
128 |
# time
|
129 |
if time_embedding_type == "fourier":
|
130 |
-
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16, set_W_to_weight=
|
131 |
timestep_input_dim = 2 * block_out_channels[0]
|
132 |
elif time_embedding_type == "positional":
|
133 |
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
|
|
105 |
class_embed_type: Optional[str] = None,
|
106 |
num_class_embeds: Optional[int] = None,
|
107 |
num_train_timesteps: Optional[int] = None,
|
108 |
+
set_W_to_weight: Optional[bool] = True,
|
109 |
):
|
110 |
super().__init__()
|
111 |
|
|
|
128 |
|
129 |
# time
|
130 |
if time_embedding_type == "fourier":
|
131 |
+
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16, set_W_to_weight=set_W_to_weight)
|
132 |
timestep_input_dim = 2 * block_out_channels[0]
|
133 |
elif time_embedding_type == "positional":
|
134 |
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|