|
#!/usr/bin/env bash |
|
|
|
|
|
export TOKENIZERS_PARALLELISM=false |
|
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True |
|
|
|
|
|
MODEL_ARGS=( |
|
--model_path "THUDM/CogVideoX-5b-I2V" |
|
--model_name "cogvideox-flovd" |
|
--model_type "i2vFlow" |
|
--training_type "controlnet" |
|
|
|
) |
|
|
|
|
|
OUTPUT_ARGS=( |
|
--output_dir "absolute/path/to/output" |
|
--report_to "wandb" |
|
--run_name "FloVD_CogVideoX_controlnet" |
|
) |
|
|
|
|
|
DATA_ARGS=( |
|
--data_root "absolute/path/to/whole_data" |
|
--caption_column "prompt.txt" |
|
--video_column "videos.txt" |
|
|
|
--train_resolution "49x480x720" |
|
) |
|
|
|
|
|
TRAIN_ARGS=( |
|
--train_epochs 10 |
|
--seed 42 |
|
--batch_size 1 |
|
--gradient_accumulation_steps 2 |
|
--mixed_precision "bf16" |
|
--learning_rate 1e-5 |
|
) |
|
|
|
|
|
SYSTEM_ARGS=( |
|
--num_workers 8 |
|
--pin_memory True |
|
--nccl_timeout 1800 |
|
) |
|
|
|
|
|
CHECKPOINT_ARGS=( |
|
--checkpointing_steps 2000 |
|
--checkpointing_limit 2 |
|
|
|
) |
|
|
|
|
|
VALIDATION_ARGS=( |
|
--do_validation true |
|
--validation_dir "absolute/path/to/whole_data" |
|
--validation_steps 2000 |
|
--validation_prompts "prompts.txt" |
|
--validation_images "images.txt" |
|
--gen_fps 16 |
|
--max_scene 4 |
|
) |
|
|
|
|
|
CONTROLNET_ARGS=( |
|
--controlnet_transformer_num_layers 6 |
|
--controlnet_input_channels 16 |
|
--controlnet_weights 1.0 |
|
--controlnet_guidance_start 0.0 |
|
--controlnet_guidance_end 0.4 |
|
--controlnet_out_proj_dim_factor 64 |
|
--enable_time_sampling false |
|
--time_sampling_type "truncated_normal" |
|
--time_sampling_mean 0.95 |
|
--time_sampling_std 0.1 |
|
--notextinflow true |
|
) |
|
|
|
|
|
|
|
accelerate launch --config_file accelerate_config.yaml train.py \ |
|
"${MODEL_ARGS[@]}" \ |
|
"${OUTPUT_ARGS[@]}" \ |
|
"${DATA_ARGS[@]}" \ |
|
"${TRAIN_ARGS[@]}" \ |
|
"${SYSTEM_ARGS[@]}" \ |
|
"${CHECKPOINT_ARGS[@]}" \ |
|
"${VALIDATION_ARGS[@]}" \ |
|
"${CONTROLNET_ARGS[@]}" |
|
|