File size: 3,318 Bytes
c1f7300
d491e94
 
093967b
62066bc
c1f7300
093967b
fedb718
093967b
fedb718
956b411
fedb718
fca41c8
 
 
 
 
 
 
fedb718
 
 
 
 
 
 
c1f7300
d491e94
fedb718
093967b
edc3608
62066bc
 
 
fca41c8
62066bc
edc3608
d491e94
093967b
d491e94
fedb718
 
 
 
62066bc
d491e94
c1f7300
fedb718
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d491e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fedb718
d491e94
 
 
 
fedb718
d491e94
 
fedb718
d491e94
 
 
 
fedb718
 
 
 
 
 
d491e94
 
fedb718
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import sys
import subprocess
import time
from huggingface_hub import snapshot_download

MODEL_REPO = "tencent/HunyuanVideo-Avatar"
BASE_DIR = os.getcwd()
WEIGHTS_DIR = os.path.join(BASE_DIR, "weights")
OUTPUT_BASEPATH = os.path.join(BASE_DIR, "results-poor")

# Specific checkpoint to use in the poor sampling run
CHECKPOINT_FILE = os.path.join(
    WEIGHTS_DIR,
    "ckpts",
    "hunyuan-video-t2v-720p",
    "transformers",
    "mp_rank_00_model_states.pt"
)
CHECKPOINT_FP8_FILE = os.path.join(
    WEIGHTS_DIR,
    "ckpts",
    "hunyuan-video-t2v-720p",
    "transformers",
    "mp_rank_00_model_states_fp8.pt"
)

def download_model():
    print("⬇️  Model not found. Downloading with snapshot_download into weights directory...")
    os.makedirs(WEIGHTS_DIR, exist_ok=True)

    snapshot_download(
        repo_id=MODEL_REPO,
        local_dir=WEIGHTS_DIR,
        local_dir_use_symlinks=False
    )

    if not os.path.isfile(CHECKPOINT_FILE):
        print(f"❌ Checkpoint file not found at {CHECKPOINT_FILE} after download.")
        sys.exit(1)
    
    if not os.path.isfile(CHECKPOINT_FP8_FILE):
        print(f"❌ FP8 checkpoint file not found at {CHECKPOINT_FP8_FILE}. Cannot proceed with sample_gpu_poor.py.")
        sys.exit(1)

    print("βœ… Model downloaded successfully.")

def run_sample_gpu_poor():
    print("🎬 Running sample_gpu_poor.py...")
    cmd = [
        "python3", "hymm_sp/sample_gpu_poor.py",
        "--input", "assets/test.csv",
        "--ckpt", CHECKPOINT_FP8_FILE,
        "--sample-n-frames", "129",
        "--seed", "128",
        "--image-size", "704",
        "--cfg-scale", "7.5",
        "--infer-steps", "50",
        "--use-deepcache", "1",
        "--flow-shift-eval-video", "5.0",
        "--save-path", OUTPUT_BASEPATH,
        "--use-fp8",
        "--cpu-offload",
        "--infer-min"
    ]

    env = os.environ.copy()
    env["PYTHONPATH"] = "./"
    env["MODEL_BASE"] = WEIGHTS_DIR
    env["CPU_OFFLOAD"] = "1"
    env["CUDA_VISIBLE_DEVICES"] = "0"

    proc = subprocess.run(cmd, env=env)
    if proc.returncode != 0:
        print("❌ sample_gpu_poor.py failed.")
        sys.exit(1)
    print("βœ… sample_gpu_poor.py completed successfully.")

def run_flask_audio():
    print("πŸš€ Starting flask_audio.py...")
    cmd = [
        "torchrun",
        "--nnodes=1",
        "--nproc_per_node=8",
        "--master_port=29605",
        "hymm_gradio/flask_audio.py",
        "--input", "assets/test.csv",
        "--ckpt", CHECKPOINT_FILE,
        "--sample-n-frames", "129",
        "--seed", "128",
        "--image-size", "704",
        "--cfg-scale", "7.5",
        "--infer-steps", "50",
        "--use-deepcache", "1",
        "--flow-shift-eval-video", "5.0"
    ]
    subprocess.Popen(cmd)

def run_gradio_ui():
    print("🟒 Starting gradio_audio.py UI...")
    cmd = ["python3", "hymm_gradio/gradio_audio.py"]
    subprocess.Popen(cmd)

def main():
    if os.path.isfile(CHECKPOINT_FILE) and os.path.isfile(CHECKPOINT_FP8_FILE):
        print("βœ… Model checkpoint already exists. Skipping download.")
    else:
        download_model()

    run_sample_gpu_poor()

    # Optional: Start Flask and Gradio UIs after poor sample run
    run_flask_audio()
    time.sleep(5)
    run_gradio_ui()

if __name__ == "__main__":
    main()