import gradio as gr import subprocess import os import shutil import tempfile import torch import logging import numpy as np import re from concurrent.futures import ThreadPoolExecutor from functools import lru_cache # 로깅 설정 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('yue_generation.log'), logging.StreamHandler() ] ) ################################ # 기존에 정의된 함수 및 로직들 # ################################ def optimize_gpu_settings(): if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = False torch.cuda.empty_cache() torch.cuda.set_device(0) torch.cuda.Stream(0) os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}") logging.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") if 'L40S' in torch.cuda.get_device_name(0): torch.cuda.set_per_process_memory_fraction(0.95) def analyze_lyrics(lyrics, repeat_chorus=2): lines = [line.strip() for line in lyrics.split('\n') if line.strip()] sections = { 'verse': 0, 'chorus': 0, 'bridge': 0, 'total_lines': len(lines) } current_section = None section_lines = { 'verse': [], 'chorus': [], 'bridge': [] } last_section = None for i, line in enumerate(lines): if '[verse]' in line.lower() or '[chorus]' in line.lower() or '[bridge]' in line.lower(): last_section = i for i, line in enumerate(lines): lower_line = line.lower() if '[verse]' in lower_line: if current_section: section_lines[current_section].extend(lines[last_section_start:i]) current_section = 'verse' sections['verse'] += 1 last_section_start = i + 1 continue elif '[chorus]' in lower_line: if current_section: section_lines[current_section].extend(lines[last_section_start:i]) current_section = 'chorus' sections['chorus'] += 1 last_section_start = i + 1 continue elif '[bridge]' in lower_line: if current_section: section_lines[current_section].extend(lines[last_section_start:i]) current_section = 'bridge' sections['bridge'] += 1 last_section_start = i + 1 continue if current_section and 'last_section_start' in locals() and last_section_start < len(lines): section_lines[current_section].extend(lines[last_section_start:]) if sections['chorus'] > 0 and repeat_chorus > 1: original_chorus = section_lines['chorus'][:] for _ in range(repeat_chorus - 1): section_lines['chorus'].extend(original_chorus) logging.info( f"Section line counts - Verse: {len(section_lines['verse'])}, " f"Chorus: {len(section_lines['chorus'])}, " f"Bridge: {len(section_lines['bridge'])}" ) return sections, (sections['verse'] + sections['chorus'] + sections['bridge']), len(lines), section_lines def calculate_generation_params(lyrics): sections, total_sections, total_lines, section_lines = analyze_lyrics(lyrics) time_per_line = { 'verse': 4, 'chorus': 6, 'bridge': 5 } section_durations = {} for section_type in ['verse', 'chorus', 'bridge']: lines_count = len(section_lines[section_type]) section_durations[section_type] = lines_count * time_per_line[section_type] total_duration = sum(duration for duration in section_durations.values()) total_duration = max(60, int(total_duration * 1.2)) base_tokens = 3000 tokens_per_line = 200 extra_tokens = 1000 total_tokens = base_tokens + (total_lines * tokens_per_line) + extra_tokens if sections['chorus'] > 0: num_segments = 4 else: num_segments = 3 max_tokens = min(12000, total_tokens) return { 'max_tokens': max_tokens, 'num_segments': num_segments, 'sections': sections, 'section_lines': section_lines, 'estimated_duration': total_duration, 'section_durations': section_durations, 'has_chorus': sections['chorus'] > 0 } def detect_and_select_model(text): if re.search(r'[\u3131-\u318E\uAC00-\uD7A3]', text): return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot" elif re.search(r'[\u4e00-\u9fff]', text): return "m-a-p/YuE-s1-7B-anneal-zh-cot" elif re.search(r'[\u3040-\u309F\u30A0-\u30FF]', text): return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot" else: return "m-a-p/YuE-s1-7B-anneal-en-cot" def install_flash_attn(): try: if not torch.cuda.is_available(): logging.warning("GPU not available, skipping flash-attn installation") return False cuda_version = torch.version.cuda if cuda_version is None: logging.warning("CUDA not available, skipping flash-attn installation") return False logging.info(f"Detected CUDA version: {cuda_version}") try: import flash_attn logging.info("flash-attn already installed") return True except ImportError: logging.info("Installing flash-attn...") subprocess.run( ["pip", "install", "flash-attn", "--no-build-isolation"], check=True, capture_output=True ) logging.info("flash-attn installed successfully!") return True except Exception as e: logging.warning(f"Failed to install flash-attn: {e}") return False def initialize_system(): optimize_gpu_settings() with ThreadPoolExecutor(max_workers=4) as executor: futures = [] futures.append(executor.submit(install_flash_attn)) from huggingface_hub import snapshot_download folder_path = './inference/xcodec_mini_infer' os.makedirs(folder_path, exist_ok=True) logging.info(f"Created folder at: {folder_path}") futures.append(executor.submit( snapshot_download, repo_id="m-a-p/xcodec_mini_infer", local_dir="./inference/xcodec_mini_infer", resume_download=True )) for future in futures: future.result() try: os.chdir("./inference") logging.info(f"Working directory changed to: {os.getcwd()}") except FileNotFoundError as e: logging.error(f"Directory error: {e}") raise @lru_cache(maxsize=100) def get_cached_file_path(content_hash, prefix): return create_temp_file(content_hash, prefix) def empty_output_folder(output_dir): try: shutil.rmtree(output_dir) os.makedirs(output_dir) logging.info(f"Output folder cleaned: {output_dir}") except Exception as e: logging.error(f"Error cleaning output folder: {e}") raise def create_temp_file(content, prefix, suffix=".txt"): temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix) content = content.strip() + "\n\n" content = content.replace("\r\n", "\n").replace("\r", "\n") temp_file.write(content) temp_file.close() logging.debug(f"Temporary file created: {temp_file.name}") return temp_file.name def get_last_mp3_file(output_dir): mp3_files = [f for f in os.listdir(output_dir) if f.endswith('.mp3')] if not mp3_files: logging.warning("No MP3 files found") return None mp3_files_with_path = [os.path.join(output_dir, f) for f in mp3_files] mp3_files_with_path.sort(key=os.path.getmtime, reverse=True) return mp3_files_with_path[0] def get_audio_duration(file_path): try: import librosa duration = librosa.get_duration(path=file_path) return duration except Exception as e: logging.error(f"Failed to get audio duration: {e}") return None def optimize_model_selection(lyrics, genre): model_path = detect_and_select_model(lyrics) params = calculate_generation_params(lyrics) has_chorus = params['sections']['chorus'] > 0 model_config = { "m-a-p/YuE-s1-7B-anneal-en-cot": { "max_tokens": params['max_tokens'], "temperature": 0.8, "batch_size": 16, "num_segments": params['num_segments'], "estimated_duration": params['estimated_duration'] }, "m-a-p/YuE-s1-7B-anneal-jp-kr-cot": { "max_tokens": params['max_tokens'], "temperature": 0.7, "batch_size": 16, "num_segments": params['num_segments'], "estimated_duration": params['estimated_duration'] }, "m-a-p/YuE-s1-7B-anneal-zh-cot": { "max_tokens": params['max_tokens'], "temperature": 0.7, "batch_size": 16, "num_segments": params['num_segments'], "estimated_duration": params['estimated_duration'] } } if has_chorus: for config in model_config.values(): config['max_tokens'] = int(config['max_tokens'] * 1.5) return model_path, model_config[model_path], params def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens): genre_txt_path = None lyrics_txt_path = None try: model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content) logging.info(f"Selected model: {model_path}") logging.info(f"Lyrics analysis: {params}") has_chorus = params['sections']['chorus'] > 0 estimated_duration = params.get('estimated_duration', 90) # 세그먼트 및 토큰 수 설정 if has_chorus: actual_max_tokens = min(12000, int(config['max_tokens'] * 1.3)) # 30% 더 많은 토큰 actual_num_segments = min(5, params['num_segments'] + 2) # 추가 세그먼트 else: actual_max_tokens = min(10000, int(config['max_tokens'] * 1.2)) actual_num_segments = min(4, params['num_segments'] + 1) logging.info(f"Estimated duration: {estimated_duration} seconds") logging.info(f"Has chorus sections: {has_chorus}") logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}") genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_") lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_") output_dir = "./output" os.makedirs(output_dir, exist_ok=True) empty_output_folder(output_dir) command = [ "python", "infer.py", "--stage1_model", model_path, "--stage2_model", "m-a-p/YuE-s2-1B-general", "--genre_txt", genre_txt_path, "--lyrics_txt", lyrics_txt_path, "--run_n_segments", str(actual_num_segments), "--stage2_batch_size", "16", "--output_dir", output_dir, "--cuda_idx", "0", "--max_new_tokens", str(actual_max_tokens), "--disable_offload_model" ] env = os.environ.copy() if torch.cuda.is_available(): env.update({ "CUDA_VISIBLE_DEVICES": "0", "CUDA_HOME": "/usr/local/cuda", "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}", "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}", "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512", "CUDA_LAUNCH_BLOCKING": "0" }) # transformers 캐시 마이그레이션 처리 (버전에 따라 동작하지 않을 수 있음) try: from transformers.utils import move_cache move_cache() except Exception as e: logging.warning(f"Cache migration warning (non-critical): {e}") process = subprocess.run( command, env=env, check=False, capture_output=True, text=True ) logging.info(f"Command output: {process.stdout}") if process.stderr: logging.error(f"Command error: {process.stderr}") if process.returncode != 0: logging.error(f"Command failed with return code: {process.returncode}") logging.error(f"Command: {' '.join(command)}") raise RuntimeError(f"Inference failed: {process.stderr}") last_mp3 = get_last_mp3_file(output_dir) if last_mp3: try: duration = get_audio_duration(last_mp3) logging.info(f"Generated audio file: {last_mp3}") if duration: logging.info(f"Audio duration: {duration:.2f} seconds") logging.info(f"Expected duration: {estimated_duration} seconds") if duration < estimated_duration * 0.8: logging.warning( f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s" ) except Exception as e: logging.warning(f"Failed to get audio duration: {e}") return last_mp3 else: logging.warning("No output audio file generated") return None except Exception as e: logging.error(f"Inference error: {e}") raise finally: for path in [genre_txt_path, lyrics_txt_path]: if path and os.path.exists(path): try: os.remove(path) logging.debug(f"Removed temporary file: {path}") except Exception as e: logging.warning(f"Failed to remove temporary file {path}: {e}") ##################################### # 아래부터 Gradio UI 및 main() 부분 # ##################################### def update_info(lyrics): """가사 변경 시 추정 정보를 업데이트하는 함수.""" if not lyrics: return "No lyrics entered", "No sections detected" params = calculate_generation_params(lyrics) duration = params['estimated_duration'] sections = params['sections'] return ( f"Estimated duration: {duration:.1f} seconds", f"Verses: {sections['verse']}, Chorus: {sections['chorus']} (Expected full length including chorus)" ) def main(): # 먼저 시스템 초기화 (GPU 최적화, 필요한 모델 다운로드 등) initialize_system() with gr.Blocks(css=""" /* 전체 배경 및 컨테이너 스타일 */ body { background-color: #f5f5f5; } .gradio-container { max-width: 1000px; margin: auto !important; background-color: #ffffff; border-radius: 8px; padding: 20px; box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); } /* 텍스트 크기, 마진 조정 */ h1, h2, h3 { margin: 0; padding: 0; } p { margin: 5px 0; } /* 예제 블록 스타일 */ .gr-examples { background-color: #fafafa; border-radius: 8px; padding: 10px; } """) as demo: # 상단 헤더 gr.HTML("""

