OpenSUNO / app.py
ginipick's picture
Update app.py
01ee1f1 verified
raw
history blame
32.7 kB
import gradio as gr
import subprocess
import os
import shutil
import tempfile
import torch
import logging
import numpy as np
import re
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from datetime import datetime
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('yue_generation.log'),
logging.StreamHandler()
]
)
def optimize_gpu_settings():
if torch.cuda.is_available():
# GPU ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์ตœ์ ํ™”
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = False
# L40S์— ์ตœ์ ํ™”๋œ ๋ฉ”๋ชจ๋ฆฌ ์„ค์ •
torch.cuda.empty_cache()
torch.cuda.set_device(0)
# CUDA ์ŠคํŠธ๋ฆผ ์ตœ์ ํ™”
torch.cuda.Stream(0)
# ๋ฉ”๋ชจ๋ฆฌ ํ• ๋‹น ์ตœ์ ํ™”
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
logging.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
# L40S ํŠนํ™” ์„ค์ •
if 'L40S' in torch.cuda.get_device_name(0):
torch.cuda.set_per_process_memory_fraction(0.95)
def analyze_lyrics(lyrics, repeat_chorus=2):
lines = [line.strip() for line in lyrics.split('\n') if line.strip()]
sections = {
'verse': 0,
'chorus': 0,
'bridge': 0,
'total_lines': len(lines)
}
current_section = None
section_lines = {
'verse': [],
'chorus': [],
'bridge': []
}
last_section = None
# ๋งˆ์ง€๋ง‰ ์„น์…˜ ํƒœ๊ทธ ์ฐพ๊ธฐ
for i, line in enumerate(lines):
if '[verse]' in line.lower() or '[chorus]' in line.lower() or '[bridge]' in line.lower():
last_section = i
for i, line in enumerate(lines):
lower_line = line.lower()
# ์„น์…˜ ํƒœ๊ทธ ์ฒ˜๋ฆฌ
if '[verse]' in lower_line:
if current_section: # ์ด์ „ ์„น์…˜์˜ ๋ผ์ธ๋“ค ์ €์žฅ
section_lines[current_section].extend(lines[last_section_start:i])
current_section = 'verse'
sections['verse'] += 1
last_section_start = i + 1
continue
elif '[chorus]' in lower_line:
if current_section:
section_lines[current_section].extend(lines[last_section_start:i])
current_section = 'chorus'
sections['chorus'] += 1
last_section_start = i + 1
continue
elif '[bridge]' in lower_line:
if current_section:
section_lines[current_section].extend(lines[last_section_start:i])
current_section = 'bridge'
sections['bridge'] += 1
last_section_start = i + 1
continue
# ๋งˆ์ง€๋ง‰ ์„น์…˜์˜ ๋ผ์ธ๋“ค ์ถ”๊ฐ€
if current_section and last_section_start < len(lines):
section_lines[current_section].extend(lines[last_section_start:])
# ์ฝ”๋Ÿฌ์Šค ๋ฐ˜๋ณต ์ฒ˜๋ฆฌ
if sections['chorus'] > 0 and repeat_chorus > 1:
original_chorus = section_lines['chorus'][:]
for _ in range(repeat_chorus - 1):
section_lines['chorus'].extend(original_chorus)
# ์„น์…˜๋ณ„ ๋ผ์ธ ์ˆ˜ ํ™•์ธ ๋กœ๊น…
logging.info(f"Section line counts - Verse: {len(section_lines['verse'])}, "
f"Chorus: {len(section_lines['chorus'])}, "
f"Bridge: {len(section_lines['bridge'])}")
return sections, (sections['verse'] + sections['chorus'] + sections['bridge']), len(lines), section_lines
def calculate_generation_params(lyrics):
sections, total_sections, total_lines, section_lines = analyze_lyrics(lyrics)
# ๊ธฐ๋ณธ ์‹œ๊ฐ„ ๊ณ„์‚ฐ (์ดˆ ๋‹จ์œ„)
time_per_line = {
'verse': 4, # verse๋Š” ํ•œ ์ค„๋‹น 4์ดˆ
'chorus': 6, # chorus๋Š” ํ•œ ์ค„๋‹น 6์ดˆ
'bridge': 5 # bridge๋Š” ํ•œ ์ค„๋‹น 5์ดˆ
}
# ๊ฐ ์„น์…˜๋ณ„ ์˜ˆ์ƒ ์‹œ๊ฐ„ ๊ณ„์‚ฐ (๋งˆ์ง€๋ง‰ ์„น์…˜ ํฌํ•จ)
section_durations = {}
for section_type in ['verse', 'chorus', 'bridge']:
lines_count = len(section_lines[section_type])
section_durations[section_type] = lines_count * time_per_line[section_type]
# ์ „์ฒด ์‹œ๊ฐ„ ๊ณ„์‚ฐ (์—ฌ์œ  ์‹œ๊ฐ„ ์ถ”๊ฐ€)
total_duration = sum(duration for duration in section_durations.values())
total_duration = max(60, int(total_duration * 1.2)) # 20% ์—ฌ์œ  ์‹œ๊ฐ„ ์ถ”๊ฐ€
# ํ† ํฐ ๊ณ„์‚ฐ (๋งˆ์ง€๋ง‰ ์„น์…˜์„ ์œ„ํ•œ ์ถ”๊ฐ€ ํ† ํฐ)
base_tokens = 3000
tokens_per_line = 200
extra_tokens = 1000 # ๋งˆ์ง€๋ง‰ ์„น์…˜์„ ์œ„ํ•œ ์ถ”๊ฐ€ ํ† ํฐ
total_tokens = base_tokens + (total_lines * tokens_per_line) + extra_tokens
# ์„ธ๊ทธ๋จผํŠธ ์ˆ˜ ๊ณ„์‚ฐ (๋งˆ์ง€๋ง‰ ์„น์…˜์„ ์œ„ํ•œ ์ถ”๊ฐ€ ์„ธ๊ทธ๋จผํŠธ)
if sections['chorus'] > 0:
num_segments = 4 # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ 4๊ฐœ ์„ธ๊ทธ๋จผํŠธ
else:
num_segments = 3 # ์ฝ”๋Ÿฌ์Šค๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ 3๊ฐœ ์„ธ๊ทธ๋จผํŠธ
# ํ† ํฐ ์ˆ˜ ์ œํ•œ (๋” ํฐ ์ œํ•œ)
max_tokens = min(12000, total_tokens) # ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜ ์ฆ๊ฐ€
return {
'max_tokens': max_tokens,
'num_segments': num_segments,
'sections': sections,
'section_lines': section_lines,
'estimated_duration': total_duration,
'section_durations': section_durations,
'has_chorus': sections['chorus'] > 0
}
def detect_and_select_model(text):
if re.search(r'[\u3131-\u318E\uAC00-\uD7A3]', text):
return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
elif re.search(r'[\u4e00-\u9fff]', text):
return "m-a-p/YuE-s1-7B-anneal-zh-cot"
elif re.search(r'[\u3040-\u309F\u30A0-\u30FF]', text):
return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot"
else:
return "m-a-p/YuE-s1-7B-anneal-en-cot"
def install_flash_attn():
try:
if not torch.cuda.is_available():
logging.warning("GPU not available, skipping flash-attn installation")
return False
cuda_version = torch.version.cuda
if cuda_version is None:
logging.warning("CUDA not available, skipping flash-attn installation")
return False
logging.info(f"Detected CUDA version: {cuda_version}")
try:
import flash_attn
logging.info("flash-attn already installed")
return True
except ImportError:
logging.info("Installing flash-attn...")
subprocess.run(
["pip", "install", "flash-attn", "--no-build-isolation"],
check=True,
capture_output=True
)
logging.info("flash-attn installed successfully!")
return True
except Exception as e:
logging.warning(f"Failed to install flash-attn: {e}")
return False
def download_missing_files():
try:
from huggingface_hub import hf_hub_download, snapshot_download
# xcodec_mini_infer ํŒŒ์ผ๋“ค ์ง์ ‘ ๋‹ค์šด๋กœ๋“œ
repo_id = "hf-internal-testing/xcodec_mini_infer" # ์ €์žฅ์†Œ ๊ฒฝ๋กœ ์ˆ˜์ •
files_to_download = {
"config.json": "config/config.json",
"vocal_decoder.pth": "checkpoints/vocal_decoder.pth",
"inst_decoder.pth": "checkpoints/inst_decoder.pth"
}
xcodec_dir = "./xcodec_mini_infer"
os.makedirs(xcodec_dir, exist_ok=True)
os.makedirs(os.path.join(xcodec_dir, "checkpoints"), exist_ok=True)
for target_name, source_path in files_to_download.items():
try:
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=source_path,
cache_dir="./models/cache",
force_download=True,
local_files_only=False
)
logging.info(f"Downloaded {source_path} to: {downloaded_path}")
# ํŒŒ์ผ์„ ์˜ฌ๋ฐ”๋ฅธ ์œ„์น˜๋กœ ๋ณต์‚ฌ
target_path = os.path.join(xcodec_dir, target_name)
shutil.copy2(downloaded_path, target_path)
logging.info(f"Copied to: {target_path}")
except Exception as e:
logging.error(f"Error downloading {source_path}: {e}")
# ๋Œ€์ฒด ์ €์žฅ์†Œ ์‹œ๋„
try:
alt_repo_id = "facebook/musicgen-small"
downloaded_path = hf_hub_download(
repo_id=alt_repo_id,
filename=source_path,
cache_dir="./models/cache",
force_download=True
)
target_path = os.path.join(xcodec_dir, target_name)
shutil.copy2(downloaded_path, target_path)
logging.info(f"Downloaded from alternate source to: {target_path}")
except Exception as alt_e:
logging.error(f"Error with alternate download: {alt_e}")
raise
# YuE ๋ชจ๋ธ๋“ค ๋‹ค์šด๋กœ๋“œ
models = [
"m-a-p/YuE-s1-7B-anneal-jp-kr-cot",
"m-a-p/YuE-s1-7B-anneal-en-cot",
"m-a-p/YuE-s1-7B-anneal-zh-cot",
"m-a-p/YuE-s2-1B-general"
]
for model in models:
model_name = model.split('/')[-1]
model_path = snapshot_download(
repo_id=model,
local_dir=f"./models/{model_name}",
cache_dir="./models/cache",
resume_download=True,
force_download=True
)
logging.info(f"Downloaded {model_name} to: {model_path}")
# ํŒŒ์ผ ์กด์žฌ ๋ฐ ํฌ๊ธฐ ํ™•์ธ
for target_name in files_to_download.keys():
file_path = os.path.join(xcodec_dir, target_name)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Failed to download {target_name}")
file_size = os.path.getsize(file_path)
if file_size == 0:
raise FileNotFoundError(f"Downloaded file is empty: {target_name}")
logging.info(f"Verified {target_name}: {file_size} bytes")
logging.info("All required models downloaded successfully")
except Exception as e:
logging.error(f"Error downloading models: {e}")
raise
def check_model_files():
base_dir = os.getcwd()
xcodec_dir = os.path.join(base_dir, "xcodec_mini_infer")
# ํ•„์š”ํ•œ ํŒŒ์ผ ๋ชฉ๋ก
required_files = {
"config.json": "config.json",
"vocal_decoder.pth": "vocal_decoder.pth",
"inst_decoder.pth": "inst_decoder.pth"
}
# ํŒŒ์ผ ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ
missing = False
for file_name in required_files.keys():
file_path = os.path.join(xcodec_dir, file_name)
if not os.path.exists(file_path):
missing = True
logging.warning(f"Missing file: {file_path}")
if missing:
logging.info("Downloading missing files...")
download_missing_files()
# ๋‹ค์šด๋กœ๋“œ ํ›„ ํŒŒ์ผ ์žฌํ™•์ธ
for file_name in required_files.keys():
file_path = os.path.join(xcodec_dir, file_name)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Failed to download or locate required file: {file_name}")
else:
file_size = os.path.getsize(file_path)
if file_size == 0:
raise FileNotFoundError(f"Downloaded file is empty: {file_name}")
logging.info(f"Verified {file_name}: {file_size} bytes")
def initialize_system():
optimize_gpu_settings()
try:
# ๊ธฐ๋ณธ ๋””๋ ‰ํ† ๋ฆฌ ๊ตฌ์กฐ ์ƒ์„ฑ
base_dir = os.path.abspath("./inference")
os.makedirs(base_dir, exist_ok=True)
os.makedirs(os.path.join(base_dir, "models"), exist_ok=True)
# ์ž‘์—… ๋””๋ ‰ํ† ๋ฆฌ ๋ณ€๊ฒฝ
os.chdir(base_dir)
logging.info(f"Working directory changed to: {os.getcwd()}")
from huggingface_hub import snapshot_download, hf_hub_download
# xcodec_mini_infer ํŒŒ์ผ๋“ค ์ง์ ‘ ๋‹ค์šด๋กœ๋“œ
xcodec_dir = os.path.join(base_dir, "xcodec_mini_infer")
os.makedirs(xcodec_dir, exist_ok=True)
# ํ•„์ˆ˜ ํŒŒ์ผ ์ง์ ‘ ๋‹ค์šด๋กœ๋“œ
required_files = {
"config.json": "config.json",
"vocal_decoder.pth": "vocal_decoder.pth",
"inst_decoder.pth": "inst_decoder.pth"
}
for file_name in required_files.keys():
try:
file_path = os.path.join(xcodec_dir, file_name)
if not os.path.exists(file_path):
downloaded_path = hf_hub_download(
repo_id="m-a-p/xcodec_mini_infer",
filename=file_name,
local_dir=xcodec_dir,
force_download=True
)
if downloaded_path != file_path:
shutil.copy2(downloaded_path, file_path)
logging.info(f"Downloaded {file_name} to {file_path}")
except Exception as e:
logging.error(f"Error downloading {file_name}: {e}")
raise
# YuE ๋ชจ๋ธ๋“ค ๋‹ค์šด๋กœ๋“œ
models = [
"m-a-p/YuE-s1-7B-anneal-jp-kr-cot",
"m-a-p/YuE-s1-7B-anneal-en-cot",
"m-a-p/YuE-s1-7B-anneal-zh-cot",
"m-a-p/YuE-s2-1B-general"
]
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
# Flash Attention ์„ค์น˜
futures.append(executor.submit(install_flash_attn))
# ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ
for model in models:
model_name = model.split('/')[-1]
model_path = os.path.join(base_dir, "models", model_name)
futures.append(executor.submit(
snapshot_download,
repo_id=model,
local_dir=model_path,
force_download=True
))
# ๋ชจ๋“  ์ž‘์—… ์™„๋ฃŒ ๋Œ€๊ธฐ
for future in futures:
future.result()
# ํŒŒ์ผ ์กด์žฌ ํ™•์ธ
for file_name, _ in required_files.items():
file_path = os.path.join(xcodec_dir, file_name)
if not os.path.exists(file_path):
raise FileNotFoundError(f"Required file still missing after download: {file_path}")
else:
file_size = os.path.getsize(file_path)
logging.info(f"Verified {file_name}: {file_size} bytes")
logging.info("System initialization completed successfully")
except Exception as e:
logging.error(f"Directory error: {e}")
raise
@lru_cache(maxsize=100)
def get_cached_file_path(content_hash, prefix):
return create_temp_file(content_hash, prefix)
def empty_output_folder(output_dir):
try:
shutil.rmtree(output_dir)
os.makedirs(output_dir)
logging.info(f"Output folder cleaned: {output_dir}")
except Exception as e:
logging.error(f"Error cleaning output folder: {e}")
raise
def create_temp_file(content, prefix, suffix=".txt"):
temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
content = content.strip() + "\n\n"
content = content.replace("\r\n", "\n").replace("\r", "\n")
temp_file.write(content)
temp_file.close()
logging.debug(f"Temporary file created: {temp_file.name}")
return temp_file.name
def get_last_mp3_file(output_dir):
mp3_files = [f for f in os.listdir(output_dir) if f.endswith('.mp3')]
if not mp3_files:
logging.warning("No MP3 files found")
return None
mp3_files_with_path = [os.path.join(output_dir, f) for f in mp3_files]
mp3_files_with_path.sort(key=os.path.getmtime, reverse=True)
return mp3_files_with_path[0]
def get_audio_duration(file_path):
try:
import librosa
duration = librosa.get_duration(path=file_path)
return duration
except Exception as e:
logging.error(f"Failed to get audio duration: {e}")
return None
def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
check_model_files() # ํ•„์š”ํ•œ ํŒŒ์ผ ์ฒดํฌ ๋ฐ ๋‹ค์šด๋กœ๋“œ
genre_txt_path = None
lyrics_txt_path = None
try:
model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content)
logging.info(f"Selected model: {model_path}")
logging.info(f"Lyrics analysis: {params}")
has_chorus = params['sections']['chorus'] > 0
estimated_duration = params.get('estimated_duration', 90)
# ์„ธ๊ทธ๋จผํŠธ ๋ฐ ํ† ํฐ ์ˆ˜ ์„ค์ •
if has_chorus:
actual_max_tokens = min(12000, int(config['max_tokens'] * 1.3)) # 30% ๋” ๋งŽ์€ ํ† ํฐ
actual_num_segments = min(5, params['num_segments'] + 2) # ์ถ”๊ฐ€ ์„ธ๊ทธ๋จผํŠธ
else:
actual_max_tokens = min(10000, int(config['max_tokens'] * 1.2))
actual_num_segments = min(4, params['num_segments'] + 1)
logging.info(f"Estimated duration: {estimated_duration} seconds")
logging.info(f"Has chorus sections: {has_chorus}")
logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}")
genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")
output_dir = "./output"
os.makedirs(output_dir, exist_ok=True)
empty_output_folder(output_dir)
command = [
"python", "infer.py",
"--stage1_model", model_path, # ์›๋ž˜ ๋ชจ๋ธ ๊ฒฝ๋กœ ์‚ฌ์šฉ
"--stage2_model", "m-a-p/YuE-s2-1B-general",
"--genre_txt", genre_txt_path,
"--lyrics_txt", lyrics_txt_path,
"--run_n_segments", str(actual_num_segments),
"--stage2_batch_size", "16",
"--output_dir", output_dir,
"--cuda_idx", "0",
"--max_new_tokens", str(actual_max_tokens),
"--disable_offload_model"
]
env = os.environ.copy()
if torch.cuda.is_available():
env.update({
"CUDA_VISIBLE_DEVICES": "0",
"CUDA_HOME": "/usr/local/cuda",
"PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
"LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}",
"PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512",
"CUDA_LAUNCH_BLOCKING": "0",
"TRANSFORMERS_CACHE": "./models/cache",
"HF_HOME": "./models/cache"
})
# transformers ์บ์‹œ ๋งˆ์ด๊ทธ๋ ˆ์ด์…˜ ์ฒ˜๋ฆฌ
try:
from transformers.utils import move_cache
move_cache()
except Exception as e:
logging.warning(f"Cache migration warning (non-critical): {e}")
process = subprocess.run(
command,
env=env,
check=False,
capture_output=True,
text=True
)
logging.info(f"Command output: {process.stdout}")
if process.stderr:
logging.error(f"Command error: {process.stderr}")
if process.returncode != 0:
logging.error(f"Command failed with return code: {process.returncode}")
logging.error(f"Command: {' '.join(command)}")
raise RuntimeError(f"Inference failed: {process.stderr}")
last_mp3 = get_last_mp3_file(output_dir)
if last_mp3:
try:
duration = get_audio_duration(last_mp3)
logging.info(f"Generated audio file: {last_mp3}")
if duration:
logging.info(f"Audio duration: {duration:.2f} seconds")
logging.info(f"Expected duration: {estimated_duration} seconds")
if duration < estimated_duration * 0.8:
logging.warning(f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s")
except Exception as e:
logging.warning(f"Failed to get audio duration: {e}")
return last_mp3
else:
logging.warning("No output audio file generated")
return None
except Exception as e:
logging.error(f"Inference error: {e}")
raise
finally:
for path in [genre_txt_path, lyrics_txt_path]:
if path and os.path.exists(path):
try:
os.remove(path)
logging.debug(f"Removed temporary file: {path}")
except Exception as e:
logging.warning(f"Failed to remove temporary file {path}: {e}")
def optimize_model_selection(lyrics, genre):
model_path = detect_and_select_model(lyrics)
params = calculate_generation_params(lyrics)
has_chorus = params['sections']['chorus'] > 0
tokens_per_segment = params['max_tokens'] // params['num_segments']
model_config = {
"m-a-p/YuE-s1-7B-anneal-en-cot": {
"max_tokens": params['max_tokens'],
"temperature": 0.8,
"batch_size": 16,
"num_segments": params['num_segments'],
"estimated_duration": params['estimated_duration']
},
"m-a-p/YuE-s1-7B-anneal-jp-kr-cot": {
"max_tokens": params['max_tokens'],
"temperature": 0.7,
"batch_size": 16,
"num_segments": params['num_segments'],
"estimated_duration": params['estimated_duration']
},
"m-a-p/YuE-s1-7B-anneal-zh-cot": {
"max_tokens": params['max_tokens'],
"temperature": 0.7,
"batch_size": 16,
"num_segments": params['num_segments'],
"estimated_duration": params['estimated_duration']
}
}
if has_chorus:
for config in model_config.values():
config['max_tokens'] = int(config['max_tokens'] * 1.5)
return model_path, model_config[model_path], params
css = """
#main-container {
max-width: 1200px;
margin: auto;
padding: 20px;
}
#header {
text-align: center;
margin-bottom: 30px;
}
#genre-input, #lyrics-input {
border-radius: 8px;
}
#generate-btn {
margin-top: 20px;
min-height: 45px;
}
.label {
font-weight: bold;
}
.example-container {
background: #f8f9fa;
padding: 15px;
border-radius: 8px;
margin: 10px 0;
}
"""
def main():
with gr.Blocks(theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="purple",
neutral_hue="slate",
font=["Arial", "sans-serif"]
), css=css) as demo:
with gr.Column(elem_id="main-container"):
# ํ—ค๋” ์„น์…˜
with gr.Row(elem_id="header"):
gr.Markdown(
"""
# ๐ŸŽต Open SUNO: Transform Your Lyrics into Complete Songs
### Create complete songs from your lyrics in multiple languages
""",
elem_id="title"
)
# ๋ฉ”์ธ ์ปจํ…์ธ ๋ฅผ ํƒญ์œผ๋กœ ๊ตฌ์„ฑ
with gr.Tabs() as tabs:
# ์ƒ์„ฑ ํƒญ
with gr.TabItem("โœจ Create Music", id="create"):
with gr.Row():
# ์ž…๋ ฅ ์„น์…˜
with gr.Column(scale=1):
genre_txt = gr.Textbox(
label="๐ŸŽธ Music Genre & Style",
placeholder="e.g., K-pop bright energetic synth dance electronic...",
elem_id="genre-input"
)
lyrics_txt = gr.Textbox(
label="๐Ÿ“ Lyrics",
placeholder="Enter lyrics with section tags: [verse], [chorus], [bridge]...",
lines=10,
elem_id="lyrics-input"
)
# ์ •๋ณด ํ‘œ์‹œ ์„น์…˜
with gr.Row():
with gr.Column(scale=1):
duration_info = gr.Label(
label="โฑ๏ธ Estimated Duration",
elem_id="duration-info"
)
with gr.Column(scale=1):
sections_info = gr.Label(
label="๐Ÿ“Š Section Analysis",
elem_id="sections-info"
)
# ์ƒ์„ฑ ๋ฒ„ํŠผ
submit_btn = gr.Button(
"๐ŸŽผ Generate Music",
variant="primary",
elem_id="generate-btn"
)
# ์ถœ๋ ฅ ์„น์…˜
with gr.Column(scale=1):
music_out = gr.Audio(
label="๐ŸŽต Generated Music",
elem_id="music-output"
)
# ์ง„ํ–‰ ์ƒํƒœ ํ‘œ์‹œ
progress = gr.Textbox(
label="Generation Status",
interactive=False,
elem_id="progress-status"
)
# ํžˆ์Šคํ† ๋ฆฌ ํƒญ
with gr.TabItem("๐Ÿ“š History", id="history"):
with gr.Row():
history_container = gr.HTML("""
<div id="history-container" style="width: 100%; padding: 10px;">
<h3>๐ŸŽต Generation History</h3>
<div id="history-list"></div>
</div>
""")
# ํžˆ์Šคํ† ๋ฆฌ ์ƒํƒœ ์ €์žฅ
history_state = gr.State([])
# ์˜ˆ์ œ ์„น์…˜
with gr.Accordion("๐Ÿ“– Examples", open=False):
gr.Examples(
examples=[
[
"female blues airy vocal bright vocal piano sad romantic guitar jazz",
"""[verse]
In the quiet of the evening, shadows start to fall
Whispers of the night wind echo through the hall
Lost within the silence, I hear your gentle voice
Guiding me back homeward, making my heart rejoice
[chorus]
Don't let this moment fade, hold me close tonight
"""
],
[
"K-pop bright energetic synth dance electronic",
"""[verse]
์–ธ์  ๊ฐ€ ๋งˆ์ฃผํ•œ ๋ˆˆ๋น› ์†์—์„œ
์–ด๋‘์šด ๋ฐค์„ ์ง€๋‚  ๋•Œ๋งˆ๋‹ค
[chorus]
๋‹ค์‹œ ํ•œ ๋ฒˆ ๋‚ด๊ฒŒ ๋งํ•ด์ค˜
"""
]
],
inputs=[genre_txt, lyrics_txt]
)
# ๋„์›€๋ง ๋ฐ ์„ค๋ช… ์„น์…˜
with gr.Accordion("โ„น๏ธ Help & Information", open=False):
gr.Markdown(
"""
### ๐ŸŽต How to Use
1. **Enter Genre & Style**: Describe the musical style you want (e.g., "K-pop", "Jazz", "Rock")
2. **Input Lyrics**: Write your lyrics using section tags:
- Use `[verse]` for verses
- Use `[chorus]` for choruses
- Use `[bridge]` for bridges
3. **Generate**: Click the Generate button and wait for your music!
### ๐ŸŒ Supported Languages
- English
- Korean (ํ•œ๊ตญ์–ด)
- Japanese (ๆ—ฅๆœฌ่ชž)
- Chinese (ไธญๆ–‡)
### โšก Tips
- Be specific with your genre descriptions
- Include emotion and instrument preferences
- Make sure to properly tag your lyrics sections
- For best results, include both verse and chorus sections
"""
)
# ์ˆจ๊ฒจ์ง„ ์ƒํƒœ ๋ณ€์ˆ˜๋“ค
num_segments = gr.State(value=2) # ๊ธฐ๋ณธ๊ฐ’ 2
max_new_tokens = gr.State(value=4000) # ๊ธฐ๋ณธ๊ฐ’ 4000
# ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™”
initialize_system()
def update_info(lyrics):
if not lyrics:
return "No lyrics entered", "No sections detected"
params = calculate_generation_params(lyrics)
duration = params['estimated_duration']
sections = params['sections']
return (
f"โฑ๏ธ Estimated: {duration:.1f} seconds",
f"๐Ÿ“Š Verses: {sections['verse']}, Chorus: {sections['chorus']}"
)
def update_history(audio_path, genre, lyrics, history):
if audio_path:
new_entry = {
"audio": audio_path,
"genre": genre,
"lyrics": lyrics,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") # datetime ์ง์ ‘ ์‚ฌ์šฉ
}
history = [new_entry] + (history or [])
history_html = "<div class='history-entries'>"
for entry in history:
history_html += f"""
<div class='history-entry' style='margin: 10px 0; padding: 10px; border: 1px solid #ddd; border-radius: 8px;'>
<audio controls src='{entry["audio"]}'></audio>
<div style='margin-top: 5px;'><strong>Genre:</strong> {entry["genre"]}</div>
<div style='margin-top: 5px;'><strong>Lyrics:</strong><pre>{entry["lyrics"]}</pre></div>
<div style='color: #666; font-size: 0.9em;'>{entry["timestamp"]}</div>
</div>
"""
history_html += "</div>"
return history, history_html
return history, ""
def generate_with_progress(genre, lyrics, segments, tokens, history):
try:
status_text = "๐ŸŽต Starting generation..."
result = infer(genre, lyrics, segments, tokens)
if result:
status_text = "โœ… Generation complete!"
new_history, history_html = update_history(result, genre, lyrics, history)
return result, new_history, history_html, status_text
else:
status_text = "โŒ Generation failed"
return None, history, "", status_text
except Exception as e:
status_text = f"โŒ Error: {str(e)}"
return None, history, "", status_text
# ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
lyrics_txt.change(
fn=update_info,
inputs=[lyrics_txt],
outputs=[duration_info, sections_info]
)
# submit ๋ฒ„ํŠผ ํด๋ฆญ ์ด๋ฒคํŠธ ์ถ”๊ฐ€
submit_btn.click(
fn=generate_with_progress,
inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens, history_state],
outputs=[music_out, history_state, history_container, progress]
)
return demo
if __name__ == "__main__":
demo = main()
demo.queue(max_size=20).launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_api=True,
show_error=True,
max_threads=8
)