Surn commited on
Commit
8fcd249
·
1 Parent(s): 28a61b8

Progress Bars Update

Browse files
app.py CHANGED
@@ -17,6 +17,7 @@ from pathlib import Path
17
  import time
18
  import typing as tp
19
  import warnings
 
20
  from tqdm import tqdm
21
  from audiocraft.models import MusicGen
22
  from audiocraft.data.audio import audio_write
@@ -139,7 +140,7 @@ def load_melody_filepath(melody_filepath, title, assigned_model,topp, temperatur
139
  symbols = ['_', '.', '-']
140
  MAX_OVERLAP = int(segment_length // 2) - 1
141
  if (melody_filepath is None) or (melody_filepath == ""):
142
- return title, gr.update(maximum=0, value=0) , gr.update(value="medium", interactive=True), gr.update(value=topp), gr.update(value=temperature), gr.update(value=cfg_coef), gr.update(maximum=MAX_OVERLAP)
143
 
144
  if (title is None) or ("MusicGen" in title) or (title == ""):
145
  melody_name, melody_extension = get_filename_from_filepath(melody_filepath)
@@ -166,7 +167,7 @@ def load_melody_filepath(melody_filepath, title, assigned_model,topp, temperatur
166
  print(f"Melody length: {len(melody_data)}, Melody segments: {total_melodys}\n")
167
  MAX_PROMPT_INDEX = total_melodys
168
 
169
- return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=0), gr.update(value=assigned_model, interactive=True), gr.update(value=topp), gr.update(value=temperature), gr.update(value=cfg_coef), gr.update(maximum=MAX_OVERLAP)
170
 
171
  def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False, profile = gr.OAuthProfile, segment_length = 30, settings_font_size=28, progress=gr.Progress(track_tqdm=True)):
172
  global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
@@ -331,7 +332,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
331
  audio_write(
332
  file.name, output, MODEL.sample_rate, strategy="loudness",
333
  loudness_headroom_db=18, loudness_compressor=True, add_suffix=False, channels=2)
334
- waveform_video_path = get_waveform(file.name, bg_image=background, bar_count=45, name=title_file_name, animate=False)
335
  # Remove the extension from file.name
336
  file_name_without_extension = os.path.splitext(file.name)[0]
337
  # Get the directory, filename, name, extension, and new extension of the waveform video path
@@ -345,6 +346,8 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
345
 
346
  commit = commit_hash()
347
  metadata = {
 
 
348
  "prompt": text,
349
  "negative_prompt": "",
350
  "Seed": seed,
@@ -407,6 +410,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
407
  video=waveform_video_path,
408
  label=title,
409
  metadata=metadata,
 
410
  )
411
 
412
 
@@ -414,6 +418,16 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
414
  MODEL.to('cpu')
415
  if UNLOAD_MODEL:
416
  MODEL = None
 
 
 
 
 
 
 
 
 
 
417
  torch.cuda.empty_cache()
418
  torch.cuda.ipc_collect()
419
  return waveform_video_path, file.name, seed
@@ -552,7 +566,9 @@ def ui(**kwargs):
552
  )
553
 
554
  with gr.Tab("User History") as history_tab:
 
555
  modules.user_history.render()
 
556
  user_profile = gr.State(None)
557
 
558
  with gr.Row("Versions") as versions_row:
 
17
  import time
18
  import typing as tp
19
  import warnings
20
+ import gc
21
  from tqdm import tqdm
22
  from audiocraft.models import MusicGen
23
  from audiocraft.data.audio import audio_write
 
140
  symbols = ['_', '.', '-']
