Curinha commited on
Commit
f8bd524
·
1 Parent(s): 5241704

Refactor sound_generator.py to load models on GPU if available and streamline audio generation process

Browse files
Files changed (1) hide show
  1. sound_generator.py +11 -13
sound_generator.py CHANGED
@@ -1,11 +1,14 @@
1
  from audiocraft.models import AudioGen, MusicGen
2
  from audiocraft.data.audio import audio_write
 
3
 
4
- # Load the pretrained model (you can choose "small", "medium", or "large")
5
- sound_model = AudioGen.get_pretrained('facebook/audiogen-medium')
6
- music_model = MusicGen.get_pretrained('facebook/musicgen-small')
7
 
8
- # Set generation parameters (for example, audio duration of 8 seconds)
 
 
 
9
  sound_model.set_generation_params(duration=5)
10
  music_model.set_generation_params(duration=5)
11
 
@@ -19,17 +22,14 @@ def generate_sound(prompt: str):
19
  Returns:
20
  - str: The path to the saved audio file.
21
  """
22
- # Generate the audio for the provided prompt
23
- descriptions = [prompt] # We use the prompt as a description for the model
24
- wav = sound_model.generate(descriptions) # Generates 2 samples
25
 
26
- # Save the generated audio file with loudness normalization
27
  output_path = 'generated_audio'
28
  audio_write(output_path, wav[0].cpu(), sound_model.sample_rate, strategy="loudness")
29
 
30
  return f"{output_path}.wav"
31
 
32
-
33
  def generate_music(prompt: str):
34
  """
35
  Generate music using Audiocraft based on the given prompt.
@@ -40,11 +40,9 @@ def generate_music(prompt: str):
40
  Returns:
41
  - str: The path to the saved audio file.
42
  """
43
- # Generate the music for the provided prompt
44
- descriptions = [prompt] # We use the prompt as a description for the model
45
- wav = music_model.generate(descriptions) # Generates 2 samples
46
 
47
- # Save the generated audio file with loudness normalization
48
  output_path = 'generated_audio'
49
  audio_write(output_path, wav[0].cpu(), music_model.sample_rate, strategy="loudness")
50
 
 
1
  from audiocraft.models import AudioGen, MusicGen
2
  from audiocraft.data.audio import audio_write
3
+ import torch
4
 
5
+ # Load the pretrained models and move them to GPU if available
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
7
 
8
+ sound_model = AudioGen.get_pretrained('facebook/audiogen-medium').to(device)
9
+ music_model = MusicGen.get_pretrained('facebook/musicgen-small').to(device)
10
+
11
+ # Set generation parameters (for example, audio duration of 5 seconds)
12
  sound_model.set_generation_params(duration=5)
13
  music_model.set_generation_params(duration=5)
14
 
 
22
  Returns:
23
  - str: The path to the saved audio file.
24
  """
25
+ descriptions = [prompt]
26
+ wav = sound_model.generate(descriptions) # Generate audio
 
27
 
 
28
  output_path = 'generated_audio'
29
  audio_write(output_path, wav[0].cpu(), sound_model.sample_rate, strategy="loudness")
30
 
31
  return f"{output_path}.wav"
32
 
 
33
  def generate_music(prompt: str):
34
  """
35
  Generate music using Audiocraft based on the given prompt.
 
40
  Returns:
41
  - str: The path to the saved audio file.
42
  """
43
+ descriptions = [prompt]
44
+ wav = music_model.generate(descriptions) # Generate music
 
45
 
 
46
  output_path = 'generated_audio'
47
  audio_write(output_path, wav[0].cpu(), music_model.sample_rate, strategy="loudness")
48