Spaces:
Sleeping
Sleeping
cpu test
Browse files- app.py +24 -9
- diffrhythm/infer/infer.py +6 -5
app.py
CHANGED
|
@@ -22,16 +22,13 @@ from diffrhythm.infer.infer_utils import (
|
|
| 22 |
)
|
| 23 |
from diffrhythm.infer.infer import inference
|
| 24 |
|
| 25 |
-
device='
|
| 26 |
cfm, tokenizer, muq, vae = prepare_model(device)
|
| 27 |
cfm = torch.compile(cfm)
|
| 28 |
|
| 29 |
-
def infer_music(lrc, ref_audio_path, max_frames=2048, device='
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
# print(lrc_list)
|
| 33 |
-
|
| 34 |
-
# return "./gift_of_the_world.wav"
|
| 35 |
lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
|
| 36 |
style_prompt = get_style_prompt(muq, ref_audio_path)
|
| 37 |
negative_style_prompt = get_negative_style_prompt(device)
|
|
@@ -43,6 +40,8 @@ def infer_music(lrc, ref_audio_path, max_frames=2048, device='cuda'):
|
|
| 43 |
duration=max_frames,
|
| 44 |
style_prompt=style_prompt,
|
| 45 |
negative_style_prompt=negative_style_prompt,
|
|
|
|
|
|
|
| 46 |
start_time=start_time
|
| 47 |
)
|
| 48 |
return generated_song
|
|
@@ -150,6 +149,22 @@ with gr.Blocks(css=css) as demo:
|
|
| 150 |
audio_prompt = gr.Audio(label="Audio Prompt", type="filepath")
|
| 151 |
|
| 152 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
lyrics_btn = gr.Button("Submit", variant="primary")
|
| 154 |
audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
|
| 155 |
|
|
@@ -210,7 +225,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 210 |
[01:24.20]Your laughter spins aurora threads
|
| 211 |
[01:28.65]Weaving dawn through featherbed"""]
|
| 212 |
],
|
| 213 |
-
inputs=[lrc],
|
| 214 |
label="Lrc Examples",
|
| 215 |
examples_per_page=2
|
| 216 |
)
|
|
@@ -306,7 +321,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 306 |
|
| 307 |
lyrics_btn.click(
|
| 308 |
fn=infer_music,
|
| 309 |
-
inputs=[lrc, audio_prompt],
|
| 310 |
outputs=audio_output
|
| 311 |
)
|
| 312 |
|
|
|
|
| 22 |
)
|
| 23 |
from diffrhythm.infer.infer import inference
|
| 24 |
|
| 25 |
+
device='cpu'
|
| 26 |
cfm, tokenizer, muq, vae = prepare_model(device)
|
| 27 |
cfm = torch.compile(cfm)
|
| 28 |
|
| 29 |
+
def infer_music(lrc, ref_audio_path, steps, sway_sampling_coef_bool, max_frames=2048, device='cpu'):
|
| 30 |
+
|
| 31 |
+
sway_sampling_coef = -1 if sway_sampling_coef_bool else None
|
|
|
|
|
|
|
|
|
|
| 32 |
lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
|
| 33 |
style_prompt = get_style_prompt(muq, ref_audio_path)
|
| 34 |
negative_style_prompt = get_negative_style_prompt(device)
|
|
|
|
| 40 |
duration=max_frames,
|
| 41 |
style_prompt=style_prompt,
|
| 42 |
negative_style_prompt=negative_style_prompt,
|
| 43 |
+
steps=steps,
|
| 44 |
+
sway_sampling_coef=sway_sampling_coef,
|
| 45 |
start_time=start_time
|
| 46 |
)
|
| 47 |
return generated_song
|
|
|
|
| 149 |
audio_prompt = gr.Audio(label="Audio Prompt", type="filepath")
|
| 150 |
|
| 151 |
with gr.Column():
|
| 152 |
+
steps = gr.Slider(
|
| 153 |
+
minimum=10,
|
| 154 |
+
maximum=40,
|
| 155 |
+
value=32,
|
| 156 |
+
step=1,
|
| 157 |
+
label="Diffusion Steps",
|
| 158 |
+
interactive=True,
|
| 159 |
+
elem_id="step_slider"
|
| 160 |
+
)
|
| 161 |
+
sway_sampling_coef_bool = gr.Radio(
|
| 162 |
+
choices=[("False", False), ("True", True)],
|
| 163 |
+
label="Use sway_sampling_coef",
|
| 164 |
+
value=False,
|
| 165 |
+
interactive=True,
|
| 166 |
+
elem_classes="horizontal-radio"
|
| 167 |
+
)
|
| 168 |
lyrics_btn = gr.Button("Submit", variant="primary")
|
| 169 |
audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
|
| 170 |
|
|
|
|
| 225 |
[01:24.20]Your laughter spins aurora threads
|
| 226 |
[01:28.65]Weaving dawn through featherbed"""]
|
| 227 |
],
|
| 228 |
+
inputs=[lrc],
|
| 229 |
label="Lrc Examples",
|
| 230 |
examples_per_page=2
|
| 231 |
)
|
|
|
|
| 321 |
|
| 322 |
lyrics_btn.click(
|
| 323 |
fn=infer_music,
|
| 324 |
+
inputs=[lrc, audio_prompt, steps, sway_sampling_coef_bool],
|
| 325 |
outputs=audio_output
|
| 326 |
)
|
| 327 |
|
diffrhythm/infer/infer.py
CHANGED
|
@@ -72,7 +72,7 @@ def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
|
|
| 72 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
| 73 |
return y_final
|
| 74 |
|
| 75 |
-
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt,
|
| 76 |
# import pdb; pdb.set_trace()
|
| 77 |
with torch.inference_mode():
|
| 78 |
generated, _ = cfm_model.sample(
|
|
@@ -81,8 +81,9 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
|
|
| 81 |
duration=duration,
|
| 82 |
style_prompt=style_prompt,
|
| 83 |
negative_style_prompt=negative_style_prompt,
|
| 84 |
-
steps=
|
| 85 |
cfg_strength=4.0,
|
|
|
|
| 86 |
start_time=start_time
|
| 87 |
)
|
| 88 |
|
|
@@ -100,10 +101,10 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
|
|
| 100 |
|
| 101 |
if __name__ == "__main__":
|
| 102 |
parser = argparse.ArgumentParser()
|
| 103 |
-
parser.add_argument('--lrc-path', type=str, default="
|
| 104 |
-
parser.add_argument('--ref-audio-path', type=str, default="
|
| 105 |
parser.add_argument('--audio-length', type=int, default=95) # length of target song
|
| 106 |
-
parser.add_argument('--output-dir', type=str, default="
|
| 107 |
args = parser.parse_args()
|
| 108 |
|
| 109 |
device = 'cuda'
|
|
|
|
| 72 |
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
|
| 73 |
return y_final
|
| 74 |
|
| 75 |
+
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time):
|
| 76 |
# import pdb; pdb.set_trace()
|
| 77 |
with torch.inference_mode():
|
| 78 |
generated, _ = cfm_model.sample(
|
|
|
|
| 81 |
duration=duration,
|
| 82 |
style_prompt=style_prompt,
|
| 83 |
negative_style_prompt=negative_style_prompt,
|
| 84 |
+
steps=steps,
|
| 85 |
cfg_strength=4.0,
|
| 86 |
+
sway_sampling_coef=sway_sampling_coef,
|
| 87 |
start_time=start_time
|
| 88 |
)
|
| 89 |
|
|
|
|
| 101 |
|
| 102 |
if __name__ == "__main__":
|
| 103 |
parser = argparse.ArgumentParser()
|
| 104 |
+
parser.add_argument('--lrc-path', type=str, default="example/eg.lrc") # lyrics of target song
|
| 105 |
+
parser.add_argument('--ref-audio-path', type=str, default="example/eg.mp3") # reference audio as style prompt for target song
|
| 106 |
parser.add_argument('--audio-length', type=int, default=95) # length of target song
|
| 107 |
+
parser.add_argument('--output-dir', type=str, default="example/output")
|
| 108 |
args = parser.parse_args()
|
| 109 |
|
| 110 |
device = 'cuda'
|