# Copyright (c) 2024 Alibaba Inc (authors: Chong Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import spaces
import gradio as gr
from inspiremusic.cli.inference import InspireMusicUnified, set_env_variables
import torchaudio
import datetime
import hashlib
import torch
import importlib
import sys

os.system('nvidia-smi')
os.system('apt update -y && apt-get install -y apt-utils && apt install -y unzip')
# os.system('pip install flash-attn --no-build-isolation')
# os.system('git submodule update --init --recursive')
os.system('mkdir pretrained_models && cd pretrained_models && git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-Long.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-1.5B-24kHz.git &&git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base-24kHz.git && for i in InspireMusic-Base InspireMusic-Base-24kHz InspireMusic-1.5B InspireMusic-1.5B-24kHz InspireMusic-1.5B-Long; do sed -i -e "s/\.\.\/\.\.\///g" ${i}/inspiremusic.yaml; done && cd ..')

# os.system('mkdir pretrained_models && cd pretrained_models && git clone https://huggingface.co/FunAudioLLM/InspireMusic-Base.git && for i in InspireMusic-Base; do sed -i -e "s/\.\.\/\.\.\///g" ${i}/inspiremusic.yaml; done && cd ..')

print(torch.backends.cudnn.version())

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))

def generate_filename():
	hash_object = hashlib.sha256(str(int(datetime.datetime.now().timestamp())).encode())
	hash_string = hash_object.hexdigest()
	return hash_string

def get_args(
		task, text="", audio=None, model_name="InspireMusic-Base",
		chorus="intro",
		output_sample_rate=48000, max_generate_audio_seconds=30.0, time_start = 0.0, time_end=30.0, trim=False):
	
	if output_sample_rate == 24000:
		fast = True
	else:
		fast = False
	# This function constructs the arguments required for InspireMusic
	args = {
		"task"                      : task,
		"text"                      : text,
		"audio_prompt"              : audio,
		"model_name"                : model_name,
		"chorus"                    : chorus,
		"fast"                      : fast,
		"fade_out"                  : True,
		"trim"                      : trim,
		"output_sample_rate"        : output_sample_rate,
		"min_generate_audio_seconds": 10.0,
		"max_generate_audio_seconds": max_generate_audio_seconds,
		"model_dir"                 : os.path.join("pretrained_models",
												   model_name),
		"result_dir"                : "exp/inspiremusic",
		"output_fn"                 : generate_filename(),
		"format"                    : "wav",
		"time_start" : time_start,
		"time_end": time_end,
		"fade_out_duration": 1.0,
	}

	if args["time_start"] is None:
		args["time_start"] = 0.0
	args["time_end"] = args["time_start"] + args["max_generate_audio_seconds"]

	print(args)
	return args


def trim_audio(audio_file, cut_seconds=5):
	audio, sr = torchaudio.load(audio_file)
	num_samples = cut_seconds * sr
	cutted_audio = audio[:, :num_samples]
	output_path = os.path.join(os.getcwd(), "audio_prompt_" + generate_filename() + ".wav")
	torchaudio.save(output_path, cutted_audio, sr)
	return output_path

@spaces.GPU(duration=120)
def music_generation(args):
	set_env_variables()
	model = InspireMusicUnified(
			model_name=args["model_name"],
			model_dir=args["model_dir"],
			min_generate_audio_seconds=args["min_generate_audio_seconds"],
			max_generate_audio_seconds=args["max_generate_audio_seconds"],
			sample_rate=24000,
			output_sample_rate=args["output_sample_rate"],
			load_jit=True,
			load_onnx=False,
			fast=args["fast"],
			result_dir=args["result_dir"])

	output_path = model.inference(
			task=args["task"],
			text=args["text"],
			audio_prompt=args["audio_prompt"],
			chorus=args["chorus"],
			time_start=args["time_start"],
			time_end=args["time_end"],
			output_fn=args["output_fn"],
			max_audio_prompt_length=args["max_audio_prompt_length"],
			fade_out_duration=args["fade_out_duration"],
			output_format=args["format"],
			fade_out_mode=args["fade_out"],
			trim=args["trim"])
	return output_path