141
  MAX_OVERLAP = int(segment_length // 2) - 1
142
  if (melody_filepath is None) or (melody_filepath == ""):
143
+ return title, gr.update(maximum=0, value=-1) , gr.update(value="medium", interactive=True), gr.update(value=topp), gr.update(value=temperature), gr.update(value=cfg_coef), gr.update(maximum=MAX_OVERLAP)
144
 
145
  if (title is None) or ("MusicGen" in title) or (title == ""):
146
  melody_name, melody_extension = get_filename_from_filepath(melody_filepath)
 
167
  print(f"Melody length: {len(melody_data)}, Melody segments: {total_melodys}\n")
168
  MAX_PROMPT_INDEX = total_melodys
169
 
170
+ return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=-1), gr.update(value=assigned_model, interactive=True), gr.update(value=topp), gr.update(value=temperature), gr.update(value=cfg_coef), gr.update(maximum=MAX_OVERLAP)
171
 
172
  def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False, profile = gr.OAuthProfile, segment_length = 30, settings_font_size=28, progress=gr.Progress(track_tqdm=True)):
173
  global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
 
332
  audio_write(
333
  file.name, output, MODEL.sample_rate, strategy="loudness",
334
  loudness_headroom_db=18, loudness_compressor=True, add_suffix=False, channels=2)
335
+ waveform_video_path = get_waveform(file.name, bg_image=background, bar_count=45, name=title_file_name, animate=False, progress=gr.Progress(track_tqdm=True))
336
  # Remove the extension from file.name
337
  file_name_without_extension = os.path.splitext(file.name)[0]
338
  # Get the directory, filename, name, extension, and new extension of the waveform video path
 
346
 
347
  commit = commit_hash()
348
  metadata = {
349
+ "Title": title,
350
+ "Year": time.strftime("%Y"),
351
  "prompt": text,
352
  "negative_prompt": "",
353
  "Seed": seed,
 
410
  video=waveform_video_path,
411
  label=title,
412
  metadata=metadata,
413
+ progress=gr.Progress(track_tqdm=True)
414
  )
415
 
416
 
 
418
  MODEL.to('cpu')
419
  if UNLOAD_MODEL:
420
  MODEL = None
421
+
422
+ # Explicitly delete large tensors or objects
423
+ del output_segments, output, melody, melody_name, melody_extension, metadata, mp4
424
+
425
+ # Force garbage collection
426
+ gc.collect()
427
+
428
+ # Synchronize CUDA streams
429
+ torch.cuda.synchronize()
430
+
431
  torch.cuda.empty_cache()
432
  torch.cuda.ipc_collect()
433
  return waveform_video_path, file.name, seed
 
566
  )
567
 
568
  with gr.Tab("User History") as history_tab:
569
+ modules.user_history.setup(display_type="video_path")
570
  modules.user_history.render()
571
+
572
  user_profile = gr.State(None)
573
 
574
  with gr.Row("Versions") as versions_row:
audiocraft/models/musicgen.py CHANGED
@@ -411,8 +411,8 @@ class MusicGen:
411
 
412
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
413
  generated_tokens += current_gen_offset
414
- generated_tokens /= 50
415
- tokens_to_generate /= 50
416
  if self._progress_callback is not None:
417
  # Note that total_gen_len might be quite wrong depending on the
418
  # codebook pattern used, but with delay it is almost accurate.
 
411
 
412
  def _progress_callback(generated_tokens: int, tokens_to_generate: int):
413
  generated_tokens += current_gen_offset
414
+ generated_tokens /= ((tokens_to_generate - 3) / self.duration)
415
+ tokens_to_generate /= ((tokens_to_generate - 3) / self.duration)
416
  if self._progress_callback is not None:
417
  # Note that total_gen_len might be quite wrong depending on the
418
  # codebook pattern used, but with delay it is almost accurate.
audiocraft/utils/extend.py CHANGED
@@ -14,6 +14,7 @@ from huggingface_hub import hf_hub_download
14
  import librosa
15
  import gradio as gr
16
  import re
 
17
 
18
 
19
  INTERRUPTING = False
@@ -72,6 +73,7 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
72
  excess_duration = segment_duration - (total_segments * segment_duration - duration)
73
  print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration} Overlap Loss {duration_loss}")
