Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -97,10 +97,13 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
97 |
attn_implementation="flash_attention_2",
|
98 |
).to(device)
|
99 |
model.eval()
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
104 |
|
105 |
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
106 |
|
@@ -115,10 +118,6 @@ codec_model = codec_class(**model_config.generator.config).to(device)
|
|
115 |
parameter_dict = torch.load(RESUME_PATH, map_location="cpu")
|
116 |
codec_model.load_state_dict(parameter_dict["codec_model"])
|
117 |
codec_model.eval()
|
118 |
-
try:
|
119 |
-
codec_model = torch.compile(codec_model)
|
120 |
-
except Exception as e:
|
121 |
-
print("torch.compile skipped for codec_model:", e)
|
122 |
|
123 |
# Precompile regex for splitting lyrics
|
124 |
LYRICS_PATTERN = re.compile(r"\[(\w+)\](.*?)\n(?=\[|\Z)", re.DOTALL)
|
@@ -409,12 +408,7 @@ def generate_music(
|
|
409 |
|
410 |
# ---------------- Stage 2: Refinement/Upsampling ----------------
|
411 |
print("Stage 2 inference...")
|
412 |
-
|
413 |
-
STAGE2_MODEL,
|
414 |
-
torch_dtype=torch.float16,
|
415 |
-
attn_implementation="flash_attention_2",
|
416 |
-
).to(device)
|
417 |
-
model_stage2.eval()
|
418 |
stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=STAGE2_BATCH_SIZE)
|
419 |
print("Stage 2 inference completed.")
|
420 |
|
|
|
97 |
attn_implementation="flash_attention_2",
|
98 |
).to(device)
|
99 |
model.eval()
|
100 |
+
|
101 |
+
model_stage2 = AutoModelForCausalLM.from_pretrained(
|
102 |
+
STAGE2_MODEL,
|
103 |
+
torch_dtype=torch.float16,
|
104 |
+
attn_implementation="flash_attention_2",
|
105 |
+
).to(device)
|
106 |
+
model_stage2.eval()
|
107 |
|
108 |
mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
|
109 |
|
|
|
118 |
parameter_dict = torch.load(RESUME_PATH, map_location="cpu")
|
119 |
codec_model.load_state_dict(parameter_dict["codec_model"])
|
120 |
codec_model.eval()
|
|
|
|
|
|
|
|
|
121 |
|
122 |
# Precompile regex for splitting lyrics
|
123 |
LYRICS_PATTERN = re.compile(r"\[(\w+)\](.*?)\n(?=\[|\Z)", re.DOTALL)
|
|
|
408 |
|
409 |
# ---------------- Stage 2: Refinement/Upsampling ----------------
|
410 |
print("Stage 2 inference...")
|
411 |
+
|
|
|
|
|
|
|
|
|
|
|
412 |
stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=STAGE2_BATCH_SIZE)
|
413 |
print("Stage 2 inference completed.")
|
414 |
|