cocktailpeanut commited on
Commit
ba4f650
Β·
1 Parent(s): 6f225e3
Files changed (2) hide show
  1. app.py +5 -4
  2. requirements.txt +3 -3
app.py CHANGED
@@ -22,13 +22,14 @@ from diffrhythm.infer.infer_utils import (
22
  get_negative_style_prompt
23
  )
24
  from diffrhythm.infer.infer import inference
 
25
 
26
 
27
- device='cuda'
28
  cfm, tokenizer, muq, vae = prepare_model(device)
29
  cfm = torch.compile(cfm)
30
 
31
- @spaces.GPU
32
  def infer_music(lrc, ref_audio_path, steps, file_type, max_frames=2048, device='cuda'):
33
 
34
  sway_sampling_coef = -1 if steps < 32 else None
@@ -326,7 +327,7 @@ with gr.Blocks(css=css) as demo:
326
 
327
  lyrics_btn.click(
328
  fn=infer_music,
329
- inputs=[lrc, audio_prompt, steps, file_type],
330
  outputs=audio_output
331
  )
332
 
@@ -336,4 +337,4 @@ demo.queue().launch(show_api=False, show_error=True)
336
 
337
 
338
  if __name__ == "__main__":
339
- demo.launch()
 
22
  get_negative_style_prompt
23
  )
24
  from diffrhythm.infer.infer import inference
25
+ import devicetorch
26
 
27
 
28
+ device=devicetorch.get(torch)
29
  cfm, tokenizer, muq, vae = prepare_model(device)
30
  cfm = torch.compile(cfm)
31
 
32
+ #@spaces.GPU
33
  def infer_music(lrc, ref_audio_path, steps, file_type, max_frames=2048, device='cuda'):
34
 
35
  sway_sampling_coef = -1 if steps < 32 else None
 
327
 
328
  lyrics_btn.click(
329
  fn=infer_music,
330
+ inputs=[lrc, audio_prompt, steps, file_type, device],
331
  outputs=audio_output
332
  )
333
 
 
337
 
338
 
339
  if __name__ == "__main__":
340
+ demo.launch()
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  gradio==5.20.0
2
  accelerate==1.4.0
3
  inflect==7.5.0
4
- torchdiffeq==0.2.5
5
- torchaudio==2.6.0
6
  x-transformers==2.1.2
7
  transformers==4.49.0
8
  librosa==0.10.2.post1
@@ -30,4 +30,4 @@ einops==0.8.1
30
  lazy_loader==0.4
31
  scipy==1.15.2
32
  ftfy==6.3.1
33
- torchdiffeq==0.2.5
 
1
  gradio==5.20.0
2
  accelerate==1.4.0
3
  inflect==7.5.0
4
+ #torchdiffeq==0.2.5
5
+ #torchaudio==2.6.0
6
  x-transformers==2.1.2
7
  transformers==4.49.0
8
  librosa==0.10.2.post1
 
30
  lazy_loader==0.4
31
  scipy==1.15.2
32
  ftfy==6.3.1
33
+ torchdiffeq==0.2.5