InspireMusic / app.py
chong.zhang
update
3d0f730
raw
history blame
10.3 kB
# Copyright (c) 2024 Alibaba Inc (authors: Chong Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import torch
import gradio as gr
import torchaudio
import datetime, hashlib
from inspiremusic.cli.inference import InspireMusicUnified, set_env_variables
# Prepare environment and model files (unchanged from original)
os.system('nvidia-smi')
os.system('apt update -y && apt-get install -y apt-utils && apt install -y unzip')
os.environ['PYTHONPATH'] = 'third_party/Matcha-TTS'
os.system(
'mkdir pretrained_models && cd pretrained_models && '
'git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base.git && '
'git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long.git && '
'git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B.git && '
'git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-24kHz.git && '
'git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base-24kHz.git && '
# Fix paths in YAML files
'for i in InspireMusic-Base InspireMusic-Base-24kHz InspireMusic-1.5B InspireMusic-1.5B-24kHz InspireMusic-1.5B-Long; '
'do sed -i -e "s/..\/..\///g" ${i}/inspiremusic.yaml; done && cd ..'
)
print(torch.backends.cudnn.version())
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(f"{ROOT_DIR}/third_party/Matcha-TTS")
# Define available model options
MODELS = ["InspireMusic-1.5B-Long", "InspireMusic-1.5B", "InspireMusic-Base",
"InspireMusic-1.5B-24kHz", "InspireMusic-Base-24kHz"]
AUDIO_PROMPT_DIR = "demo/audio_prompts"
OUTPUT_AUDIO_DIR = "demo/outputs"
# **Initialize global model state at startup**
loaded_model = None
current_model_name = None
# Set environment variables once (e.g., for torch performance, precision settings)
set_env_variables()
# Load the default model into GPU memory
current_model_name = "InspireMusic-1.5B-Long" # default selected model in the UI
loaded_model = InspireMusicUnified(
model_name=current_model_name,
model_dir=os.path.join("pretrained_models", current_model_name),
min_generate_audio_seconds=10.0,
max_generate_audio_seconds=30.0,
sample_rate=24000,
output_sample_rate=48000, # 48kHz output for default (non-24kHz model)
load_jit=True,
load_onnx=False,
fast=False, # False because 48000 Hz output (not fast mode)
result_dir=OUTPUT_AUDIO_DIR
)
# (The model is now loaded on the GPU and ready for reuse)
def generate_filename():
# ... (unchanged: generates a unique filename for outputs)
timestamp = str(int(datetime.datetime.now().timestamp())).encode()
hash_object = hashlib.sha256(timestamp)
return hash_object.hexdigest()[:10]
def get_args(task, text="", audio=None, model_name="InspireMusic-Base",
chorus="intro", output_sample_rate=48000, max_generate_audio_seconds=30.0,
time_start=0.0, time_end=30.0, trim=False):
"""Prepare the arguments dictionary for a generation task."""
# If a 24kHz model is selected, force output sample rate to 24000
if "24kHz" in model_name:
output_sample_rate = 24000
# Determine fast mode (True if using 24k output, which skips upsampling)
fast = True if output_sample_rate == 24000 else False
args = {
"task": task,
"text": text,
"audio_prompt": audio,
"model_name": model_name,
"chorus": chorus,
"fast": fast,
"fade_out": True,
"trim": trim,
"output_sample_rate": output_sample_rate,
"min_generate_audio_seconds": 10.0,
"max_generate_audio_seconds": max_generate_audio_seconds,
"max_audio_prompt_length": 5.0,
"model_dir": os.path.join("pretrained_models", model_name),
"result_dir": OUTPUT_AUDIO_DIR,
"output_fn": generate_filename(),
"format": "wav",
"time_start": time_start or 0.0,
"time_end": time_end or (time_start + max_generate_audio_seconds),
"fade_out_duration": 1.0,
}
return args
# **Refactored inference function using the preloaded model**
@spaces.GPU()
def music_generation(args):
"""Generate music using the InspireMusic model, reusing a preloaded model if available."""
global loaded_model, current_model_name
requested_model = args["model_name"]
# If the requested model is not the one currently loaded, load the new model
if loaded_model is None or requested_model != current_model_name:
# Free GPU memory from the old model
if loaded_model is not None:
del loaded_model
torch.cuda.empty_cache() # free cached memory​:contentReference[oaicite:10]{index=10}
# Load the requested model into GPU memory
loaded_model = InspireMusicUnified(
model_name=requested_model,
model_dir=args["model_dir"],
min_generate_audio_seconds=args["min_generate_audio_seconds"],
max_generate_audio_seconds=args["max_generate_audio_seconds"],
sample_rate=24000,
output_sample_rate=args["output_sample_rate"],
load_jit=True,
load_onnx=False,
fast=args["fast"],
result_dir=args["result_dir"]
)
current_model_name = requested_model
# Perform inference with the loaded model (no gradient computation needed)
with torch.no_grad(): # disable grad to save memory​:contentReference[oaicite:11]{index=11}​:contentReference[oaicite:12]{index=12}
output_path = loaded_model.inference(
task=args["task"],
text=args["text"],
audio_prompt=args["audio_prompt"],
chorus=args["chorus"],
time_start=args["time_start"],
time_end=args["time_end"],
output_fn=args["output_fn"],
max_audio_prompt_length=args["max_audio_prompt_length"],
fade_out_duration=args["fade_out_duration"],
output_format=args["format"],
fade_out_mode=args["fade_out"],
trim=args["trim"]
)
return output_path
# Demo helper functions (using music_generation internally)
def demo_inspiremusic_t2m(text, model_name, chorus, output_sample_rate, max_generate_audio_seconds):
args = get_args(task="text-to-music", text=text, audio=None,
model_name=model_name, chorus=chorus,
output_sample_rate=output_sample_rate,
max_generate_audio_seconds=max_generate_audio_seconds)
return music_generation(args)
def demo_inspiremusic_con(text, audio, model_name, chorus, output_sample_rate, max_generate_audio_seconds):
# Trim the audio prompt to 5 seconds and use it for continuation
trimmed_audio = trim_audio(audio, cut_seconds=5)
args = get_args(task="continuation", text=text, audio=trimmed_audio,
model_name=model_name, chorus=chorus,
output_sample_rate=output_sample_rate,
max_generate_audio_seconds=max_generate_audio_seconds)
return music_generation(args)
def trim_audio(audio_file, cut_seconds=5):
# ... (unchanged: load audio and trim to first 5 seconds)
audio_tensor, sr = torchaudio.load(audio_file)
num_samples = int(cut_seconds * sr)
trimmed_audio = audio_tensor[:, :num_samples]
output_path = os.path.join(AUDIO_PROMPT_DIR, "audio_prompt_" + generate_filename() + ".wav")
torchaudio.save(output_path, trimmed_audio, sr)
return output_path
def main():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# InspireMusic\n"
"- A demo for music generation with high audio quality (up to 48kHz) and long-form capabilities.\n"
"- GitHub: https://github.com/FunAudioLLM/InspireMusic\n"
"- Available models: InspireMusic-1.5B-Long, InspireMusic-1.5B, InspireMusic-Base, InspireMusic-1.5B-24kHz, InspireMusic-Base-24kHz (on Hugging Face and ModelScope).\n"
"*(Note: Only English text prompts are supported.)*")
# Input components
model_name = gr.Dropdown(MODELS, label="Select Model Name", value="InspireMusic-1.5B-Long")
chorus = gr.Dropdown(["intro", "verse", "chorus", "outro"], label="Chorus Mode", value="intro")
output_sample_rate = gr.Dropdown([48000, 24000], label="Output Audio Sample Rate (Hz)", value=48000)
max_generate_audio_seconds = gr.Slider(10, 300, label="Generate Audio Length (s)", value=30)
with gr.Row():
text_input = gr.Textbox(label="Input Text (For Text-to-Music Task)", value="Experience soothing ... ambiance.")
audio_input = gr.Audio(label="Input Audio Prompt (For Music Continuation Task)", type="filepath")
music_output = gr.Audio(label="Generated Music", type="filepath", autoplay=True, show_download_button=True)
# Buttons to trigger generation
with gr.Row():
t2m_button = gr.Button("Start Text-to-Music Task")
con_button = gr.Button("Start Music Continuation Task")
# Bind button clicks to the respective functions
t2m_button.click(fn=demo_inspiremusic_t2m,
inputs=[text_input, model_name, chorus, output_sample_rate, max_generate_audio_seconds],
outputs=music_output)
con_button.click(fn=demo_inspiremusic_con,
inputs=[text_input, audio_input, model_name, chorus, output_sample_rate, max_generate_audio_seconds],
outputs=music_output)
gr.Examples(examples=[...], inputs=[text_input]) # (example prompts list truncated for brevity)
demo.launch()
if __name__ == "__main__":
# Ensure output directories exist
os.makedirs(AUDIO_PROMPT_DIR, exist_ok=True)
os.makedirs(OUTPUT_AUDIO_DIR, exist_ok=True)
main()