Spaces:
Running
on
Zero
Running
on
Zero
Curinha
Refactor sound generation functions to remove user_id parameter and adjust GPU duration
671b217
import time | |
import spaces | |
import torch | |
from audiocraft.data.audio import audio_write | |
from audiocraft.models import AudioGen, MusicGen | |
# Load the pretrained models and move them to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Using device:", device) | |
sound_model = AudioGen.get_pretrained('facebook/audiogen-medium') | |
music_model = MusicGen.get_pretrained('facebook/musicgen-small') | |
# Set generation parameters (for example, audio duration of 5 seconds) | |
sound_model.set_generation_params(duration=5) | |
music_model.set_generation_params(duration=5) | |
def generate_sound(prompt: str): | |
""" | |
Generate sound using Audiocraft based on the given prompt. | |
Args: | |
- prompt (str): The description of the sound/music to generate. | |
Returns: | |
- str: The path to the saved audio file. | |
""" | |
descriptions = [prompt] | |
timestamp = str(time.time()).replace(".", "") | |
wav = sound_model.generate(descriptions) # Generate audio | |
output_path = f'{prompt}_{timestamp}' | |
audio_write(output_path, wav[0].cpu(), sound_model.sample_rate, strategy="loudness") | |
return f"{output_path}.wav" | |
def generate_music(prompt: str): | |
""" | |
Generate music using Audiocraft based on the given prompt. | |
Args: | |
- prompt (str): The description of the music to generate. | |
Returns: | |
- str: The path to the saved audio file. | |
""" | |
descriptions = [prompt] | |
timestamp = str(time.time()).replace(".", "") | |
wav = music_model.generate(descriptions) # Generate music | |
output_path = f'{prompt}_{timestamp}' | |
audio_write(output_path, wav[0].cpu(), music_model.sample_rate, strategy="loudness") | |
return f"{output_path}.wav" | |