74
  duration += duration_loss
 
75
  while excess_duration + duration_loss > segment_duration:
76
  total_segments += 1
77
  #calculate duration loss from segment overlap
@@ -82,6 +84,7 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
82
  if excess_duration + duration_loss > segment_duration:
83
  duration += duration_loss
84
  duration_loss = 0
 
85
  total_segments = min(total_segments, (720 // segment_duration))
86
 
87
  # If melody_segments is shorter than total_segments, repeat the segments until the total_segments is reached
@@ -90,6 +93,7 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
90
  for i in range(total_segments - len(melody_segments)):
91
  segment = melody_segments[i]
92
  melody_segments.append(segment)
 
93
  print(f"melody_segments: {len(melody_segments)} fixed")
94
 
95
  # Iterate over the segments to create list of Meldoy tensors
@@ -116,7 +120,8 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
116
 
117
  # Append the segment to the melodys list
118
  melodys.append(verse)
119
-
 
120
  torch.manual_seed(seed)
121
 
122
  # If user selects a prompt segment, generate a new prompt segment to use on all segments
@@ -147,7 +152,7 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
147
  prompt=None,
148
  )
149
 
150
- for idx, verse in enumerate(melodys):
151
  if INTERRUPTING:
152
  return output_segments, duration
153
 
 
14
  import librosa
15
  import gradio as gr
16
  import re
17
+ from tqdm import tqdm
18
 
19
 
20
  INTERRUPTING = False
 
73
  excess_duration = segment_duration - (total_segments * segment_duration - duration)
74
  print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration} Overlap Loss {duration_loss}")
75
  duration += duration_loss
76
+ pbar = tqdm(total=total_segments*2, desc="Generating segments", leave=False)
77
  while excess_duration + duration_loss > segment_duration:
78
  total_segments += 1
79
  #calculate duration loss from segment overlap
 
84
  if excess_duration + duration_loss > segment_duration:
85
  duration += duration_loss
86
  duration_loss = 0
87
+ pbar.update(1)
88
  total_segments = min(total_segments, (720 // segment_duration))
89
 
90
  # If melody_segments is shorter than total_segments, repeat the segments until the total_segments is reached
 
93
  for i in range(total_segments - len(melody_segments)):
94
  segment = melody_segments[i]
95
  melody_segments.append(segment)
96
+ pbar.update(1)
97
  print(f"melody_segments: {len(melody_segments)} fixed")
98
 
99
  # Iterate over the segments to create list of Meldoy tensors
 
120
 
121
  # Append the segment to the melodys list
122
  melodys.append(verse)
123
+ pbar.update(1)
124
+ pbar.close()
125
  torch.manual_seed(seed)
126
 
127
  # If user selects a prompt segment, generate a new prompt segment to use on all segments
 
152
  prompt=None,
153
  )
154
 
155
+ for idx, verse in tqdm(enumerate(melodys), total=len(melodys), desc="Generating melody segments"):
156
  if INTERRUPTING:
157
  return output_segments, duration
158
 
modules/gradio.py CHANGED
@@ -9,6 +9,7 @@ import shutil
9
  import subprocess
10
  from tempfile import NamedTemporaryFile
11
  from pathlib import Path
 
12
 
13
 
14
  class MatplotlibBackendMananger:
@@ -42,6 +43,7 @@ def make_waveform(
42
  bar_width: float = 0.6,
43
  animate: bool = False,
44
  name: str = "",
 
45
  ) -> str:
46
  """
47
  Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component.
 
9
  import subprocess
10
  from tempfile import NamedTemporaryFile
11
  from pathlib import Path
12
+ from tqdm import tqdm
13
 
14
 
15
  class MatplotlibBackendMananger:
 
43
  bar_width: float = 0.6,
44
  animate: bool = False,
45
  name: str = "",
46
+ progress= gr.Progress(track_tqdm=True)
47
  ) -> str:
48
  """
49
  Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component.