rahul7star commited on
Commit
f7aef4e
Β·
verified Β·
1 Parent(s): 15ae5bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -75
app.py CHANGED
@@ -1,78 +1,57 @@
 
1
  import os
2
  import sys
3
- from huggingface_hub import hf_hub_download, list_repo_files
4
  import subprocess
 
 
5
 
6
  MODEL_REPO = "tencent/HunyuanVideo-Avatar"
7
  BASE_DIR = os.getcwd()
8
  WEIGHTS_DIR = os.path.join(BASE_DIR, "weights")
9
  OUTPUT_BASEPATH = os.path.join(BASE_DIR, "results-poor")
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Only download this specific file from transformers
13
- ESSENTIAL_PATHS = [
14
- "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt",
15
- ]
16
-
17
- # Download everything from these folders
18
- FULL_DIRS = [
19
- "hunyuan-video-t2v-720p/vae",
20
- #"llava_llama_image",
21
- "text_encoder_2",
22
- "whisper-tiny",
23
- "det_align",
24
- ]
25
-
26
-
27
- def list_ckpt_files():
28
- """Return a list of files to download: specific files + all from some folders."""
29
- try:
30
- all_files = list_repo_files(MODEL_REPO)
31
- except Exception as e:
32
- print(f"❌ Failed to list files from {MODEL_REPO}: {e}")
33
- return []
34
-
35
- files_to_download = ESSENTIAL_PATHS.copy()
36
- for path in all_files:
37
- if any(path.startswith(folder + "/") for folder in FULL_DIRS):
38
- files_to_download.append(path)
39
-
40
- return files_to_download
41
-
42
-
43
- def download_ckpts(files):
44
- """Download selected files to local directory, preserving structure."""
45
- for file_path in files:
46
- local_path = os.path.join(WEIGHTS_DIR, file_path)
47
- if os.path.exists(local_path):
48
- print(f"βœ… Already exists: {local_path}")
49
- continue
50
-
51
- print(f"⬇️ Downloading: {file_path}")
52
- try:
53
- hf_hub_download(
54
- repo_id=MODEL_REPO,
55
- filename=file_path,
56
- local_dir=WEIGHTS_DIR,
57
- local_dir_use_symlinks=False,
58
- resume_download=True,
59
- )
60
- except EntryNotFoundError:
61
- print(f"❌ Entry not found: {file_path}")
62
- except Exception as e:
63
- print(f"❌ Failed to download {file_path}: {e}")
64
 
65
  def run_sample_gpu_poor():
