InspireMusic / app.py
chong.zhang
update
22ee199
raw
history blame
7.27 kB
import os
import spaces
import gradio as gr
from inspiremusic.cli.inference import InspireMusicUnified, set_env_variables
import torchaudio
import datetime
import hashlib
import torch
from modelscope import snapshot_download
os.system('nvidia-smi')
print(torch.backends.cudnn.version())
def generate_filename():
hash_object = hashlib.sha256(str(int(datetime.datetime.now().timestamp())).encode())
hash_string = hash_object.hexdigest()
return hash_string
def get_args(
task, text="", audio=None, model_name="InspireMusic-Base",
chorus="intro",
output_sample_rate=48000, max_generate_audio_seconds=30.0, time_start = 0.0, time_end=30.0, trim=False):
if output_sample_rate == 24000:
fast = True
else:
fast = False
# This function constructs the arguments required for InspireMusic
args = {
"task" : task,
"text" : text,
"audio_prompt" : audio,
"model_name" : model_name,
"chorus" : chorus,
"fast" : fast,
"fade_out" : True,
"trim" : trim,
"output_sample_rate" : output_sample_rate,
"min_generate_audio_seconds": 10.0,
"max_generate_audio_seconds": max_generate_audio_seconds,
"model_dir" : os.path.join("pretrained_models",
model_name),
"result_dir" : "exp/inspiremusic",
"output_fn" : generate_filename(),
"format" : "wav",
"time_start" : time_start,
"time_end": time_end,
"fade_out_duration": 1.0,
}
if args["time_start"] is None:
args["time_start"] = 0.0
args["time_end"] = args["time_start"] + args["max_generate_audio_seconds"]
print(args)
return args
def trim_audio(audio_file, cut_seconds=5):
audio, sr = torchaudio.load(audio_file)
num_samples = cut_seconds * sr
cutted_audio = audio[:, :num_samples]
output_path = os.path.join(os.getcwd(), "audio_prompt_" + generate_filename() + ".wav")
torchaudio.save(output_path, cutted_audio, sr)
return output_path
@spaces.GPU
def music_generation(args):
set_env_variables()
model = InspireMusicUnified(
model_name=args["model_name"],
model_dir=args["model_dir"],
min_generate_audio_seconds=args["min_generate_audio_seconds"],
max_generate_audio_seconds=args["max_generate_audio_seconds"],
sample_rate=24000,
output_sample_rate=args["output_sample_rate"],
load_jit=True,
load_onnx=False,
fast=args["fast"],
result_dir=args["result_dir"])
output_path = model.inference(
task=args["task"],
text=args["text"],
audio_prompt=args["audio_prompt"],
chorus=args["chorus"],
time_start=args["time_start"],
time_end=args["time_end"],
output_fn=args["output_fn"],
max_audio_prompt_length=args["max_audio_prompt_length"],
fade_out_duration=args["fade_out_duration"],
output_format=args["format"],
fade_out_mode=args["fade_out"],
trim=args["trim"])
return output_path
@spaces.GPU
def demo_inspiremusic_t2m(text, model_name, chorus,
output_sample_rate, max_generate_audio_seconds):
args = get_args(
task='text-to-music', text=text, audio=None,
model_name=model_name, chorus=chorus,
output_sample_rate=output_sample_rate,
max_generate_audio_seconds=max_generate_audio_seconds)
return music_generation(args)
@spaces.GPU
def demo_inspiremusic_con(text, audio, model_name, chorus,
output_sample_rate, max_generate_audio_seconds):
args = get_args(
task='continuation', text=text, audio=trim_audio(audio, cut_seconds=5),
model_name=model_name, chorus=chorus,
output_sample_rate=output_sample_rate,
max_generate_audio_seconds=max_generate_audio_seconds)
return music_generation(args)
def main():
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# InspireMusic
- Support text-to-music, music continuation, audio super-resolution, audio reconstruction tasks with high audio quality, with available sampling rates of 24kHz, 48kHz.
- Support long audio generation in multiple output audio formats, i.e., wav, flac, mp3, m4a.
- Open-source [InspireMusic-Base](https://modelscope.cn/models/iic/InspireMusic/summary), [InspireMusic-Base-24kHz](https://modelscope.cn/models/iic/InspireMusic-Base-24kHz/summary), [InspireMusic-1.5B](https://modelscope.cn/models/iic/InspireMusic-1.5B/summary), [InspireMusic-1.5B-24kHz](https://modelscope.cn/models/iic/InspireMusic-1.5B-24kHz/summary), [InspireMusic-1.5B-Long](https://modelscope.cn/models/iic/InspireMusic-1.5B-Long/summary) models for music generation.
- Currently only support English text prompts.
""")
with gr.Row(equal_height=True):
model_name = gr.Dropdown(["InspireMusic-1.5B-Long", "InspireMusic-1.5B", "InspireMusic-1.5B-24kHz", "InspireMusic-Base", "InspireMusic-Base-24kHz"], label="Select Model Name", value="InspireMusic-Base")
chorus = gr.Dropdown(["intro", "verse", "chorus", "outro"],
label="Chorus Mode", value="intro")
output_sample_rate = gr.Dropdown([48000, 24000],
label="Output Audio Sample Rate (Hz)",
value=48000)
max_generate_audio_seconds = gr.Slider(10, 120,
label="Generate Audio Length (s)",
value=30)
# with gr.Row(equal_height=True):
text_input = gr.Textbox(label="Input Text (For Text-to-Music Task)", value="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.")
music_output = gr.Audio(label="Text to Music Output", type="filepath")
button = gr.Button("Text to Music")
button.click(demo_inspiremusic_t2m,
inputs=[text_input, model_name,
chorus,
output_sample_rate,
max_generate_audio_seconds],
outputs=music_output)
audio_input = gr.Audio(label="Input Audio Prompt (For Music Continuation Task)",
type="filepath")
music_con_output = gr.Audio(label="Music Continuation Output", type="filepath")
generate_button = gr.Button("Music Continuation")
generate_button.click(demo_inspiremusic_con,
inputs=[text_input, audio_input, model_name,
chorus,
output_sample_rate,
max_generate_audio_seconds],
outputs=music_con_output)
demo.launch()
if __name__ == '__main__':
# model_list = ["InspireMusic-1.5B-Long", "InspireMusic-1.5B", "InspireMusic-Base"]
model_list = ["InspireMusic-Base"]
for model_name in model_list:
model_dir = f"pretrained_models/{model_name}"
if not os.path.isdir(model_dir):
if model_name == "InspireMusic-Base":
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
else:
snapshot_download(f"iic/{model_name}", local_dir=model_dir)
yaml_file_path = os.path.join(model_dir, 'inspiremusic.yaml')
with open(yaml_file_path, 'r') as file:
lines = file.readlines()
updated_lines = []
for line in lines:
if "basemodel_path: '../../pretrained_models/" in line:
line = line.replace('../../pretrained_models', 'pretrained_models')
elif "generator_path: '../../pretrained_models" in line:
line = line.replace('../../pretrained_models', 'pretrained_models')
updated_lines.append(line)
with open(yaml_file_path, 'w') as file:
file.writelines(updated_lines)
main()