Spaces:
Build error
Build error
File size: 5,636 Bytes
910e2ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import sys
import torch
import argparse
from PIL import Image
from diffusers.utils import export_to_video
# Add the project root directory to sys.path
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR)
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from pyramid_dit import PyramidDiTForVideoGeneration
from trainer_misc import init_distributed_mode, init_sequence_parallel_group
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False)
parser.add_argument('--model_name', default='pyramid_mmdit', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"])
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16")
parser.add_argument('--model_path', required=True, type=str, help='Path to the downloaded checkpoint directory')
parser.add_argument('--variant', default='diffusion_transformer_768p', type=str)
parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v'])
parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1')
parser.add_argument('--sp_group_size', default=2, type=int, help="The number of GPUs used for inference, should be 2 or 4")
parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of processes used for video training, default=-1 means using all processes.")
parser.add_argument('--prompt', type=str, required=True, help="Text prompt for video generation")
parser.add_argument('--image_path', type=str, help="Path to the input image for image-to-video")
parser.add_argument('--video_guidance_scale', type=float, default=5.0, help="Video guidance scale")
parser.add_argument('--guidance_scale', type=float, default=9.0, help="Guidance scale for text-to-video")
parser.add_argument('--resolution', type=str, default='768p', choices=['768p', '384p'], help="Model resolution")
parser.add_argument('--output_path', type=str, required=True, help="Path to save the generated video")
return parser.parse_args()
def main():
args = get_args()
# Setup DDP
init_distributed_mode(args)
assert args.world_size == args.sp_group_size, "The sequence parallel size should match DDP world size"
# Enable sequence parallel
init_sequence_parallel_group(args)
device = torch.device('cuda')
rank = args.rank
model_dtype = args.model_dtype
if args.model_name == "pyramid_flux":
assert args.variant != "diffusion_transformer_768p", "The pyramid_flux does not support high resolution now, \
we will release it after finishing training. You can modify the model_name to pyramid_mmdit to support 768p version generation"
model = PyramidDiTForVideoGeneration(
args.model_path,
model_dtype,
model_name=args.model_name,
model_variant=args.variant,
)
model.vae.to(device)
model.dit.to(device)
model.text_encoder.to(device)
model.vae.enable_tiling()
if model_dtype == "bf16":
torch_dtype = torch.bfloat16
elif model_dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# The video generation config
if args.resolution == '768p':
width = 1280
height = 768
else:
width = 640
height = 384
try:
if args.task == 't2v':
prompt = args.prompt
with torch.no_grad(), torch.cuda.amp.autocast(enabled=(model_dtype != 'fp32'), dtype=torch_dtype):
frames = model.generate(
prompt=prompt,
num_inference_steps=[20, 20, 20],
video_num_inference_steps=[10, 10, 10],
height=height,
width=width,
temp=args.temp,
guidance_scale=args.guidance_scale,
video_guidance_scale=args.video_guidance_scale,
output_type="pil",
save_memory=True,
cpu_offloading=False,
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, args.output_path, fps=24)
elif args.task == 'i2v':
if not args.image_path:
raise ValueError("Image path is required for image-to-video task")
image = Image.open(args.image_path).convert("RGB")
image = image.resize((width, height))
prompt = args.prompt
with torch.no_grad(), torch.cuda.amp.autocast(enabled=(model_dtype != 'fp32'), dtype=torch_dtype):
frames = model.generate_i2v(
prompt=prompt,
input_image=image,
num_inference_steps=[10, 10, 10],
temp=args.temp,
video_guidance_scale=args.video_guidance_scale,
output_type="pil",
save_memory=True,
cpu_offloading=False,
inference_multigpu=True,
)
if rank == 0:
export_to_video(frames, args.output_path, fps=24)
except Exception as e:
if rank == 0:
print(f"[ERROR] Error during video generation: {e}")
raise
finally:
torch.distributed.barrier()
if __name__ == "__main__":
main()
|