66
- ckpt_fp8 = os.path.join(WEIGHTS_DIR, "ckpts", "hunyuan-video-t2v-720p", "transformers", "mp_rank_00_model_states_fp8.pt")
67
-
68
- if not os.path.isfile(ckpt_fp8):
69
- print(f"❌ Missing checkpoint: {ckpt_fp8}")
70
- return
71
-
72
  cmd = [
73
  "python3", "hymm_sp/sample_gpu_poor.py",
74
  "--input", "assets/test.csv",
75
- "--ckpt", ckpt_fp8,
76
  "--sample-n-frames", "129",
77
  "--seed", "128",
78
  "--image-size", "704",
@@ -92,27 +71,49 @@ def run_sample_gpu_poor():
92
  env["CPU_OFFLOAD"] = "1"
93
  env["CUDA_VISIBLE_DEVICES"] = "0"
94
 
95
- print("πŸš€ Running sample_gpu_poor.py...")
96
- result = subprocess.run(cmd, env=env, capture_output=True, text=True)
 
 
 
97
 
98
- if result.returncode != 0:
99
- print(f"❌ sample_gpu_poor.py failed:\n{result.stderr}")
100
- else:
101
- print("βœ… sample_gpu_poor.py ran successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
 
 
 
103
 
104
  def main():
105
- files = list_ckpt_files()
106
- if not files:
107
- print("❌ No checkpoint files found in repo under ckpts folder.")
108
- sys.exit(1)
109
 
110
- download_ckpts(files)
111
  run_sample_gpu_poor()
112
 
 
 
 
 
113
 
114
  if __name__ == "__main__":
115
- main()
116
-
117
-
118
-
 
1
+
2
  import os
3
  import sys
 
4
  import subprocess
5
+ import time
6
+ from huggingface_hub import snapshot_download
7
 
8
  MODEL_REPO = "tencent/HunyuanVideo-Avatar"
9
  BASE_DIR = os.getcwd()
10
  WEIGHTS_DIR = os.path.join(BASE_DIR, "weights")
11
  OUTPUT_BASEPATH = os.path.join(BASE_DIR, "results-poor")
12
 
13
+ # Specific checkpoint to use in the poor sampling run
14
+ CHECKPOINT_FILE = os.path.join(
15
+ WEIGHTS_DIR,
16
+ "ckpts",
17
+ "hunyuan-video-t2v-720p",
18
+ "transformers",
19
+ "mp_rank_00_model_states.pt"
20
+ )
21
+ CHECKPOINT_FP8_FILE = os.path.join(
22
+ WEIGHTS_DIR,
23
+ "ckpts",
24
+ "hunyuan-video-t2v-720p",
25
+ "transformers",
26
+ "mp_rank_00_model_states_fp8.pt"
27
+ )
28
+
29
+ def download_model():
30
+ print("⬇️ Model not found. Downloading with snapshot_download into weights directory...")
31
+ os.makedirs(WEIGHTS_DIR, exist_ok=True)
32
+
33
+ snapshot_download(
34
+ repo_id=MODEL_REPO,
35
+ local_dir=WEIGHTS_DIR,
36
+ local_dir_use_symlinks=False
37
+ )
38
+
39
+ if not os.path.isfile(CHECKPOINT_FILE):
40
+ print(f"❌ Checkpoint file not found at {CHECKPOINT_FILE} after download.")
41
+ sys.exit(1)
42
+
43
+ if not os.path.isfile(CHECKPOINT_FP8_FILE):
44
+ print(f"❌ FP8 checkpoint file not found at {CHECKPOINT_FP8_FILE}. Cannot proceed with sample_gpu_poor.py.")
45
+ sys.exit(1)
46
 
47
+ print("βœ… Model downloaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def run_sample_gpu_poor():
50
+ print("🎬 Running sample_gpu_poor.py...")
 
 
 
 
 
51
  cmd = [
52
  "python3", "hymm_sp/sample_gpu_poor.py",
53
  "--input", "assets/test.csv",
54
+ "--ckpt", CHECKPOINT_FP8_FILE,
55
  "--sample-n-frames", "129",
56
  "--seed", "128",
57
  "--image-size", "704",
 
71
  env["CPU_OFFLOAD"] = "1"
72
  env["CUDA_VISIBLE_DEVICES"] = "0"
73
 
74
+ proc = subprocess.run(cmd, env=env)
75
+ if proc.returncode != 0:
76
+ print("❌ sample_gpu_poor.py failed.")
77
+ sys.exit(1)
78
+ print("βœ… sample_gpu_poor.py completed successfully.")
79
 
80
+ def run_flask_audio():
81
+ print("πŸš€ Starting flask_audio.py...")
82
+ cmd = [
83
+ "torchrun",
84
+ "--nnodes=1",
85
+ "--nproc_per_node=8",
86
+ "--master_port=29605",
87
+ "hymm_gradio/flask_audio.py",
88
+ "--input", "assets/test.csv",
89
+ "--ckpt", CHECKPOINT_FILE,
90
+ "--sample-n-frames", "129",
91
+ "--seed", "128",
92
+ "--image-size", "704",
93
+ "--cfg-scale", "7.5",
94
+ "--infer-steps", "50",
95
+ "--use-deepcache", "1",
96
+ "--flow-shift-eval-video", "5.0"
97
+ ]
98
+ subprocess.Popen(cmd)
99
 
100
+ def run_gradio_ui():
101
+ print("🟒 Starting gradio_audio.py UI...")
102
+ cmd = ["python3", "hymm_gradio/gradio_audio.py"]
103
+ subprocess.Popen(cmd)
104
 
105
  def main():
106
+ if os.path.isfile(CHECKPOINT_FILE) and os.path.isfile(CHECKPOINT_FP8_FILE):
107
+ print("βœ… Model checkpoint already exists. Skipping download.")
108
+ else:
109
+ download_model()
110
 
 
111
  run_sample_gpu_poor()
112
 
113
+ # Optional: Start Flask and Gradio UIs after poor sample run
114
+ run_flask_audio()
115
+ time.sleep(5)
116
+ run_gradio_ui()
117
 
118
  if __name__ == "__main__":
119
+ main()