OmPrakashSingh1704 commited on
Commit
6e8e48d
·
1 Parent(s): ad69708
Files changed (1) hide show
  1. options/Video_model/Model.py +20 -20
options/Video_model/Model.py CHANGED
@@ -56,26 +56,26 @@ def Video(
56
  # Perform computation with appropriate dtype based on device
57
  if device == "cuda":
58
  # Use float16 for GPU
59
- with torch.autocast(device_type='cuda', dtype=torch.float16):
60
- frames = pipeline(
61
- image, height=height, width=width,
62
- num_inference_steps=num_inference_steps,
63
- min_guidance_scale=min_guidance_scale,
64
- max_guidance_scale=max_guidance_scale,
65
- num_frames=num_frames, fps=fps, motion_bucket_id=motion_bucket_id,
66
- generator=generator,
67
- ).frames[0]
68
- else:
69
- # Use bfloat16 for CPU as it's supported in torch.autocast
70
- with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
71
- frames = pipeline(
72
- image, height=height, width=width,
73
- num_inference_steps=num_inference_steps,
74
- min_guidance_scale=min_guidance_scale,
75
- max_guidance_scale=max_guidance_scale,
76
- num_frames=num_frames, fps=fps, motion_bucket_id=motion_bucket_id,
77
- generator=generator,
78
- ).frames[0]
79
 
80
 
81
  # Save the generated video
 
56
  # Perform computation with appropriate dtype based on device
57
  if device == "cuda":
58
  # Use float16 for GPU
59
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
60
+ frames = pipeline(
61
+ image, height=height, width=width,
62
+ num_inference_steps=num_inference_steps,
63
+ min_guidance_scale=min_guidance_scale,
64
+ max_guidance_scale=max_guidance_scale,
65
+ num_frames=num_frames, fps=fps, motion_bucket_id=motion_bucket_id,
66
+ generator=generator,
67
+ ).frames[0]
68
+ else:
69
+ # Use bfloat16 for CPU as it's supported in torch.autocast
70
+ with torch.autocast(device_type='cpu', dtype=torch.bfloat16):
71
+ frames = pipeline(
72
+ image, height=height, width=width,
73
+ num_inference_steps=num_inference_steps,
74
+ min_guidance_scale=min_guidance_scale,
75
+ max_guidance_scale=max_guidance_scale,
76
+ num_frames=num_frames, fps=fps, motion_bucket_id=motion_bucket_id,
77
+ generator=generator,
78
+ ).frames[0]
79
 
80
 
81
  # Save the generated video