cocktailpeanut commited on
Commit
793fc05
Β·
1 Parent(s): 30a6bb8
app.py CHANGED
@@ -51,7 +51,7 @@ def infer_music(lrc, ref_audio_path, steps, file_type, max_frames=2048):
51
  start_time=start_time,
52
  file_type=file_type
53
  )
54
- torch.cuda.empty_cache()
55
  gc.collect()
56
  print(">4")
57
 
@@ -207,7 +207,7 @@ with gr.Blocks(css=css) as demo:
207
  interactive=True,
208
  elem_id="step_slider"
209
  )
210
- file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
211
 
212
 
213
 
 
51
  start_time=start_time,
52
  file_type=file_type
53
  )
54
+ devicetorch.empty_cache(torch)
55
  gc.collect()
56
  print(">4")
57
 
 
207
  interactive=True,
208
  elem_id="step_slider"
209
  )
210
+ file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3")
211
 
212
 
213
 
diffrhythm/infer/infer.py CHANGED
@@ -134,7 +134,11 @@ if __name__ == "__main__":
134
  parser.add_argument('--output-dir', type=str, default="example/output")
135
  args = parser.parse_args()
136
 
137
- device = 'cuda'
 
 
 
 
138
 
139
  audio_length = args.audio_length
140
  if audio_length == 95:
 
134
  parser.add_argument('--output-dir', type=str, default="example/output")
135
  args = parser.parse_args()
136
 
137
+ device = "cpu"
138
+ if torch.cuda.is_available():
139
+ device = "cuda"
140
+ elif torch.mps.is_available():
141
+ device = "mps"
142
 
143
  audio_length = args.audio_length
144
  if audio_length == 95:
diffrhythm/infer/infer_utils.py CHANGED
@@ -169,8 +169,7 @@ def get_lrc_token(text, tokenizer, device):
169
  return lrc_emb, normalized_start_time
170
 
171
  def load_checkpoint(model, ckpt_path, device, use_ema=True):
172
- if device == "cuda":
173
- model = model.half()
174
 
175
  ckpt_type = ckpt_path.split(".")[-1]
176
  if ckpt_type == "safetensors":
 
169
  return lrc_emb, normalized_start_time
170
 
171
  def load_checkpoint(model, ckpt_path, device, use_ema=True):
172
+ model = model.half()
 
173
 
174
  ckpt_type = ckpt_path.split(".")[-1]
175
  if ckpt_type == "safetensors":
requirements.txt CHANGED
@@ -12,7 +12,8 @@ pandas==2.2.3
12
  pylance==0.23.2
13
  ema-pytorch==0.7.7
14
  prefigure==0.0.10
15
- bitsandbytes==0.45.3
 
16
  muq==0.1.0
17
  mutagen==1.47.0
18
  pyopenjtalk==0.4.0
@@ -21,7 +22,7 @@ jieba==0.42.1
21
  cn2an==0.5.23
22
  pypinyin==0.53.0
23
  #onnxruntime==1.20.1
24
- onnxruntime-gpu
25
  Unidecode==1.3.8
26
  phonemizer==3.3.0
27
  LangSegment==0.3.5
 
12
  pylance==0.23.2
13
  ema-pytorch==0.7.7
14
  prefigure==0.0.10
15
+ #bitsandbytes==0.45.3
16
+ bitsandbytes
17
  muq==0.1.0
18
  mutagen==1.47.0
19
  pyopenjtalk==0.4.0
 
22
  cn2an==0.5.23
23
  pypinyin==0.53.0
24
  #onnxruntime==1.20.1
25
+ #onnxruntime-gpu
26
  Unidecode==1.3.8
27
  phonemizer==3.3.0
28
  LangSegment==0.3.5