Spaces:
Runtime error
Runtime error
VAE: Check for timesteps parameter in decoder before calling
Browse files
xora/models/autoencoders/vae.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from typing import Optional, Union
|
| 2 |
|
| 3 |
import torch
|
|
|
|
| 4 |
import math
|
| 5 |
import torch.nn as nn
|
| 6 |
from diffusers import ConfigMixin, ModelMixin
|
|
@@ -60,6 +61,8 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
| 60 |
self.dims = dims
|
| 61 |
self.z_sample_size = 1
|
| 62 |
|
|
|
|
|
|
|
| 63 |
# only relevant if vae tiling is enabled
|
| 64 |
self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
|
| 65 |
|
|
@@ -257,7 +260,10 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
|
| 257 |
timesteps: Optional[torch.Tensor] = None,
|
| 258 |
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 259 |
z = self.post_quant_conv(z)
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
| 261 |
return dec
|
| 262 |
|
| 263 |
def decode(
|
|
|
|
| 1 |
from typing import Optional, Union
|
| 2 |
|
| 3 |
import torch
|
| 4 |
+
import inspect
|
| 5 |
import math
|
| 6 |
import torch.nn as nn
|
| 7 |
from diffusers import ConfigMixin, ModelMixin
|
|
|
|
| 61 |
self.dims = dims
|
| 62 |
self.z_sample_size = 1
|
| 63 |
|
| 64 |
+
self.decoder_params = inspect.signature(self.decoder.forward).parameters
|
| 65 |
+
|
| 66 |
# only relevant if vae tiling is enabled
|
| 67 |
self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
|
| 68 |
|
|
|
|
| 260 |
timesteps: Optional[torch.Tensor] = None,
|
| 261 |
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 262 |
z = self.post_quant_conv(z)
|
| 263 |
+
if "timesteps" in self.decoder_params:
|
| 264 |
+
dec = self.decoder(z, target_shape=target_shape, timesteps=timesteps)
|
| 265 |
+
else:
|
| 266 |
+
dec = self.decoder(z, target_shape=target_shape)
|
| 267 |
return dec
|
| 268 |
|
| 269 |
def decode(
|