Surn commited on
Commit
439c593
·
1 Parent(s): 2e5d66f

Change Garbage collection

Browse files
Files changed (2) hide show
  1. app.py +37 -21
  2. audiocraft/data/audio_utils.py +1 -1
app.py CHANGED
@@ -95,26 +95,34 @@ def get_waveform(*args, **kwargs):
95
  return out
96
 
97
 
98
- def load_model(version):
99
  global MODEL, MODELS, UNLOAD_MODEL
100
  print("Loading model", version)
101
- if MODELS is None:
102
- return MusicGen.get_pretrained(version)
103
- else:
104
- t1 = time.monotonic()
105
- if MODEL is not None:
106
- MODEL.to('cpu') # move to cache
107
- print("Previous model moved to CPU in %.2fs" % (time.monotonic() - t1))
108
- t1 = time.monotonic()
109
- if MODELS.get(version) is None:
110
- print("Loading model %s from disk" % version)
111
  result = MusicGen.get_pretrained(version)
112
- MODELS[version] = result
113
- print("Model loaded in %.2fs" % (time.monotonic() - t1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return result
115
- result = MODELS[version].to('cuda')
116
- print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
117
- return result
118
 
119
  def get_melody(melody_filepath):
120
  audio_data= list(librosa.load(melody_filepath, sr=None))
@@ -188,7 +196,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
188
  INTERRUPTED = False
189
  INTERRUPTING = False
190
  if temperature < 0:
191
- temperature -0
192
  raise gr.Error("Temperature must be >= 0.")
193
  if topk < 0:
194
  topk = 1
@@ -197,8 +205,16 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
197
  topp =1
198
  raise gr.Error("Topp must be non-negative.")
199
 
 
 
 
 
 
 
 
 
200
  try:
201
- if MODEL is None or MODEL.name != model:
202
  MODEL = load_model(model)
203
  else:
204
  if MOVE_TO_CPU:
@@ -433,12 +449,12 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
433
  del output_segments, output, melody, melody_name, melody_extension, metadata, mp4
434
 
435
  # Force garbage collection
436
- gc.collect()
437
 
438
  # Synchronize CUDA streams
439
  torch.cuda.synchronize()
440
 
441
- torch.cuda.empty_cache()
442
  torch.cuda.ipc_collect()
443
  return waveform_video_path, file.name, seed
444
 
@@ -556,7 +572,7 @@ def ui(**kwargs):
556
  3.75
557
  ],
558
  [
559
- "4/4 120bpm 320kbps 48khz, a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
560
  "./assets/bach.mp3",
561
  "melody-large",
562
  "EDM my Bach",
 
95
  return out
96
 
97
 
98
+ def load_model(version, progress=gr.Progress(track_tqdm=True)):
99
  global MODEL, MODELS, UNLOAD_MODEL
100
  print("Loading model", version)
101
+
102
+ with tqdm(total=100, desc=f"Loading model '{version}'", unit="step") as pbar:
103
+ if MODELS is None:
104
+ pbar.update(50) # Simulate progress for loading
 
 
 
 
 
 
105
  result = MusicGen.get_pretrained(version)
106
+ pbar.update(50) # Complete progress
107
+ return result
108
+ else:
109
+ t1 = time.monotonic()
110
+ if MODEL is not None:
111
+ MODEL.to('cpu') # Move to cache
112
+ print("Previous model moved to CPU in %.2fs" % (time.monotonic() - t1))
113
+ pbar.update(30) # Simulate progress for moving model to CPU
114
+ t1 = time.monotonic()
115
+ if MODELS.get(version) is None:
116
+ print("Loading model %s from disk" % version)
117
+ result = MusicGen.get_pretrained(version)
118
+ MODELS[version] = result
119
+ print("Model loaded in %.2fs" % (time.monotonic() - t1))
120
+ pbar.update(70) # Simulate progress for loading from disk
121
+ return result
122
+ result = MODELS[version].to('cuda')
123
+ print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
124
+ pbar.update(100) # Complete progress
125
  return result
 
 
 
126
 
127
  def get_melody(melody_filepath):
128
  audio_data= list(librosa.load(melody_filepath, sr=None))
 
196
  INTERRUPTED = False
197
  INTERRUPTING = False
198
  if temperature < 0:
199
+ temperature = 0.1
200
  raise gr.Error("Temperature must be >= 0.")
201
  if topk < 0:
202
  topk = 1
 
205
  topp =1
206
  raise gr.Error("Topp must be non-negative.")
207
 
208
+ # Clean up GPU resources only if the model changes
209
+ if MODEL is not None and model not in MODEL.name:
210
+ print(f"Switching model from {MODEL.name} to {model}. Cleaning up resources.")
211
+ del MODEL # Delete the current model
212
+ torch.cuda.empty_cache() # Clear GPU memory
213
+ gc.collect() # Force garbage collection
214
+ MODEL = None
215
+
216
  try:
217
+ if MODEL is None or model not in MODEL.name:
218
  MODEL = load_model(model)
219
  else:
220
  if MOVE_TO_CPU:
 
449
  del output_segments, output, melody, melody_name, melody_extension, metadata, mp4
450
 
451
  # Force garbage collection
452
+ #gc.collect()
453
 
454
  # Synchronize CUDA streams
455
  torch.cuda.synchronize()
456
 
457
+ #torch.cuda.empty_cache()
458
  torch.cuda.ipc_collect()
459
  return waveform_video_path, file.name, seed
460
 
 
572
  3.75
573
  ],
574
  [
575
+ "4/4 120bpm 320kbps 48khz, a light and cheery EDM track, with syncopated drums, aery pads, and strong emotions",
576
  "./assets/bach.mp3",
577
  "melody-large",
578
  "EDM my Bach",
audiocraft/data/audio_utils.py CHANGED
@@ -200,7 +200,7 @@ def apply_tafade(audio: torch.Tensor, sample_rate, duration=3.0, out=True, start
200
  if out:
201
  fade_transform.fade_out_len = fade_samples
202
  else:
203
- fade_transform.fade_in_len = fade_samples
204
 
205
  # Select the portion of the audio to apply the fade
206
  if start:
 
200
  if out:
201
  fade_transform.fade_out_len = fade_samples
202
  else:
203
+ fade_transform.fade_in_len = fade_samples
204
 
205
  # Select the portion of the audio to apply the fade
206
  if start: