import gradio as gr import torch import gc import numpy as np import random import os import tempfile import soundfile as sf os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG' from transformers import AutoProcessor, pipeline from elastic_models.transformers import MusicgenForConditionalGeneration def set_seed(seed: int = 42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def cleanup_gpu(): """Clean up GPU memory to avoid TensorRT conflicts.""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() def cleanup_temp_files(): """Clean up old temporary audio files.""" import glob import time temp_dir = tempfile.gettempdir() cutoff_time = time.time() - 3600 for temp_file in glob.glob(os.path.join(temp_dir, "tmp*.wav")): try: if os.path.getctime(temp_file) < cutoff_time: os.remove(temp_file) print(f"[CLEANUP] Removed old temp file: {temp_file}") except OSError: pass _generator = None _processor = None def load_model(): global _generator, _processor if _generator is None: print("[MODEL] Starting model initialization...") cleanup_gpu() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[MODEL] Using device: {device}") print("[MODEL] Loading processor...") _processor = AutoProcessor.from_pretrained( "facebook/musicgen-large" ) print("[MODEL] Loading model...") model = MusicgenForConditionalGeneration.from_pretrained( "facebook/musicgen-large", torch_dtype=torch.float16, device=device, mode="S", __paged=True, ) model.eval() print("[MODEL] Creating pipeline...") _generator = pipeline( task="text-to-audio", model=model, tokenizer=_processor.tokenizer, device=device, ) print("[MODEL] Model initialization completed successfully") return _generator, _processor def calculate_max_tokens(duration_seconds): token_rate = 50 max_new_tokens = int(duration_seconds * token_rate) print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})") return max_new_tokens def generate_music(text_prompt, duration=10, guidance_scale=3.0): try: generator, processor = load_model() print(f"[GENERATION] Starting generation...") print(f"[GENERATION] Prompt: '{text_prompt}'") print(f"[GENERATION] Duration: {duration}s") print(f"[GENERATION] Guidance scale: {guidance_scale}") cleanup_gpu() import time set_seed(42) print(f"[GENERATION] Using seed: {42}") max_new_tokens = calculate_max_tokens(duration) generation_params = { 'do_sample': True, 'guidance_scale': guidance_scale, 'max_new_tokens': max_new_tokens, 'min_new_tokens': max_new_tokens, 'cache_implementation': 'paged', } prompts = [text_prompt] outputs = generator( prompts, batch_size=1, generate_kwargs=generation_params ) print(f"[GENERATION] Generation completed successfully") output = outputs[0] audio_data = output['audio'] sample_rate = output['sampling_rate'] print(f"[GENERATION] Audio shape: {audio_data.shape}") print(f"[GENERATION] Sample rate: {sample_rate}") print(f"[GENERATION] Audio dtype: {audio_data.dtype}") print(f"[GENERATION] Audio is numpy: {type(audio_data)}") if hasattr(audio_data, 'cpu'): audio_data = audio_data.cpu().numpy() print(f"[GENERATION] Audio shape after tensor conversion: {audio_data.shape}") if len(audio_data.shape) == 3: audio_data = audio_data[0] if len(audio_data.shape) == 2: if audio_data.shape[0] < audio_data.shape[1]: audio_data = audio_data.T if audio_data.shape[1] > 1: audio_data = audio_data[:, 0] else: audio_data = audio_data.flatten() audio_data = audio_data.flatten() print(f"[GENERATION] Audio shape after flattening: {audio_data.shape}") max_val = np.max(np.abs(audio_data)) if max_val > 0: audio_data = audio_data / max_val * 0.95 # Scale to 95% to avoid clipping audio_data = (audio_data * 32767).astype(np.int16) ### print(f"[GENERATION] Final audio shape: {audio_data.shape}") print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]") print(f"[GENERATION] Audio dtype: {audio_data.dtype}") print(f"[GENERATION] Sample rate: {sample_rate}") timestamp = int(time.time() * 1000) temp_filename = f"generated_music_{timestamp}.wav" temp_path = os.path.join(tempfile.gettempdir(), temp_filename) sf.write(temp_path, audio_data, sample_rate) if os.path.exists(temp_path): file_size = os.path.getsize(temp_path) print(f"[GENERATION] Audio saved to: {temp_path}") print(f"[GENERATION] File size: {file_size} bytes") print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)") return (sample_rate, audio_data) else: print(f"[ERROR] Failed to create audio file: {temp_path}") return None except Exception as e: print(f"[ERROR] Generation failed: {str(e)}") cleanup_gpu() return None with gr.Blocks(title="MusicGen Large - Music Generation") as demo: gr.Markdown("# 🎵 MusicGen Large Music Generator") gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model with elastic compression.") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Music Description", placeholder="Enter a description of the music you want to generate", lines=3, value="A groovy funk bassline with a tight drum beat" ) with gr.Row(): duration = gr.Slider( minimum=5, maximum=30, value=10, step=1, label="Duration (seconds)" ) guidance_scale = gr.Slider( minimum=1.0, maximum=10.0, value=3.0, step=0.5, label="Guidance Scale", info="Higher values follow prompt more closely" ) generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg") with gr.Column(): audio_output = gr.Audio( label="Generated Music", type="numpy" ) with gr.Accordion("Tips", open=False): gr.Markdown(""" - Be specific in your descriptions (e.g., "slow blues guitar with harmonica") - Higher guidance scale = follows prompt more closely - Lower guidance scale = more creative/varied results - Duration is limited to 30 seconds for faster generation """) generate_btn.click( fn=generate_music, inputs=[text_input, duration, guidance_scale], outputs=audio_output, show_progress=True ) gr.Examples( examples=[ "A groovy funk bassline with a tight drum beat", "Relaxing acoustic guitar melody", "Electronic dance music with heavy bass", "Classical violin concerto", "Reggae with steel drums and bass", "Rock ballad with electric guitar solo", "Jazz piano improvisation with brushed drums", "Ambient synthwave with retro vibes", ], inputs=text_input, label="Example Prompts" ) gr.Markdown("---") gr.Markdown("""