# @spaces.GPU(duration=120)
def demo_inspiremusic_t2m(text, model_name, chorus,
					 output_sample_rate, max_generate_audio_seconds):
	args = get_args(
			task='text-to-music', text=text, audio=None,
			model_name=model_name, chorus=chorus,
			output_sample_rate=output_sample_rate,
			max_generate_audio_seconds=max_generate_audio_seconds)
	return music_generation(args)

# @spaces.GPU(duration=120)
def demo_inspiremusic_con(text, audio, model_name, chorus,
					 output_sample_rate, max_generate_audio_seconds):
	args = get_args(
			task='continuation', text=text, audio=trim_audio(audio, cut_seconds=5),
			model_name=model_name, chorus=chorus,
			output_sample_rate=output_sample_rate,
			max_generate_audio_seconds=max_generate_audio_seconds)
	return music_generation(args)

# @spaces.GPU(duration=120)
def main():
	with gr.Blocks(theme=gr.themes.Soft()) as demo:
		gr.Markdown("""
		# InspireMusic
		- Support text-to-music, music continuation, audio super-resolution, audio reconstruction tasks with high audio quality, with available sampling rates of 24kHz, 48kHz. 
		- Support long audio generation in multiple output audio formats, i.e., wav, flac, mp3, m4a.
		- Open-source [InspireMusic-Base](https://modelscope.cn/models/iic/InspireMusic/summary), [InspireMusic-Base-24kHz](https://modelscope.cn/models/iic/InspireMusic-Base-24kHz/summary), [InspireMusic-1.5B](https://modelscope.cn/models/iic/InspireMusic-1.5B/summary), [InspireMusic-1.5B-24kHz](https://modelscope.cn/models/iic/InspireMusic-1.5B-24kHz/summary), [InspireMusic-1.5B-Long](https://modelscope.cn/models/iic/InspireMusic-1.5B-Long/summary) models for music generation.
		- Currently only support English text prompts.
		""")

		with gr.Row(equal_height=True):
			model_name = gr.Dropdown(["InspireMusic-1.5B-Long", "InspireMusic-1.5B", "InspireMusic-1.5B-24kHz", "InspireMusic-Base", "InspireMusic-Base-24kHz"], label="Select Model Name", value="InspireMusic-Base")
			chorus = gr.Dropdown(["intro", "verse", "chorus", "outro"],
								 label="Chorus Mode", value="intro")
			output_sample_rate = gr.Dropdown([48000, 24000],
											 label="Output Audio Sample Rate (Hz)",
											 value=48000)
			max_generate_audio_seconds = gr.Slider(10, 120,
												   label="Generate Audio Length (s)",
												   value=30)

		# with gr.Row(equal_height=True):
		text_input = gr.Textbox(label="Input Text (For Text-to-Music Task)", value="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.")
		music_output = gr.Audio(label="Text to Music Output", type="filepath")

		button = gr.Button("Text to Music")
		button.click(demo_inspiremusic_t2m,
						  inputs=[text_input, model_name,
								  chorus,
								  output_sample_rate,
								  max_generate_audio_seconds],
						  outputs=music_output)

		audio_input = gr.Audio(label="Input Audio Prompt (For Music Continuation Task)",
								   type="filepath")

		music_con_output = gr.Audio(label="Music Continuation Output", type="filepath")
		generate_button = gr.Button("Music Continuation")
		generate_button.click(demo_inspiremusic_con,
							  inputs=[text_input, audio_input, model_name,
									  chorus,
									  output_sample_rate,
									  max_generate_audio_seconds],
							  outputs=music_con_output)
	demo.launch()

if __name__ == '__main__':
	# subprocess.run(['bash', 'install.sh'], shell=True)
	main()