cocktailpeanut commited on
Commit
dc6a3d5
Β·
1 Parent(s): febf9c9
Files changed (2) hide show
  1. app.py +4 -0
  2. diffrhythm/infer/infer.py +10 -1
app.py CHANGED
@@ -9,6 +9,7 @@ from einops import rearrange
9
  import argparse
10
  import json
11
  import os
 
12
  #import spaces
13
  from tqdm import tqdm
14
  import random
@@ -49,6 +50,9 @@ def infer_music(lrc, ref_audio_path, steps, file_type, max_frames=2048):
49
  start_time=start_time,
50
  file_type=file_type
51
  )
 
 
 
52
  return generated_song
53
 
54
  def R1_infer1(theme, tags_gen, language):
 
9
  import argparse
10
  import json
11
  import os
12
+ import gc
13
  #import spaces
14
  from tqdm import tqdm
15
  import random
 
50
  start_time=start_time,
51
  file_type=file_type
52
  )
53
+ torch.cuda.empty_cache()
54
+ gc.collect()
55
+
56
  return generated_song
57
 
58
  def R1_infer1(theme, tags_gen, language):
diffrhythm/infer/infer.py CHANGED
@@ -9,6 +9,7 @@ import random
9
  import numpy as np
10
  import time
11
  import io
 
12
  import pydub
13
 
14
  from diffrhythm.infer.infer_utils import (
@@ -88,11 +89,19 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative
88
  sway_sampling_coef=sway_sampling_coef,
89
  start_time=start_time
90
  )
 
 
 
91
 
92
  generated = generated.to(torch.float32)
93
  latent = generated.transpose(1, 2) # [b d t]
94
  output = decode_audio(latent, vae_model, chunked=False)
95
 
 
 
 
 
 
96
  # Rearrange audio batch to a single sequence
97
  output = rearrange(output, "b d n -> d (b n)")
98
  output_tensor = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()
@@ -157,4 +166,4 @@ if __name__ == "__main__":
157
 
158
  output_path = os.path.join(output_dir, "output.wav")
159
  torchaudio.save(output_path, generated_song, sample_rate=44100)
160
-
 
9
  import numpy as np
10
  import time
11
  import io
12
+ import gc
13
  import pydub
14
 
15
  from diffrhythm.infer.infer_utils import (
 
89
  sway_sampling_coef=sway_sampling_coef,
90
  start_time=start_time
91
  )
92
+ torch.cuda.empty_cache()
93
+ gc.collect()
94
+
95
 
96
  generated = generated.to(torch.float32)
97
  latent = generated.transpose(1, 2) # [b d t]
98
  output = decode_audio(latent, vae_model, chunked=False)
99
 
100
+ del latent, generated
101
+ torch.cuda.empty_cache()
102
+ gc.collect()
103
+
104
+
105
  # Rearrange audio batch to a single sequence
106
  output = rearrange(output, "b d n -> d (b n)")
107
  output_tensor = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()
 
166
 
167
  output_path = os.path.join(output_dir, "output.wav")
168
  torchaudio.save(output_path, generated_song, sample_rate=44100)
169
+