Spaces:
Runtime error
Runtime error
Merge pull request #11 from LightricksResearch/rm-dist-util
Browse files
xora/models/autoencoders/vae_encode.py
CHANGED
|
@@ -6,8 +6,10 @@ from torch import Tensor
|
|
| 6 |
|
| 7 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 8 |
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
|
| 13 |
"""
|
|
@@ -54,10 +56,12 @@ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae
|
|
| 54 |
encode_bs = len(media_items) // split_size
|
| 55 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
| 56 |
latents = []
|
| 57 |
-
|
|
|
|
| 58 |
for image_batch in media_items.split(encode_bs):
|
| 59 |
latents.append(vae.encode(image_batch).latent_dist.sample())
|
| 60 |
-
|
|
|
|
| 61 |
latents = torch.cat(latents, dim=0)
|
| 62 |
else:
|
| 63 |
latents = vae.encode(media_items).latent_dist.sample()
|
|
|
|
| 6 |
|
| 7 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 8 |
from xora.models.autoencoders.video_autoencoder import Downsample3D, VideoAutoencoder
|
| 9 |
+
try:
|
| 10 |
+
import torch_xla.core.xla_model as xm
|
| 11 |
+
except:
|
| 12 |
+
pass
|
| 13 |
|
| 14 |
def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
|
| 15 |
"""
|
|
|
|
| 56 |
encode_bs = len(media_items) // split_size
|
| 57 |
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
| 58 |
latents = []
|
| 59 |
+
if media_items.device.type == "xla":
|
| 60 |
+
xm.mark_step()
|
| 61 |
for image_batch in media_items.split(encode_bs):
|
| 62 |
latents.append(vae.encode(image_batch).latent_dist.sample())
|
| 63 |
+
if media_items.device.type == "xla":
|
| 64 |
+
xm.mark_step()
|
| 65 |
latents = torch.cat(latents, dim=0)
|
| 66 |
else:
|
| 67 |
latents = vae.encode(media_items).latent_dist.sample()
|
xora/utils/dist_util.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
from enum import Enum
|
| 2 |
-
|
| 3 |
-
def execute_graph() -> None:
|
| 4 |
-
if _acceleration_type == AccelerationType.TPU:
|
| 5 |
-
xm.mark_step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|