KingNish commited on
Commit
d21cf89
·
verified ·
1 Parent(s): 4c3deb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -14
app.py CHANGED
@@ -97,10 +97,13 @@ model = AutoModelForCausalLM.from_pretrained(
97
  attn_implementation="flash_attention_2",
98
  ).to(device)
99
  model.eval()
100
- try:
101
- model = torch.compile(model)
102
- except Exception as e:
103
- print("torch.compile skipped for Stage1 model:", e)
 
 
 
104
 
105
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
106
 
@@ -115,10 +118,6 @@ codec_model = codec_class(**model_config.generator.config).to(device)
115
  parameter_dict = torch.load(RESUME_PATH, map_location="cpu")
116
  codec_model.load_state_dict(parameter_dict["codec_model"])
117
  codec_model.eval()
118
- try:
119
- codec_model = torch.compile(codec_model)
120
- except Exception as e:
121
- print("torch.compile skipped for codec_model:", e)
122
 
123
  # Precompile regex for splitting lyrics
124
  LYRICS_PATTERN = re.compile(r"\[(\w+)\](.*?)\n(?=\[|\Z)", re.DOTALL)
@@ -409,12 +408,7 @@ def generate_music(
409
 
410
  # ---------------- Stage 2: Refinement/Upsampling ----------------
411
  print("Stage 2 inference...")
412
- model_stage2 = AutoModelForCausalLM.from_pretrained(
413
- STAGE2_MODEL,
414
- torch_dtype=torch.float16,
415
- attn_implementation="flash_attention_2",
416
- ).to(device)
417
- model_stage2.eval()
418
  stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=STAGE2_BATCH_SIZE)
419
  print("Stage 2 inference completed.")
420
 
 
97
  attn_implementation="flash_attention_2",
98
  ).to(device)
99
  model.eval()
100
+
101
+ model_stage2 = AutoModelForCausalLM.from_pretrained(
102
+ STAGE2_MODEL,
103
+ torch_dtype=torch.float16,
104
+ attn_implementation="flash_attention_2",
105
+ ).to(device)
106
+ model_stage2.eval()
107
 
108
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
109
 
 
118
  parameter_dict = torch.load(RESUME_PATH, map_location="cpu")
119
  codec_model.load_state_dict(parameter_dict["codec_model"])
120
  codec_model.eval()
 
 
 
 
121
 
122
  # Precompile regex for splitting lyrics
123
  LYRICS_PATTERN = re.compile(r"\[(\w+)\](.*?)\n(?=\[|\Z)", re.DOTALL)
 
408
 
409
  # ---------------- Stage 2: Refinement/Upsampling ----------------
410
  print("Stage 2 inference...")
411
+
 
 
 
 
 
412
  stage2_result = stage2_inference(model_stage2, stage1_output_set, stage2_output_dir, batch_size=STAGE2_BATCH_SIZE)
413
  print("Stage 2 inference completed.")
414