Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Alibaba Inc | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import sys | |
import torchaudio | |
import time | |
import logging | |
import argparse | |
from inspiremusic.cli.inspiremusic import InspireMusic | |
from inspiremusic.utils.file_utils import logging | |
import torch | |
from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio | |
def set_env_variables(): | |
os.environ['PYTHONIOENCODING'] = 'UTF-8' | |
os.environ['TOKENIZERS_PARALLELISM'] = 'False' | |
main_root = os.getcwd() | |
bin_dir = os.path.join(main_root, 'inspiremusic') | |
third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS') | |
python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}" | |
os.environ['PATH'] = python_path | |
sys.path.extend([main_root, third_party_matcha_tts_path]) | |
class InspireMusicUnified: | |
def __init__(self, | |
model_name: str = "InspireMusic-1.5B-Long", | |
model_dir: str = None, | |
min_generate_audio_seconds: float = 10.0, | |
max_generate_audio_seconds: float = 30.0, | |
sample_rate: int = 24000, | |
output_sample_rate: int = 48000, | |
load_jit: bool = True, | |
load_onnx: bool = False, | |
fast: bool = False, | |
fp16: bool = True, | |
gpu: int = 0, | |
result_dir: str = None, | |
hub="modelscope"): | |
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) | |
# Set model_dir or default to downloading if it doesn't exist | |
if model_dir is None: | |
model_dir = f"pretrained_models/{model_name}" | |
else: | |
model_dir = model_dir.replace("../../", "./") | |
if not os.path.isfile(f"{model_dir}/llm.pt"): | |
if hub == "modelscope": | |
from modelscope import snapshot_download | |
if model_name == "InspireMusic-Base": | |
snapshot_download(f"iic/InspireMusic", local_dir=model_dir) | |
else: | |
snapshot_download(f"iic/{model_name}", local_dir=model_dir) | |
self.model_dir = model_dir | |
print(self.model_dir) | |
self.sample_rate = sample_rate | |
self.output_sample_rate = 24000 if fast else output_sample_rate | |
self.result_dir = result_dir or f"exp/{model_name}" | |
os.makedirs(self.result_dir, exist_ok=True) | |
self.min_generate_audio_seconds = min_generate_audio_seconds | |
self.max_generate_audio_seconds = max_generate_audio_seconds | |
self.min_generate_audio_length = int(self.output_sample_rate * self.min_generate_audio_seconds) | |
self.max_generate_audio_length = int(self.output_sample_rate * self.max_generate_audio_seconds) | |
assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds" | |
use_cuda = gpu >= 0 and torch.cuda.is_available() | |
self.device = torch.device('cuda' if use_cuda else 'cpu') | |
self.model = InspireMusic(self.model_dir, load_jit=load_jit, load_onnx=load_onnx, fast=fast, fp16=fp16) | |
self.model.model.llm = self.model.model.llm.to(torch.float16) | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
def inference(self, | |
task: str = 'text-to-music', | |
text: str = None, | |
audio_prompt: str = None, # audio prompt file path | |
chorus: str = "verse", | |
time_start: float = 0.0, | |
time_end: float = 30.0, | |
output_fn: str = "output_audio", | |
max_audio_prompt_length: float = 5.0, | |
fade_out_duration: float = 1.0, | |
output_format: str = "wav", | |
fade_out_mode: bool = True, | |
trim: bool = False, | |
): | |
with torch.no_grad(): | |
text_prompt = f"<|{time_start}|><|{chorus}|><|{text}|><|{time_end}|>" | |
chorus_dict = {"random": torch.randint(1, 5, (1,)).item(), "intro" : 0, "verse": 1, "chorus": 2, "outro": 4} | |
chorus = chorus_dict.get(chorus, 1) | |
chorus = torch.tensor([chorus], dtype=torch.int).to(self.device) | |
time_start_tensor = torch.tensor([time_start], dtype=torch.float64).to(self.device) | |
time_end_tensor = torch.tensor([time_end], dtype=torch.float64).to(self.device) | |
music_fn = os.path.join(self.result_dir, f'{output_fn}.{output_format}') | |
bench_start = time.time() | |
if task == 'text-to-music': | |
model_input = { | |
"text" : text, | |
"audio_prompt" : audio_prompt, | |
"time_start" : time_start_tensor, | |
"time_end" : time_end_tensor, | |
"chorus" : chorus, | |
"task" : task, | |
"stream" : False, | |
"duration_to_gen": self.max_generate_audio_seconds, | |
"sr" : self.sample_rate | |
} | |
elif task == 'continuation': | |
if audio_prompt is not None: | |
audio, _ = process_audio(audio_prompt, self.sample_rate) | |
if audio.size(1) < self.sample_rate: | |
logging.warning("Warning: Input prompt audio length is shorter than 1s. Please provide an appropriate length audio prompt and try again.") | |
audio = None | |
else: | |
max_audio_prompt_length_samples = int(max_audio_prompt_length * self.sample_rate) | |
audio = audio[:, :max_audio_prompt_length_samples] # Trimming prompt audio | |
model_input = { | |
"text" : text, | |
"audio_prompt" : audio, | |
"time_start" : time_start_tensor, | |
"time_end" : time_end_tensor, | |
"chorus" : chorus, | |
"task" : task, | |
"stream" : False, | |
"duration_to_gen": self.max_generate_audio_seconds, | |
"sr" : self.sample_rate | |
} | |
music_audios = [] | |
for model_output in self.model.cli_inference(**model_input): | |
music_audios.append(model_output['music_audio']) | |
bench_end = time.time() | |
if trim: | |
music_audio = trim_audio(music_audios[0], | |
sample_rate=self.output_sample_rate, | |
threshold=0.05, | |
min_silence_duration=0.8) | |
else: | |
music_audio = music_audios[0] | |
if music_audio.shape[0] != 0: | |
if music_audio.shape[1] > self.max_generate_audio_length: | |
music_audio = music_audio[:, :self.max_generate_audio_length] | |
if music_audio.shape[1] >= self.min_generate_audio_length: | |
try: | |
if fade_out_mode: | |
music_audio = fade_out(music_audio, self.output_sample_rate, fade_out_duration) | |
music_audio = music_audio.repeat(2, 1) | |
if output_format in ["wav", "flac"]: | |
torchaudio.save(music_fn, music_audio, | |
sample_rate=self.output_sample_rate, | |
encoding="PCM_S", | |
bits_per_sample=24) | |
elif output_format in ["mp3", "m4a"]: | |
torchaudio.backend.sox_io_backend.save( | |
filepath=music_fn, src=music_audio, | |
sample_rate=self.output_sample_rate, | |
format=output_format) | |
else: | |
logging.info("Format is not supported. Please choose from wav, mp3, m4a, flac.") | |
except Exception as e: | |
logging.error(f"Error saving file: {e}") | |
raise | |
audio_duration = music_audio.shape[1] / self.output_sample_rate | |
rtf = (bench_end - bench_start) / audio_duration | |
logging.info(f"Processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}") | |
else: | |
logging.error(f"Generated audio length is shorter than minimum required audio length.") | |
if music_fn: | |
if os.path.exists(music_fn): | |
logging.info(f"Generated audio file {music_fn} is saved.") | |
return music_fn | |
else: | |
logging.error(f"{music_fn} does not exist.") | |
def get_args(): | |
parser = argparse.ArgumentParser(description='Run inference with your model') | |
parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long", | |
help='Model name') | |
parser.add_argument('-d', '--model_dir', | |
help='Model folder path') | |
parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.", | |
help='Prompt text') | |
parser.add_argument('-a', '--audio_prompt', default=None, | |
help='Prompt audio') | |
parser.add_argument('-c', '--chorus', default="intro", | |
help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)') | |
parser.add_argument('-f', '--fast', type=bool, default=False, | |
help='Enable fast inference mode (without flow matching)') | |
parser.add_argument('-g', '--gpu', type=int, default=0, | |
help='GPU ID for this rank, -1 for CPU') | |
parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'], | |
help='Inference task type: text-to-music, continuation, reconstruct, super_resolution') | |
parser.add_argument('-r', '--result_dir', default="exp/inspiremusic", | |
help='Directory to save generated audio') | |
parser.add_argument('-o', '--output_fn', default="output_audio", | |
help='Output file name') | |
parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"], | |
help='Format of output audio') | |
parser.add_argument('--sample_rate', type=int, default=24000, | |
help='Sampling rate of input audio') | |
parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000], | |
help='Sampling rate of generated output audio') | |
parser.add_argument('-s', '--time_start', type=float, default=0.0, | |
help='Start time in seconds') | |
parser.add_argument('-e', '--time_end', type=float, default=30.0, | |
help='End time in seconds') | |
parser.add_argument('--max_audio_prompt_length', type=float, default=5.0, | |
help='Maximum audio prompt length in seconds') | |
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, | |
help='Minimum generated audio length in seconds') | |
parser.add_argument('--max_generate_audio_seconds', type=float, default=300.0, | |
help='Maximum generated audio length in seconds') | |
parser.add_argument('--fp16', type=bool, default=True, | |
help='Inference with fp16 model') | |
parser.add_argument('--fade_out', type=bool, default=True, | |
help='Apply fade out effect to generated audio') | |
parser.add_argument('--fade_out_duration', type=float, default=1.0, | |
help='Fade out duration in seconds') | |
parser.add_argument('--trim', type=bool, default=False, | |
help='Trim the silence ending of generated audio') | |
args = parser.parse_args() | |
if not args.model_dir: | |
args.model_dir = os.path.join("pretrained_models", args.model_name) | |
print(args) | |
return args | |
def main(): | |
set_env_variables() | |
args = get_args() | |
model = InspireMusicUnified(model_name = args.model_name, | |
model_dir = args.model_dir, | |
min_generate_audio_seconds = args.min_generate_audio_seconds, | |
max_generate_audio_seconds = args.max_generate_audio_seconds, | |
sample_rate = args.sample_rate, | |
output_sample_rate = args.output_sample_rate, | |
load_jit = True, | |
load_onnx = False, | |
fast = args.fast, | |
fp16 = args.fp16, | |
gpu = args.gpu, | |
result_dir = args.result_dir) | |
model.inference(task = args.task, | |
text = args.text, | |
audio_prompt = args.audio_prompt, | |
chorus = args.chorus, | |
time_start = args.time_start, | |
time_end = args.time_end, | |
output_fn = args.output_fn, | |
max_audio_prompt_length = args.max_audio_prompt_length, | |
fade_out_duration = args.fade_out_duration, | |
output_format = args.format, | |
fade_out_mode = args.fade_out, | |
trim = args.trim) | |
if __name__ == "__main__": | |
main() |