Spaces:
Running
Running
Merge pull request #26 from LightricksResearch/cuda_optional
Browse files- .gitignore +4 -1
- inference.py +21 -7
- xora/pipelines/pipeline_xora_video.py +6 -1
.gitignore
CHANGED
@@ -159,4 +159,7 @@ cython_debug/
|
|
159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
-
.idea/
|
|
|
|
|
|
|
|
159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
.idea/
|
163 |
+
|
164 |
+
# From inference.py
|
165 |
+
video_output_*.mp4
|
inference.py
CHANGED
@@ -55,7 +55,9 @@ def load_vae(vae_dir):
|
|
55 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
56 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
57 |
vae.load_state_dict(vae_state_dict)
|
58 |
-
|
|
|
|
|
59 |
|
60 |
|
61 |
def load_unet(unet_dir):
|
@@ -65,7 +67,9 @@ def load_unet(unet_dir):
|
|
65 |
transformer = Transformer3DModel.from_config(transformer_config)
|
66 |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
67 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
68 |
-
|
|
|
|
|
69 |
|
70 |
|
71 |
def load_scheduler(scheduler_dir):
|
@@ -254,7 +258,9 @@ def main():
|
|
254 |
patchifier = SymmetricPatchifier(patch_size=1)
|
255 |
text_encoder = T5EncoderModel.from_pretrained(
|
256 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
257 |
-
)
|
|
|
|
|
258 |
tokenizer = T5Tokenizer.from_pretrained(
|
259 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
260 |
)
|
@@ -272,7 +278,9 @@ def main():
|
|
272 |
"vae": vae,
|
273 |
}
|
274 |
|
275 |
-
pipeline = XoraVideoPipeline(**submodel_dict)
|
|
|
|
|
276 |
|
277 |
# Prepare input for the pipeline
|
278 |
sample = {
|
@@ -286,8 +294,12 @@ def main():
|
|
286 |
random.seed(args.seed)
|
287 |
np.random.seed(args.seed)
|
288 |
torch.manual_seed(args.seed)
|
289 |
-
torch.cuda.
|
290 |
-
|
|
|
|
|
|
|
|
|
291 |
|
292 |
images = pipeline(
|
293 |
num_inference_steps=args.num_inference_steps,
|
@@ -322,7 +334,9 @@ def main():
|
|
322 |
)
|
323 |
|
324 |
for i in range(images.shape[0]):
|
325 |
-
|
|
|
|
|
326 |
video_np = (video_np * 255).astype(np.uint8)
|
327 |
fps = args.frame_rate
|
328 |
height, width = video_np.shape[1:3]
|
|
|
55 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
56 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
57 |
vae.load_state_dict(vae_state_dict)
|
58 |
+
if torch.cuda.is_available():
|
59 |
+
vae = vae.cuda()
|
60 |
+
return vae.to(torch.bfloat16)
|
61 |
|
62 |
|
63 |
def load_unet(unet_dir):
|
|
|
67 |
transformer = Transformer3DModel.from_config(transformer_config)
|
68 |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
69 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
70 |
+
if torch.cuda.is_available():
|
71 |
+
transformer = transformer.cuda()
|
72 |
+
return transformer
|
73 |
|
74 |
|
75 |
def load_scheduler(scheduler_dir):
|
|
|
258 |
patchifier = SymmetricPatchifier(patch_size=1)
|
259 |
text_encoder = T5EncoderModel.from_pretrained(
|
260 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
261 |
+
)
|
262 |
+
if torch.cuda.is_available():
|
263 |
+
text_encoder = text_encoder.to("cuda")
|
264 |
tokenizer = T5Tokenizer.from_pretrained(
|
265 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
266 |
)
|
|
|
278 |
"vae": vae,
|
279 |
}
|
280 |
|
281 |
+
pipeline = XoraVideoPipeline(**submodel_dict)
|
282 |
+
if torch.cuda.is_available():
|
283 |
+
pipeline = pipeline.to("cuda")
|
284 |
|
285 |
# Prepare input for the pipeline
|
286 |
sample = {
|
|
|
294 |
random.seed(args.seed)
|
295 |
np.random.seed(args.seed)
|
296 |
torch.manual_seed(args.seed)
|
297 |
+
if torch.cuda.is_available():
|
298 |
+
torch.cuda.manual_seed(args.seed)
|
299 |
+
|
300 |
+
generator = torch.Generator(
|
301 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
302 |
+
).manual_seed(args.seed)
|
303 |
|
304 |
images = pipeline(
|
305 |
num_inference_steps=args.num_inference_steps,
|
|
|
334 |
)
|
335 |
|
336 |
for i in range(images.shape[0]):
|
337 |
+
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
|
338 |
+
video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
|
339 |
+
# Unnormalizing images to [0, 255] range
|
340 |
video_np = (video_np * 255).astype(np.uint8)
|
341 |
fps = args.frame_rate
|
342 |
height, width = video_np.shape[1:3]
|
xora/pipelines/pipeline_xora_video.py
CHANGED
@@ -1010,7 +1010,12 @@ class XoraVideoPipeline(DiffusionPipeline):
|
|
1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
1012 |
if mixed_precision:
|
1013 |
-
|
|
|
|
|
|
|
|
|
|
|
1014 |
else:
|
1015 |
context_manager = nullcontext() # Dummy context manager
|
1016 |
|
|
|
1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
1012 |
if mixed_precision:
|
1013 |
+
if "xla" in device.type:
|
1014 |
+
raise NotImplementedError(
|
1015 |
+
"Mixed precision is not supported yet on XLA devices."
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
context_manager = torch.autocast(device, dtype=torch.bfloat16)
|
1019 |
else:
|
1020 |
context_manager = nullcontext() # Dummy context manager
|
1021 |
|