Spaces:
Running
on
Zero
Running
on
Zero
Curinha
commited on
Commit
·
f8bd524
1
Parent(s):
5241704
Refactor sound_generator.py to load models on GPU if available and streamline audio generation process
Browse files- sound_generator.py +11 -13
sound_generator.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
from audiocraft.models import AudioGen, MusicGen
|
2 |
from audiocraft.data.audio import audio_write
|
|
|
3 |
|
4 |
-
# Load the pretrained
|
5 |
-
|
6 |
-
music_model = MusicGen.get_pretrained('facebook/musicgen-small')
|
7 |
|
8 |
-
|
|
|
|
|
|
|
9 |
sound_model.set_generation_params(duration=5)
|
10 |
music_model.set_generation_params(duration=5)
|
11 |
|
@@ -19,17 +22,14 @@ def generate_sound(prompt: str):
|
|
19 |
Returns:
|
20 |
- str: The path to the saved audio file.
|
21 |
"""
|
22 |
-
|
23 |
-
|
24 |
-
wav = sound_model.generate(descriptions) # Generates 2 samples
|
25 |
|
26 |
-
# Save the generated audio file with loudness normalization
|
27 |
output_path = 'generated_audio'
|
28 |
audio_write(output_path, wav[0].cpu(), sound_model.sample_rate, strategy="loudness")
|
29 |
|
30 |
return f"{output_path}.wav"
|
31 |
|
32 |
-
|
33 |
def generate_music(prompt: str):
|
34 |
"""
|
35 |
Generate music using Audiocraft based on the given prompt.
|
@@ -40,11 +40,9 @@ def generate_music(prompt: str):
|
|
40 |
Returns:
|
41 |
- str: The path to the saved audio file.
|
42 |
"""
|
43 |
-
|
44 |
-
|
45 |
-
wav = music_model.generate(descriptions) # Generates 2 samples
|
46 |
|
47 |
-
# Save the generated audio file with loudness normalization
|
48 |
output_path = 'generated_audio'
|
49 |
audio_write(output_path, wav[0].cpu(), music_model.sample_rate, strategy="loudness")
|
50 |
|
|
|
1 |
from audiocraft.models import AudioGen, MusicGen
|
2 |
from audiocraft.data.audio import audio_write
|
3 |
+
import torch
|
4 |
|
5 |
+
# Load the pretrained models and move them to GPU if available
|
6 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
7 |
|
8 |
+
sound_model = AudioGen.get_pretrained('facebook/audiogen-medium').to(device)
|
9 |
+
music_model = MusicGen.get_pretrained('facebook/musicgen-small').to(device)
|
10 |
+
|
11 |
+
# Set generation parameters (for example, audio duration of 5 seconds)
|
12 |
sound_model.set_generation_params(duration=5)
|
13 |
music_model.set_generation_params(duration=5)
|
14 |
|
|
|
22 |
Returns:
|
23 |
- str: The path to the saved audio file.
|
24 |
"""
|
25 |
+
descriptions = [prompt]
|
26 |
+
wav = sound_model.generate(descriptions) # Generate audio
|
|
|
27 |
|
|
|
28 |
output_path = 'generated_audio'
|
29 |
audio_write(output_path, wav[0].cpu(), sound_model.sample_rate, strategy="loudness")
|
30 |
|
31 |
return f"{output_path}.wav"
|
32 |
|
|
|
33 |
def generate_music(prompt: str):
|
34 |
"""
|
35 |
Generate music using Audiocraft based on the given prompt.
|
|
|
40 |
Returns:
|
41 |
- str: The path to the saved audio file.
|
42 |
"""
|
43 |
+
descriptions = [prompt]
|
44 |
+
wav = music_model.generate(descriptions) # Generate music
|
|
|
45 |
|
|
|
46 |
output_path = 'generated_audio'
|
47 |
audio_write(output_path, wav[0].cpu(), music_model.sample_rate, strategy="loudness")
|
48 |
|