Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Alibaba Inc (authors: Chong Zhang) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import sys | |
import torch | |
import gradio as gr | |
import torchaudio | |
import datetime, hashlib | |
from inspiremusic.cli.inference import InspireMusicUnified, set_env_variables | |
# Prepare environment and model files (unchanged from original) | |
os.system('nvidia-smi') | |
os.system('apt update -y && apt-get install -y apt-utils && apt install -y unzip') | |
os.environ['PYTHONPATH'] = 'third_party/Matcha-TTS' | |
os.system( | |
'mkdir pretrained_models && cd pretrained_models && ' | |
'git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base.git && ' | |
'git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long.git && ' | |
'git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B.git && ' | |
'git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-24kHz.git && ' | |
'git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base-24kHz.git && ' | |
# Fix paths in YAML files | |
'for i in InspireMusic-Base InspireMusic-Base-24kHz InspireMusic-1.5B InspireMusic-1.5B-24kHz InspireMusic-1.5B-Long; ' | |
'do sed -i -e "s/..\/..\///g" ${i}/inspiremusic.yaml; done && cd ..' | |
) | |
print(torch.backends.cudnn.version()) | |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(f"{ROOT_DIR}/third_party/Matcha-TTS") | |
# Define available model options | |
MODELS = ["InspireMusic-1.5B-Long", "InspireMusic-1.5B", "InspireMusic-Base", | |
"InspireMusic-1.5B-24kHz", "InspireMusic-Base-24kHz"] | |
AUDIO_PROMPT_DIR = "demo/audio_prompts" | |
OUTPUT_AUDIO_DIR = "demo/outputs" | |
# **Initialize global model state at startup** | |
loaded_model = None | |
current_model_name = None | |
# Set environment variables once (e.g., for torch performance, precision settings) | |
set_env_variables() | |
# Load the default model into GPU memory | |
current_model_name = "InspireMusic-1.5B-Long" # default selected model in the UI | |
loaded_model = InspireMusicUnified( | |
model_name=current_model_name, | |
model_dir=os.path.join("pretrained_models", current_model_name), | |
min_generate_audio_seconds=10.0, | |
max_generate_audio_seconds=30.0, | |
sample_rate=24000, | |
output_sample_rate=48000, # 48kHz output for default (non-24kHz model) | |
load_jit=True, | |
load_onnx=False, | |
fast=False, # False because 48000 Hz output (not fast mode) | |
result_dir=OUTPUT_AUDIO_DIR | |
) | |
# (The model is now loaded on the GPU and ready for reuse) | |
def generate_filename(): | |
# ... (unchanged: generates a unique filename for outputs) | |
timestamp = str(int(datetime.datetime.now().timestamp())).encode() | |
hash_object = hashlib.sha256(timestamp) | |
return hash_object.hexdigest()[:10] | |
def get_args(task, text="", audio=None, model_name="InspireMusic-Base", | |
chorus="intro", output_sample_rate=48000, max_generate_audio_seconds=30.0, | |
time_start=0.0, time_end=30.0, trim=False): | |
"""Prepare the arguments dictionary for a generation task.""" | |
# If a 24kHz model is selected, force output sample rate to 24000 | |
if "24kHz" in model_name: | |
output_sample_rate = 24000 | |
# Determine fast mode (True if using 24k output, which skips upsampling) | |
fast = True if output_sample_rate == 24000 else False | |
args = { | |
"task": task, | |
"text": text, | |
"audio_prompt": audio, | |
"model_name": model_name, | |
"chorus": chorus, | |
"fast": fast, | |
"fade_out": True, | |
"trim": trim, | |
"output_sample_rate": output_sample_rate, | |
"min_generate_audio_seconds": 10.0, | |
"max_generate_audio_seconds": max_generate_audio_seconds, | |
"max_audio_prompt_length": 5.0, | |
"model_dir": os.path.join("pretrained_models", model_name), | |
"result_dir": OUTPUT_AUDIO_DIR, | |
"output_fn": generate_filename(), | |
"format": "wav", | |
"time_start": time_start or 0.0, | |
"time_end": time_end or (time_start + max_generate_audio_seconds), | |
"fade_out_duration": 1.0, | |
} | |
return args | |
# **Refactored inference function using the preloaded model** | |
def music_generation(args): | |
"""Generate music using the InspireMusic model, reusing a preloaded model if available.""" | |
global loaded_model, current_model_name | |
requested_model = args["model_name"] | |
# If the requested model is not the one currently loaded, load the new model | |
if loaded_model is None or requested_model != current_model_name: | |
# Free GPU memory from the old model | |
if loaded_model is not None: | |
del loaded_model | |
torch.cuda.empty_cache() # free cached memory​:contentReference[oaicite:10]{index=10} | |
# Load the requested model into GPU memory | |
loaded_model = InspireMusicUnified( | |
model_name=requested_model, | |
model_dir=args["model_dir"], | |
min_generate_audio_seconds=args["min_generate_audio_seconds"], | |
max_generate_audio_seconds=args["max_generate_audio_seconds"], | |
sample_rate=24000, | |
output_sample_rate=args["output_sample_rate"], | |
load_jit=True, | |
load_onnx=False, | |
fast=args["fast"], | |
result_dir=args["result_dir"] | |
) | |
current_model_name = requested_model | |
# Perform inference with the loaded model (no gradient computation needed) | |
with torch.no_grad(): # disable grad to save memory​:contentReference[oaicite:11]{index=11}​:contentReference[oaicite:12]{index=12} | |
output_path = loaded_model.inference( | |
task=args["task"], | |
text=args["text"], | |
audio_prompt=args["audio_prompt"], | |
chorus=args["chorus"], | |
time_start=args["time_start"], | |
time_end=args["time_end"], | |
output_fn=args["output_fn"], | |
max_audio_prompt_length=args["max_audio_prompt_length"], | |
fade_out_duration=args["fade_out_duration"], | |
output_format=args["format"], | |
fade_out_mode=args["fade_out"], | |
trim=args["trim"] | |
) | |
return output_path | |
# Demo helper functions (using music_generation internally) | |
def demo_inspiremusic_t2m(text, model_name, chorus, output_sample_rate, max_generate_audio_seconds): | |
args = get_args(task="text-to-music", text=text, audio=None, | |
model_name=model_name, chorus=chorus, | |
output_sample_rate=output_sample_rate, | |
max_generate_audio_seconds=max_generate_audio_seconds) | |
return music_generation(args) | |
def demo_inspiremusic_con(text, audio, model_name, chorus, output_sample_rate, max_generate_audio_seconds): | |
# Trim the audio prompt to 5 seconds and use it for continuation | |
trimmed_audio = trim_audio(audio, cut_seconds=5) | |
args = get_args(task="continuation", text=text, audio=trimmed_audio, | |
model_name=model_name, chorus=chorus, | |
output_sample_rate=output_sample_rate, | |
max_generate_audio_seconds=max_generate_audio_seconds) | |
return music_generation(args) | |
def trim_audio(audio_file, cut_seconds=5): | |
# ... (unchanged: load audio and trim to first 5 seconds) | |
audio_tensor, sr = torchaudio.load(audio_file) | |
num_samples = int(cut_seconds * sr) | |
trimmed_audio = audio_tensor[:, :num_samples] | |
output_path = os.path.join(AUDIO_PROMPT_DIR, "audio_prompt_" + generate_filename() + ".wav") | |
torchaudio.save(output_path, trimmed_audio, sr) | |
return output_path | |
def main(): | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# InspireMusic\n" | |
"- A demo for music generation with high audio quality (up to 48kHz) and long-form capabilities.\n" | |
"- GitHub: https://github.com/FunAudioLLM/InspireMusic\n" | |
"- Available models: InspireMusic-1.5B-Long, InspireMusic-1.5B, InspireMusic-Base, InspireMusic-1.5B-24kHz, InspireMusic-Base-24kHz (on Hugging Face and ModelScope).\n" | |
"*(Note: Only English text prompts are supported.)*") | |
# Input components | |
model_name = gr.Dropdown(MODELS, label="Select Model Name", value="InspireMusic-1.5B-Long") | |
chorus = gr.Dropdown(["intro", "verse", "chorus", "outro"], label="Chorus Mode", value="intro") | |
output_sample_rate = gr.Dropdown([48000, 24000], label="Output Audio Sample Rate (Hz)", value=48000) | |
max_generate_audio_seconds = gr.Slider(10, 300, label="Generate Audio Length (s)", value=30) | |
with gr.Row(): | |
text_input = gr.Textbox(label="Input Text (For Text-to-Music Task)", value="Experience soothing ... ambiance.") | |
audio_input = gr.Audio(label="Input Audio Prompt (For Music Continuation Task)", type="filepath") | |
music_output = gr.Audio(label="Generated Music", type="filepath", autoplay=True, show_download_button=True) | |
# Buttons to trigger generation | |
with gr.Row(): | |
t2m_button = gr.Button("Start Text-to-Music Task") | |
con_button = gr.Button("Start Music Continuation Task") | |
# Bind button clicks to the respective functions | |
t2m_button.click(fn=demo_inspiremusic_t2m, | |
inputs=[text_input, model_name, chorus, output_sample_rate, max_generate_audio_seconds], | |
outputs=music_output) | |
con_button.click(fn=demo_inspiremusic_con, | |
inputs=[text_input, audio_input, model_name, chorus, output_sample_rate, max_generate_audio_seconds], | |
outputs=music_output) | |
gr.Examples(examples=[...], inputs=[text_input]) # (example prompts list truncated for brevity) | |
demo.launch() | |
if __name__ == "__main__": | |
# Ensure output directories exist | |
os.makedirs(AUDIO_PROMPT_DIR, exist_ok=True) | |
os.makedirs(OUTPUT_AUDIO_DIR, exist_ok=True) | |
main() | |