Spaces:
Paused
Paused
Merge pull request #26 from LightricksResearch/cuda_optional
Browse files- .gitignore +4 -1
- inference.py +21 -7
- xora/pipelines/pipeline_xora_video.py +6 -1
.gitignore
CHANGED
|
@@ -159,4 +159,7 @@ cython_debug/
|
|
| 159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
-
.idea/
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
.idea/
|
| 163 |
+
|
| 164 |
+
# From inference.py
|
| 165 |
+
video_output_*.mp4
|
inference.py
CHANGED
|
@@ -55,7 +55,9 @@ def load_vae(vae_dir):
|
|
| 55 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
| 56 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 57 |
vae.load_state_dict(vae_state_dict)
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
def load_unet(unet_dir):
|
|
@@ -65,7 +67,9 @@ def load_unet(unet_dir):
|
|
| 65 |
transformer = Transformer3DModel.from_config(transformer_config)
|
| 66 |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
| 67 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
def load_scheduler(scheduler_dir):
|
|
@@ -254,7 +258,9 @@ def main():
|
|
| 254 |
patchifier = SymmetricPatchifier(patch_size=1)
|
| 255 |
text_encoder = T5EncoderModel.from_pretrained(
|
| 256 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
| 257 |
-
)
|
|
|
|
|
|
|
| 258 |
tokenizer = T5Tokenizer.from_pretrained(
|
| 259 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
| 260 |
)
|
|
@@ -272,7 +278,9 @@ def main():
|
|
| 272 |
"vae": vae,
|
| 273 |
}
|
| 274 |
|
| 275 |
-
pipeline = XoraVideoPipeline(**submodel_dict)
|
|
|
|
|
|
|
| 276 |
|
| 277 |
# Prepare input for the pipeline
|
| 278 |
sample = {
|
|
@@ -286,8 +294,12 @@ def main():
|
|
| 286 |
random.seed(args.seed)
|
| 287 |
np.random.seed(args.seed)
|
| 288 |
torch.manual_seed(args.seed)
|
| 289 |
-
torch.cuda.
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
images = pipeline(
|
| 293 |
num_inference_steps=args.num_inference_steps,
|
|
@@ -322,7 +334,9 @@ def main():
|
|
| 322 |
)
|
| 323 |
|
| 324 |
for i in range(images.shape[0]):
|
| 325 |
-
|
|
|
|
|
|
|
| 326 |
video_np = (video_np * 255).astype(np.uint8)
|
| 327 |
fps = args.frame_rate
|
| 328 |
height, width = video_np.shape[1:3]
|
|
|
|
| 55 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
| 56 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 57 |
vae.load_state_dict(vae_state_dict)
|
| 58 |
+
if torch.cuda.is_available():
|
| 59 |
+
vae = vae.cuda()
|
| 60 |
+
return vae.to(torch.bfloat16)
|
| 61 |
|
| 62 |
|
| 63 |
def load_unet(unet_dir):
|
|
|
|
| 67 |
transformer = Transformer3DModel.from_config(transformer_config)
|
| 68 |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
| 69 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
| 70 |
+
if torch.cuda.is_available():
|
| 71 |
+
transformer = transformer.cuda()
|
| 72 |
+
return transformer
|
| 73 |
|
| 74 |
|
| 75 |
def load_scheduler(scheduler_dir):
|
|
|
|
| 258 |
patchifier = SymmetricPatchifier(patch_size=1)
|
| 259 |
text_encoder = T5EncoderModel.from_pretrained(
|
| 260 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
| 261 |
+
)
|
| 262 |
+
if torch.cuda.is_available():
|
| 263 |
+
text_encoder = text_encoder.to("cuda")
|
| 264 |
tokenizer = T5Tokenizer.from_pretrained(
|
| 265 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
| 266 |
)
|
|
|
|
| 278 |
"vae": vae,
|
| 279 |
}
|
| 280 |
|
| 281 |
+
pipeline = XoraVideoPipeline(**submodel_dict)
|
| 282 |
+
if torch.cuda.is_available():
|
| 283 |
+
pipeline = pipeline.to("cuda")
|
| 284 |
|
| 285 |
# Prepare input for the pipeline
|
| 286 |
sample = {
|
|
|
|
| 294 |
random.seed(args.seed)
|
| 295 |
np.random.seed(args.seed)
|
| 296 |
torch.manual_seed(args.seed)
|
| 297 |
+
if torch.cuda.is_available():
|
| 298 |
+
torch.cuda.manual_seed(args.seed)
|
| 299 |
+
|
| 300 |
+
generator = torch.Generator(
|
| 301 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 302 |
+
).manual_seed(args.seed)
|
| 303 |
|
| 304 |
images = pipeline(
|
| 305 |
num_inference_steps=args.num_inference_steps,
|
|
|
|
| 334 |
)
|
| 335 |
|
| 336 |
for i in range(images.shape[0]):
|
| 337 |
+
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
|
| 338 |
+
video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
|
| 339 |
+
# Unnormalizing images to [0, 255] range
|
| 340 |
video_np = (video_np * 255).astype(np.uint8)
|
| 341 |
fps = args.frame_rate
|
| 342 |
height, width = video_np.shape[1:3]
|
xora/pipelines/pipeline_xora_video.py
CHANGED
|
@@ -1010,7 +1010,12 @@ class XoraVideoPipeline(DiffusionPipeline):
|
|
| 1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
| 1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
| 1012 |
if mixed_precision:
|
| 1013 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1014 |
else:
|
| 1015 |
context_manager = nullcontext() # Dummy context manager
|
| 1016 |
|
|
|
|
| 1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
| 1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
| 1012 |
if mixed_precision:
|
| 1013 |
+
if "xla" in device.type:
|
| 1014 |
+
raise NotImplementedError(
|
| 1015 |
+
"Mixed precision is not supported yet on XLA devices."
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
context_manager = torch.autocast(device, dtype=torch.bfloat16)
|
| 1019 |
else:
|
| 1020 |
context_manager = nullcontext() # Dummy context manager
|
| 1021 |
|