import gradio as gr import subprocess import os import shutil import tempfile import torch import logging import numpy as np import re from concurrent.futures import ThreadPoolExecutor from functools import lru_cache # 로깅 설정 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('yue_generation.log'), logging.StreamHandler() ] ) def optimize_gpu_settings(): if torch.cuda.is_available(): # GPU 메모리 관리 최적화 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = False # L40S에 최적화된 메모리 설정 torch.cuda.empty_cache() torch.cuda.set_device(0) # CUDA 스트림 최적화 torch.cuda.Stream(0) # 메모리 할당 최적화 os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}") logging.info(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") # L40S 특화 설정 if 'L40S' in torch.cuda.get_device_name(0): torch.cuda.set_per_process_memory_fraction(0.95) def analyze_lyrics(lyrics, repeat_chorus=2): lines = [line.strip() for line in lyrics.split('\n') if line.strip()] sections = { 'verse': 0, 'chorus': 0, 'bridge': 0, 'total_lines': len(lines) } current_section = None section_lines = { 'verse': [], 'chorus': [], 'bridge': [] } last_section = None # 마지막 섹션 태그 찾기 for i, line in enumerate(lines): if '[verse]' in line.lower() or '[chorus]' in line.lower() or '[bridge]' in line.lower(): last_section = i for i, line in enumerate(lines): lower_line = line.lower() # 섹션 태그 처리 if '[verse]' in lower_line: if current_section: # 이전 섹션의 라인들 저장 section_lines[current_section].extend(lines[last_section_start:i]) current_section = 'verse' sections['verse'] += 1 last_section_start = i + 1 continue elif '[chorus]' in lower_line: if current_section: section_lines[current_section].extend(lines[last_section_start:i]) current_section = 'chorus' sections['chorus'] += 1 last_section_start = i + 1 continue elif '[bridge]' in lower_line: if current_section: section_lines[current_section].extend(lines[last_section_start:i]) current_section = 'bridge' sections['bridge'] += 1 last_section_start = i + 1 continue # 마지막 섹션의 라인들 추가 if current_section and last_section_start < len(lines): section_lines[current_section].extend(lines[last_section_start:]) # 코러스 반복 처리 if sections['chorus'] > 0 and repeat_chorus > 1: original_chorus = section_lines['chorus'][:] for _ in range(repeat_chorus - 1): section_lines['chorus'].extend(original_chorus) # 섹션별 라인 수 확인 로깅 logging.info(f"Section line counts - Verse: {len(section_lines['verse'])}, " f"Chorus: {len(section_lines['chorus'])}, " f"Bridge: {len(section_lines['bridge'])}") return sections, (sections['verse'] + sections['chorus'] + sections['bridge']), len(lines), section_lines def calculate_generation_params(lyrics): sections, total_sections, total_lines, section_lines = analyze_lyrics(lyrics) # 기본 시간 계산 (초 단위) time_per_line = { 'verse': 4, # verse는 한 줄당 4초 'chorus': 6, # chorus는 한 줄당 6초 'bridge': 5 # bridge는 한 줄당 5초 } # 각 섹션별 예상 시간 계산 (마지막 섹션 포함) section_durations = {} for section_type in ['verse', 'chorus', 'bridge']: lines_count = len(section_lines[section_type]) section_durations[section_type] = lines_count * time_per_line[section_type] # 전체 시간 계산 (여유 시간 추가) total_duration = sum(duration for duration in section_durations.values()) total_duration = max(60, int(total_duration * 1.2)) # 20% 여유 시간 추가 # 토큰 계산 (마지막 섹션을 위한 추가 토큰) base_tokens = 3000 tokens_per_line = 200 extra_tokens = 1000 # 마지막 섹션을 위한 추가 토큰 total_tokens = base_tokens + (total_lines * tokens_per_line) + extra_tokens # 세그먼트 수 계산 (마지막 섹션을 위한 추가 세그먼트) if sections['chorus'] > 0: num_segments = 4 # 코러스가 있는 경우 4개 세그먼트 else: num_segments = 3 # 코러스가 없는 경우 3개 세그먼트 # 토큰 수 제한 (더 큰 제한) max_tokens = min(12000, total_tokens) # 최대 토큰 수 증가 return { 'max_tokens': max_tokens, 'num_segments': num_segments, 'sections': sections, 'section_lines': section_lines, 'estimated_duration': total_duration, 'section_durations': section_durations, 'has_chorus': sections['chorus'] > 0 } def detect_and_select_model(text): if re.search(r'[\u3131-\u318E\uAC00-\uD7A3]', text): return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot" elif re.search(r'[\u4e00-\u9fff]', text): return "m-a-p/YuE-s1-7B-anneal-zh-cot" elif re.search(r'[\u3040-\u309F\u30A0-\u30FF]', text): return "m-a-p/YuE-s1-7B-anneal-jp-kr-cot" else: return "m-a-p/YuE-s1-7B-anneal-en-cot" def install_flash_attn(): try: if not torch.cuda.is_available(): logging.warning("GPU not available, skipping flash-attn installation") return False cuda_version = torch.version.cuda if cuda_version is None: logging.warning("CUDA not available, skipping flash-attn installation") return False logging.info(f"Detected CUDA version: {cuda_version}") try: import flash_attn logging.info("flash-attn already installed") return True except ImportError: logging.info("Installing flash-attn...") subprocess.run( ["pip", "install", "flash-attn", "--no-build-isolation"], check=True, capture_output=True ) logging.info("flash-attn installed successfully!") return True except Exception as e: logging.warning(f"Failed to install flash-attn: {e}") return False def initialize_system(): optimize_gpu_settings() with ThreadPoolExecutor(max_workers=4) as executor: futures = [] futures.append(executor.submit(install_flash_attn)) from huggingface_hub import snapshot_download folder_path = './inference/xcodec_mini_infer' os.makedirs(folder_path, exist_ok=True) logging.info(f"Created folder at: {folder_path}") futures.append(executor.submit( snapshot_download, repo_id="m-a-p/xcodec_mini_infer", local_dir="./inference/xcodec_mini_infer", resume_download=True )) for future in futures: future.result() try: os.chdir("./inference") logging.info(f"Working directory changed to: {os.getcwd()}") except FileNotFoundError as e: logging.error(f"Directory error: {e}") raise @lru_cache(maxsize=100) def get_cached_file_path(content_hash, prefix): return create_temp_file(content_hash, prefix) def empty_output_folder(output_dir): try: shutil.rmtree(output_dir) os.makedirs(output_dir) logging.info(f"Output folder cleaned: {output_dir}") except Exception as e: logging.error(f"Error cleaning output folder: {e}") raise def create_temp_file(content, prefix, suffix=".txt"): temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix) content = content.strip() + "\n\n" content = content.replace("\r\n", "\n").replace("\r", "\n") temp_file.write(content) temp_file.close() logging.debug(f"Temporary file created: {temp_file.name}") return temp_file.name def get_last_mp3_file(output_dir): mp3_files = [f for f in os.listdir(output_dir) if f.endswith('.mp3')] if not mp3_files: logging.warning("No MP3 files found") return None mp3_files_with_path = [os.path.join(output_dir, f) for f in mp3_files] mp3_files_with_path.sort(key=os.path.getmtime, reverse=True) return mp3_files_with_path[0] def get_audio_duration(file_path): try: import librosa duration = librosa.get_duration(path=file_path) return duration except Exception as e: logging.error(f"Failed to get audio duration: {e}") return None def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens): genre_txt_path = None lyrics_txt_path = None try: model_path, config, params = optimize_model_selection(lyrics_txt_content, genre_txt_content) logging.info(f"Selected model: {model_path}") logging.info(f"Lyrics analysis: {params}") has_chorus = params['sections']['chorus'] > 0 estimated_duration = params.get('estimated_duration', 90) # 세그먼트 및 토큰 수 설정 if has_chorus: actual_max_tokens = min(12000, int(config['max_tokens'] * 1.3)) # 30% 더 많은 토큰 actual_num_segments = min(5, params['num_segments'] + 2) # 추가 세그먼트 else: actual_max_tokens = min(10000, int(config['max_tokens'] * 1.2)) actual_num_segments = min(4, params['num_segments'] + 1) logging.info(f"Estimated duration: {estimated_duration} seconds") logging.info(f"Has chorus sections: {has_chorus}") logging.info(f"Using segments: {actual_num_segments}, tokens: {actual_max_tokens}") genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_") lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_") output_dir = "./output" os.makedirs(output_dir, exist_ok=True) empty_output_folder(output_dir) # 수정된 command - 지원되지 않는 인수 제거 command = [ "python", "infer.py", "--stage1_model", model_path, "--stage2_model", "m-a-p/YuE-s2-1B-general", "--genre_txt", genre_txt_path, "--lyrics_txt", lyrics_txt_path, "--run_n_segments", str(actual_num_segments), "--stage2_batch_size", "16", "--output_dir", output_dir, "--cuda_idx", "0", "--max_new_tokens", str(actual_max_tokens), "--disable_offload_model" # GPU 메모리 최적화를 위해 추가 ] env = os.environ.copy() if torch.cuda.is_available(): env.update({ "CUDA_VISIBLE_DEVICES": "0", "CUDA_HOME": "/usr/local/cuda", "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}", "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}", "PYTORCH_CUDA_ALLOC_CONF": "max_split_size_mb:512", "CUDA_LAUNCH_BLOCKING": "0" }) # transformers 캐시 마이그레이션 처리 try: from transformers.utils import move_cache move_cache() except Exception as e: logging.warning(f"Cache migration warning (non-critical): {e}") process = subprocess.run( command, env=env, check=False, capture_output=True, text=True ) logging.info(f"Command output: {process.stdout}") if process.stderr: logging.error(f"Command error: {process.stderr}") if process.returncode != 0: logging.error(f"Command failed with return code: {process.returncode}") logging.error(f"Command: {' '.join(command)}") raise RuntimeError(f"Inference failed: {process.stderr}") last_mp3 = get_last_mp3_file(output_dir) if last_mp3: try: duration = get_audio_duration(last_mp3) logging.info(f"Generated audio file: {last_mp3}") if duration: logging.info(f"Audio duration: {duration:.2f} seconds") logging.info(f"Expected duration: {estimated_duration} seconds") if duration < estimated_duration * 0.8: logging.warning(f"Generated audio is shorter than expected: {duration:.2f}s < {estimated_duration:.2f}s") except Exception as e: logging.warning(f"Failed to get audio duration: {e}") return last_mp3 else: logging.warning("No output audio file generated") return None except Exception as e: logging.error(f"Inference error: {e}") raise finally: for path in [genre_txt_path, lyrics_txt_path]: if path and os.path.exists(path): try: os.remove(path) logging.debug(f"Removed temporary file: {path}") except Exception as e: logging.warning(f"Failed to remove temporary file {path}: {e}") def optimize_model_selection(lyrics, genre): model_path = detect_and_select_model(lyrics) params = calculate_generation_params(lyrics) has_chorus = params['sections']['chorus'] > 0 tokens_per_segment = params['max_tokens'] // params['num_segments'] model_config = { "m-a-p/YuE-s1-7B-anneal-en-cot": { "max_tokens": params['max_tokens'], "temperature": 0.8, "batch_size": 16, "num_segments": params['num_segments'], "estimated_duration": params['estimated_duration'] }, "m-a-p/YuE-s1-7B-anneal-jp-kr-cot": { "max_tokens": params['max_tokens'], "temperature": 0.7, "batch_size": 16, "num_segments": params['num_segments'], "estimated_duration": params['estimated_duration'] }, "m-a-p/YuE-s1-7B-anneal-zh-cot": { "max_tokens": params['max_tokens'], "temperature": 0.7, "batch_size": 16, "num_segments": params['num_segments'], "estimated_duration": params['estimated_duration'] } } if has_chorus: for config in model_config.values(): config['max_tokens'] = int(config['max_tokens'] * 1.5) return model_path, model_config[model_path], params def main(): # 테마 설정 theme = gr.themes.Soft( primary_hue="indigo", secondary_hue="purple", neutral_hue="slate", font=["Arial", "sans-serif"] ) # CSS 스타일 정의 custom_css = """ #main-container { max-width: 1200px; margin: auto; padding: 20px; } #header { text-align: center; margin-bottom: 30px; background: linear-gradient(135deg, #6366f1, #a855f7); padding: 20px; border-radius: 15px; color: white; } .input-section { background: #f8fafc; padding: 20px; border-radius: 15px; margin-bottom: 20px; box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1); } .output-section { background: #f0f9ff; padding: 20px; border-radius: 15px; margin-bottom: 20px; box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1); } #generate-btn { background: linear-gradient(135deg, #6366f1, #a855f7); border: none; padding: 15px 30px; border-radius: 10px; color: white; font-weight: bold; cursor: pointer; transition: all 0.3s ease; } #generate-btn:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.15); } .info-box { background: #fff; padding: 15px; border-radius: 10px; border: 1px solid #e2e8f0; margin: 10px 0; } .status-section { background: #fff; padding: 15px; border-radius: 10px; margin-top: 15px; border: 1px solid #e2e8f0; } """ with gr.Blocks(theme=theme, css=custom_css) as demo: with gr.Column(elem_id="main-container"): # 헤더 섹션 with gr.Row(elem_id="header"): gr.Markdown( """ # 🎵 AI Song Creator 'Open SUNO' ### Transform Your Lyrics into Complete Songs with Music Create professional songs from your lyrics in multiple languages """ ) # 메인 컨텐츠 with gr.Row(): # 입력 섹션 with gr.Column(scale=1, elem_classes="input-section"): gr.Markdown("### 📝 Input Your Song Details") genre_txt = gr.Textbox( label="🎸 Music Genre & Style", placeholder="e.g., K-pop bright energetic synth dance electronic...", elem_id="genre-input" ) lyrics_txt = gr.Textbox( label="📝 Lyrics", placeholder="Enter lyrics with section tags: [verse], [chorus], [bridge]...", lines=10, elem_id="lyrics-input" ) # 정보 표시 섹션 with gr.Row(): with gr.Column(scale=1): duration_info = gr.Label( label="⏱️ Estimated Duration", elem_classes="info-box" ) with gr.Column(scale=1): sections_info = gr.Label( label="📊 Section Analysis", elem_classes="info-box" ) submit_btn = gr.Button( "🎼 Generate Music", variant="primary", elem_id="generate-btn" ) # 출력 섹션 with gr.Column(scale=1, elem_classes="output-section"): gr.Markdown("### 🎵 Generated Music") music_out = gr.Audio( label="Generated Song", elem_id="music-output" ) # 진행 상태 with gr.Group(elem_classes="status-section"): gr.Markdown("### 🔄 Generation Status") num_segments = gr.Number( label="Song Segments", value=2, interactive=False, visible=False ) max_new_tokens = gr.Number( label="Tokens", value=4000, interactive=False, visible=False ) # 예제 섹션 with gr.Accordion("📖 Examples", open=False): gr.Examples( examples=[ [ "female blues airy vocal bright vocal piano sad romantic guitar jazz", """[verse] In the quiet of the evening, shadows start to fall Whispers of the night wind echo through the hall Lost within the silence, I hear your gentle voice Guiding me back homeward, making my heart rejoice [chorus] Don't let this moment fade, hold me close tonight With you here beside me, everything's alright Can't imagine life alone, don't want to let you go Stay with me forever, let our love just flow""" ], [ "K-pop bright energetic synth dance electronic", """[verse] 언젠가 마주한 눈빛 속에서 [chorus] 다시 한 번 내게 말해줘 [verse] 어두운 밤을 지날 때마다 [chorus] 다시 한 번 내게 말해줘""" ] ], inputs=[genre_txt, lyrics_txt] ) # 도움말 섹션 with gr.Accordion("ℹ️ Help & Information", open=False): gr.Markdown( """ ### 🎵 How to Use 1. **Enter Genre & Style**: Describe the musical style you want 2. **Input Lyrics**: Write your lyrics using section tags 3. **Generate**: Click the Generate button and wait for your music! ### 🌏 Supported Languages - English - Korean (한국어) - Japanese (日本語) - Chinese (中文) ### ⚡ Tips for Best Results - Be specific with genre descriptions - Include emotion and instrument preferences - Properly tag your lyrics sections - Include both verse and chorus sections """ ) def update_info(lyrics): if not lyrics: return "No lyrics entered", "No sections detected" params = calculate_generation_params(lyrics) duration = params['estimated_duration'] sections = params['sections'] return ( f"⏱️ Duration: {duration:.1f} seconds", f"📊 Verses: {sections['verse']}, Chorus: {sections['chorus']}" ) # 이벤트 핸들러 설정 lyrics_txt.change( fn=update_info, inputs=[lyrics_txt], outputs=[duration_info, sections_info] ) submit_btn.click( fn=infer, inputs=[genre_txt, lyrics_txt, num_segments, max_new_tokens], outputs=[music_out] ) return demo if __name__ == "__main__": demo = main() demo.queue(max_size=20).launch( server_name="0.0.0.0", server_port=7860, share=True, show_api=True, show_error=True, max_threads=8 )