File size: 3,300 Bytes
c1f7300
0e7cb07
d491e94
093967b
6810bbb
c1f7300
0e7cb07
 
 
c1f7300
0e7cb07
6810bbb
0e7cb07
 
edc3608
0e7cb07
 
 
 
 
 
 
 
 
edc3608
0e7cb07
 
 
 
 
 
 
 
d491e94
0e7cb07
41394ac
0e7cb07
 
 
 
 
41394ac
0e7cb07
41394ac
 
0e7cb07
 
41394ac
 
 
 
 
 
 
0e7cb07
41394ac
 
 
 
 
 
 
0e7cb07
 
41394ac
 
035f115
 
41394ac
0e7cb07
 
035f115
41394ac
0e7cb07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41394ac
0e7cb07
 
 
41394ac
 
0e7cb07
 
 
 
 
 
 
 
41394ac
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import sys
import subprocess
import time
from huggingface_hub import hf_hub_download

BASE_DIR = os.getcwd()
WEIGHTS_DIR = os.path.join(BASE_DIR, "weights")
OUTPUT_BASEPATH = os.path.join(BASE_DIR, "results-poor")

# Download specific files from Hugging Face repo
def download_checkpoints():
    os.makedirs(WEIGHTS_DIR, exist_ok=True)
    print("⬇️  Downloading necessary checkpoint files...")

    try:
        # Download FP8 checkpoint
        checkpoint_fp8 = hf_hub_download(
            repo_id="tencent/HunyuanVideo-Avatar",
            filename="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt",
            cache_dir=WEIGHTS_DIR,
            local_dir=WEIGHTS_DIR,
            local_dir_use_symlinks=False
        )

        # Download normal checkpoint for Flask/Gradio UI
        checkpoint = hf_hub_download(
            repo_id="tencent/HunyuanVideo-Avatar",
            filename="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
            cache_dir=WEIGHTS_DIR,
            local_dir=WEIGHTS_DIR,
            local_dir_use_symlinks=False
        )

        return checkpoint, checkpoint_fp8

    except Exception as e:
        print(f"❌ Error during checkpoint download: {e}")
        sys.exit(1)

def run_sample_gpu_poor(checkpoint_fp8_path):
    print("🎬 Running sample_gpu_poor.py...")

    cmd = [
        "python3", "hymm_sp/sample_gpu_poor.py",
        "--input", "assets/test.csv",
        "--ckpt", checkpoint_fp8_path,
        "--sample-n-frames", "129",
        "--seed", "128",
        "--image-size", "704",
        "--cfg-scale", "7.5",
        "--infer-steps", "50",
        "--use-deepcache", "1",
        "--flow-shift-eval-video", "5.0",
        "--save-path", OUTPUT_BASEPATH,
        "--use-fp8",
        "--cpu-offload",
        "--infer-min"
    ]

    env = os.environ.copy()
    env["PYTHONPATH"] = "./"
    env["MODEL_BASE"] = WEIGHTS_DIR
    env["CPU_OFFLOAD"] = "1"
    env["CUDA_VISIBLE_DEVICES"] = "0"

    proc = subprocess.run(cmd, env=env)
    if proc.returncode != 0:
        print("❌ sample_gpu_poor.py failed.")
        sys.exit(1)

    print("βœ… sample_gpu_poor.py completed successfully.")

def run_flask_audio(checkpoint_path):
    print("πŸš€ Starting flask_audio.py...")
    cmd = [
        "torchrun",
        "--nnodes=1",
        "--nproc_per_node=1",
        "--master_port=29605",
        "hymm_gradio/flask_audio.py",
        "--input", "assets/test.csv",
        "--ckpt", checkpoint_path,
        "--sample-n-frames", "129",
        "--seed", "128",
        "--image-size", "704",
        "--cfg-scale", "7.5",
        "--infer-steps", "50",
        "--use-deepcache", "1",
        "--flow-shift-eval-video", "5.0"
    ]
    subprocess.Popen(cmd)

def run_gradio_ui():
    print("🟒 Starting gradio_audio.py UI...")
    cmd = ["python3", "hymm_gradio/gradio_audio.py"]
    subprocess.Popen(cmd)

def main():
    # Step 1: Download only needed files from Hugging Face repo
    checkpoint, checkpoint_fp8 = download_checkpoints()

    # Step 2: Run poor sample video generation
    run_sample_gpu_poor(checkpoint_fp8)

    # Step 3: Launch Flask + Gradio UIs
    run_flask_audio(checkpoint)
    time.sleep(5)
    run_gradio_ui()

if __name__ == "__main__":
    main()