chong.zhang
update
957f1a3
raw
history blame
14.4 kB
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import torchaudio
import time
import logging
import argparse
from inspiremusic.cli.inspiremusic import InspireMusic
from inspiremusic.utils.file_utils import logging
import torch
from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
def set_env_variables():
os.environ['PYTHONIOENCODING'] = 'UTF-8'
os.environ['TOKENIZERS_PARALLELISM'] = 'False'
main_root = os.getcwd()
bin_dir = os.path.join(main_root, 'inspiremusic')
third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS')
python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}"
os.environ['PATH'] = python_path
sys.path.extend([main_root, third_party_matcha_tts_path])
class InspireMusicUnified:
def __init__(self,
model_name: str = "InspireMusic-1.5B-Long",
model_dir: str = None,
min_generate_audio_seconds: float = 10.0,
max_generate_audio_seconds: float = 30.0,
sample_rate: int = 24000,
output_sample_rate: int = 48000,
load_jit: bool = True,
load_onnx: bool = False,
fast: bool = False,
fp16: bool = True,
gpu: int = 0,
result_dir: str = None,
hub="modelscope"):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
# Set model_dir or default to downloading if it doesn't exist
if model_dir is None:
model_dir = f"pretrained_models/{model_name}"
else:
model_dir = model_dir.replace("../../", "./")
if not os.path.isfile(f"{model_dir}/llm.pt"):
if hub == "modelscope":
from modelscope import snapshot_download
if model_name == "InspireMusic-Base":
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
else:
snapshot_download(f"iic/{model_name}", local_dir=model_dir)
self.model_dir = model_dir
print(self.model_dir)
self.sample_rate = sample_rate
self.output_sample_rate = 24000 if fast else output_sample_rate
self.result_dir = result_dir or f"exp/{model_name}"
os.makedirs(self.result_dir, exist_ok=True)
self.min_generate_audio_seconds = min_generate_audio_seconds
self.max_generate_audio_seconds = max_generate_audio_seconds
self.min_generate_audio_length = int(self.output_sample_rate * self.min_generate_audio_seconds)
self.max_generate_audio_length = int(self.output_sample_rate * self.max_generate_audio_seconds)
assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds"
use_cuda = gpu >= 0 and torch.cuda.is_available()
self.device = torch.device('cuda' if use_cuda else 'cpu')
self.model = InspireMusic(self.model_dir, load_jit=load_jit, load_onnx=load_onnx, fast=fast, fp16=fp16)
self.model.model.llm = self.model.model.llm.to(torch.float16)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@torch.inference_mode()
def inference(self,
task: str = 'text-to-music',
text: str = None,
audio_prompt: str = None, # audio prompt file path
chorus: str = "verse",
time_start: float = 0.0,
time_end: float = 30.0,
output_fn: str = "output_audio",
max_audio_prompt_length: float = 5.0,
fade_out_duration: float = 1.0,
output_format: str = "wav",
fade_out_mode: bool = True,
trim: bool = False,
):
with torch.no_grad():
text_prompt = f"<|{time_start}|><|{chorus}|><|{text}|><|{time_end}|>"
chorus_dict = {"random": torch.randint(1, 5, (1,)).item(), "intro" : 0, "verse": 1, "chorus": 2, "outro": 4}
chorus = chorus_dict.get(chorus, 1)
chorus = torch.tensor([chorus], dtype=torch.int).to(self.device)
time_start_tensor = torch.tensor([time_start], dtype=torch.float64).to(self.device)
time_end_tensor = torch.tensor([time_end], dtype=torch.float64).to(self.device)
music_fn = os.path.join(self.result_dir, f'{output_fn}.{output_format}')
bench_start = time.time()
if task == 'text-to-music':
model_input = {
"text" : text,
"audio_prompt" : audio_prompt,
"time_start" : time_start_tensor,
"time_end" : time_end_tensor,
"chorus" : chorus,
"task" : task,
"stream" : False,
"duration_to_gen": self.max_generate_audio_seconds,
"sr" : self.sample_rate
}
elif task == 'continuation':
if audio_prompt is not None:
audio, _ = process_audio(audio_prompt, self.sample_rate)
if audio.size(1) < self.sample_rate:
logging.warning("Warning: Input prompt audio length is shorter than 1s. Please provide an appropriate length audio prompt and try again.")
audio = None
else:
max_audio_prompt_length_samples = int(max_audio_prompt_length * self.sample_rate)
audio = audio[:, :max_audio_prompt_length_samples] # Trimming prompt audio
model_input = {
"text" : text,
"audio_prompt" : audio,
"time_start" : time_start_tensor,
"time_end" : time_end_tensor,
"chorus" : chorus,
"task" : task,
"stream" : False,
"duration_to_gen": self.max_generate_audio_seconds,
"sr" : self.sample_rate
}
music_audios = []
for model_output in self.model.cli_inference(**model_input):
music_audios.append(model_output['music_audio'])
bench_end = time.time()
if trim:
music_audio = trim_audio(music_audios[0],
sample_rate=self.output_sample_rate,
threshold=0.05,
min_silence_duration=0.8)
else:
music_audio = music_audios[0]
if music_audio.shape[0] != 0:
if music_audio.shape[1] > self.max_generate_audio_length:
music_audio = music_audio[:, :self.max_generate_audio_length]
if music_audio.shape[1] >= self.min_generate_audio_length:
try:
if fade_out_mode:
music_audio = fade_out(music_audio, self.output_sample_rate, fade_out_duration)
music_audio = music_audio.repeat(2, 1)
if output_format in ["wav", "flac"]:
torchaudio.save(music_fn, music_audio,
sample_rate=self.output_sample_rate,
encoding="PCM_S",
bits_per_sample=24)
elif output_format in ["mp3", "m4a"]:
torchaudio.backend.sox_io_backend.save(
filepath=music_fn, src=music_audio,
sample_rate=self.output_sample_rate,
format=output_format)
else:
logging.info("Format is not supported. Please choose from wav, mp3, m4a, flac.")
except Exception as e:
logging.error(f"Error saving file: {e}")
raise
audio_duration = music_audio.shape[1] / self.output_sample_rate
rtf = (bench_end - bench_start) / audio_duration
logging.info(f"Processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}")
else:
logging.error(f"Generated audio length is shorter than minimum required audio length.")
if music_fn:
if os.path.exists(music_fn):
logging.info(f"Generated audio file {music_fn} is saved.")
return music_fn
else:
logging.error(f"{music_fn} does not exist.")
def get_args():
parser = argparse.ArgumentParser(description='Run inference with your model')
parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long",
help='Model name')
parser.add_argument('-d', '--model_dir',
help='Model folder path')
parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.",
help='Prompt text')
parser.add_argument('-a', '--audio_prompt', default=None,
help='Prompt audio')
parser.add_argument('-c', '--chorus', default="intro",
help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)')
parser.add_argument('-f', '--fast', type=bool, default=False,
help='Enable fast inference mode (without flow matching)')
parser.add_argument('-g', '--gpu', type=int, default=0,
help='GPU ID for this rank, -1 for CPU')
parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'],
help='Inference task type: text-to-music, continuation, reconstruct, super_resolution')
parser.add_argument('-r', '--result_dir', default="exp/inspiremusic",
help='Directory to save generated audio')
parser.add_argument('-o', '--output_fn', default="output_audio",
help='Output file name')
parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"],
help='Format of output audio')
parser.add_argument('--sample_rate', type=int, default=24000,
help='Sampling rate of input audio')
parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000],
help='Sampling rate of generated output audio')
parser.add_argument('-s', '--time_start', type=float, default=0.0,
help='Start time in seconds')
parser.add_argument('-e', '--time_end', type=float, default=30.0,
help='End time in seconds')
parser.add_argument('--max_audio_prompt_length', type=float, default=5.0,
help='Maximum audio prompt length in seconds')
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0,
help='Minimum generated audio length in seconds')
parser.add_argument('--max_generate_audio_seconds', type=float, default=300.0,
help='Maximum generated audio length in seconds')
parser.add_argument('--fp16', type=bool, default=True,
help='Inference with fp16 model')
parser.add_argument('--fade_out', type=bool, default=True,
help='Apply fade out effect to generated audio')
parser.add_argument('--fade_out_duration', type=float, default=1.0,
help='Fade out duration in seconds')
parser.add_argument('--trim', type=bool, default=False,
help='Trim the silence ending of generated audio')
args = parser.parse_args()
if not args.model_dir:
args.model_dir = os.path.join("pretrained_models", args.model_name)
print(args)
return args
def main():
set_env_variables()
args = get_args()
model = InspireMusicUnified(model_name = args.model_name,
model_dir = args.model_dir,
min_generate_audio_seconds = args.min_generate_audio_seconds,
max_generate_audio_seconds = args.max_generate_audio_seconds,
sample_rate = args.sample_rate,
output_sample_rate = args.output_sample_rate,
load_jit = True,
load_onnx = False,
fast = args.fast,
fp16 = args.fp16,
gpu = args.gpu,
result_dir = args.result_dir)
model.inference(task = args.task,
text = args.text,
audio_prompt = args.audio_prompt,
chorus = args.chorus,
time_start = args.time_start,
time_end = args.time_end,
output_fn = args.output_fn,
max_audio_prompt_length = args.max_audio_prompt_length,
fade_out_duration = args.fade_out_duration,
output_format = args.format,
fade_out_mode = args.fade_out,
trim = args.trim)
if __name__ == "__main__":
main()