Open SUNO: Full-Song Generation (Multi-Language Support)

Enter your song details below and let the AI handle the music production!

""") with gr.Row(): # 왼쪽 입력 컬럼 with gr.Column(): genre_txt = gr.Textbox( label="Genre", placeholder="Enter music genre and style descriptions...", lines=2 ) lyrics_txt = gr.Textbox( label="Lyrics (Supports English, Korean, Japanese, Chinese)", placeholder="Enter song lyrics with [verse], [chorus], [bridge] tags...", lines=10 ) # 오른쪽 설정/정보 컬럼 with gr.Column(): with gr.Box(): gr.Markdown("### Generation Settings") num_segments = gr.Number( label="Number of Song Segments (Auto-adjusted)", value=2, minimum=1, maximum=4, step=1, interactive=False ) max_new_tokens = gr.Slider( label="Max New Tokens (Auto-adjusted)", minimum=500, maximum=32000, step=500, value=4000, interactive=False ) with gr.Box(): gr.Markdown("### Song Info") duration_info = gr.Label(label="Estimated Duration") sections_info = gr.Label(label="Section Information") submit_btn = gr.Button("Generate Music", variant="primary") # 생성된 오디오 출력 영역 with gr.Box(): music_out = gr.Audio(label="Generated Audio") # 예시 gr.Examples( examples=[ [ "female blues airy vocal bright vocal piano sad romantic guitar jazz", """[verse] In the quiet of the evening, shadows start to fall Whispers of the night wind echo through the hall Lost within the silence, I hear your gentle voice Guiding me back homeward, making my heart rejoice [chorus] Don't let this moment fade, hold me close tonight With you here beside me, everything's alright Can't imagine life alone, don't want to let you go Stay with me forever, let our love just flow """ ], [ "K-pop bright energetic synth dance electronic", """[verse] 언젠가 마주한 눈빛 속에서 [chorus] 다시 한 번 내게 말해줘 """ ] ], inputs=[genre_txt, lyrics_txt], outputs=[] ) # 가사 변경 시 추정 정보 업데이트 lyrics_txt.change( fn=update_info, inputs=[lyrics_txt], outputs=[duration_info, sections_info] ) # 버튼 클릭 시 infer 실행 submit_btn.click( fn=infer, inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens], outputs=[music_out] ) return demo if __name__ == "__main__": demo = main() demo.queue(max_size=20).launch( server_name="0.0.0.0", server_port=7860, share=True, show_api=True, show_error=True, max_threads=8 )