Sapir Weissbuch commited on
Commit
7f1b6d2
·
unverified ·
2 Parent(s): 05cb3e4 39316ac

Merge pull request #26 from LightricksResearch/cuda_optional

Browse files
.gitignore CHANGED
@@ -159,4 +159,7 @@ cython_debug/
159
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
- .idea/
 
 
 
 
159
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ .idea/
163
+
164
+ # From inference.py
165
+ video_output_*.mp4
inference.py CHANGED
@@ -55,7 +55,9 @@ def load_vae(vae_dir):
55
  vae = CausalVideoAutoencoder.from_config(vae_config)
56
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
57
  vae.load_state_dict(vae_state_dict)
58
- return vae.cuda().to(torch.bfloat16)
 
 
59
 
60
 
61
  def load_unet(unet_dir):
@@ -65,7 +67,9 @@ def load_unet(unet_dir):
65
  transformer = Transformer3DModel.from_config(transformer_config)
66
  unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
67
  transformer.load_state_dict(unet_state_dict, strict=True)
68
- return transformer.cuda()
 
 
69
 
70
 
71
  def load_scheduler(scheduler_dir):
@@ -254,7 +258,9 @@ def main():
254
  patchifier = SymmetricPatchifier(patch_size=1)
255
  text_encoder = T5EncoderModel.from_pretrained(
256
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
257
- ).to("cuda")
 
 
258
  tokenizer = T5Tokenizer.from_pretrained(
259
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
260
  )
@@ -272,7 +278,9 @@ def main():
272
  "vae": vae,
273
  }
274
 
275
- pipeline = XoraVideoPipeline(**submodel_dict).to("cuda")
 
 
276
 
277
  # Prepare input for the pipeline
278
  sample = {
@@ -286,8 +294,12 @@ def main():
286
  random.seed(args.seed)
287
  np.random.seed(args.seed)
288
  torch.manual_seed(args.seed)
289
- torch.cuda.manual_seed(args.seed)
290
- generator = torch.Generator(device="cuda").manual_seed(args.seed)
 
 
 
 
291
 
292
  images = pipeline(
293
  num_inference_steps=args.num_inference_steps,
@@ -322,7 +334,9 @@ def main():
322
  )
323
 
324
  for i in range(images.shape[0]):
325
- video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
 
 
326
  video_np = (video_np * 255).astype(np.uint8)
327
  fps = args.frame_rate
328
  height, width = video_np.shape[1:3]
 
55
  vae = CausalVideoAutoencoder.from_config(vae_config)
56
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
57
  vae.load_state_dict(vae_state_dict)
58
+ if torch.cuda.is_available():
59
+ vae = vae.cuda()
60
+ return vae.to(torch.bfloat16)
61
 
62
 
63
  def load_unet(unet_dir):
 
67
  transformer = Transformer3DModel.from_config(transformer_config)
68
  unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
69
  transformer.load_state_dict(unet_state_dict, strict=True)
70
+ if torch.cuda.is_available():
71
+ transformer = transformer.cuda()
72
+ return transformer
73
 
74
 
75
  def load_scheduler(scheduler_dir):
 
258
  patchifier = SymmetricPatchifier(patch_size=1)
259
  text_encoder = T5EncoderModel.from_pretrained(
260
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
261
+ )
262
+ if torch.cuda.is_available():
263
+ text_encoder = text_encoder.to("cuda")
264
  tokenizer = T5Tokenizer.from_pretrained(
265
  "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
266
  )
 
278
  "vae": vae,
279
  }
280
 
281
+ pipeline = XoraVideoPipeline(**submodel_dict)
282
+ if torch.cuda.is_available():
283
+ pipeline = pipeline.to("cuda")
284
 
285
  # Prepare input for the pipeline
286
  sample = {
 
294
  random.seed(args.seed)
295
  np.random.seed(args.seed)
296
  torch.manual_seed(args.seed)
297
+ if torch.cuda.is_available():
298
+ torch.cuda.manual_seed(args.seed)
299
+
300
+ generator = torch.Generator(
301
+ device="cuda" if torch.cuda.is_available() else "cpu"
302
+ ).manual_seed(args.seed)
303
 
304
  images = pipeline(
305
  num_inference_steps=args.num_inference_steps,
 
334
  )
335
 
336
  for i in range(images.shape[0]):
337
+ # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
338
+ video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
339
+ # Unnormalizing images to [0, 255] range
340
  video_np = (video_np * 255).astype(np.uint8)
341
  fps = args.frame_rate
342
  height, width = video_np.shape[1:3]
xora/pipelines/pipeline_xora_video.py CHANGED
@@ -1010,7 +1010,12 @@ class XoraVideoPipeline(DiffusionPipeline):
1010
  current_timestep = current_timestep * (1 - conditioning_mask)
1011
  # Choose the appropriate context manager based on `mixed_precision`
1012
  if mixed_precision:
1013
- context_manager = torch.autocast("cuda", dtype=torch.bfloat16)
 
 
 
 
 
1014
  else:
1015
  context_manager = nullcontext() # Dummy context manager
1016
 
 
1010
  current_timestep = current_timestep * (1 - conditioning_mask)
1011
  # Choose the appropriate context manager based on `mixed_precision`
1012
  if mixed_precision:
1013
+ if "xla" in device.type:
1014
+ raise NotImplementedError(
1015
+ "Mixed precision is not supported yet on XLA devices."
1016
+ )
1017
+
1018
+ context_manager = torch.autocast(device, dtype=torch.bfloat16)
1019
  else:
1020
  context_manager = nullcontext() # Dummy context manager
1021