Spaces:
Running
on
L40S
Running
on
L40S
import gradio as gr | |
import torch | |
import gc | |
import numpy as np | |
import random | |
import os | |
import tempfile | |
import soundfile as sf | |
os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG' | |
from transformers import AutoProcessor, pipeline | |
from elastic_models.transformers import MusicgenForConditionalGeneration | |
def set_seed(seed: int = 42): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def cleanup_gpu(): | |
"""Clean up GPU memory to avoid TensorRT conflicts.""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
gc.collect() | |
def cleanup_temp_files(): | |
"""Clean up old temporary audio files.""" | |
import glob | |
import time | |
temp_dir = tempfile.gettempdir() | |
cutoff_time = time.time() - 3600 | |
for temp_file in glob.glob(os.path.join(temp_dir, "tmp*.wav")): | |
try: | |
if os.path.getctime(temp_file) < cutoff_time: | |
os.remove(temp_file) | |
print(f"[CLEANUP] Removed old temp file: {temp_file}") | |
except OSError: | |
pass | |
_generator = None | |
_processor = None | |
def load_model(): | |
global _generator, _processor | |
if _generator is None: | |
print("[MODEL] Starting model initialization...") | |
cleanup_gpu() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"[MODEL] Using device: {device}") | |
print("[MODEL] Loading processor...") | |
_processor = AutoProcessor.from_pretrained( | |
"facebook/musicgen-large" | |
) | |
print("[MODEL] Loading model...") | |
model = MusicgenForConditionalGeneration.from_pretrained( | |
"facebook/musicgen-large", | |
torch_dtype=torch.float16, | |
device=device, | |
mode="S", | |
__paged=True, | |
) | |
model.eval() | |
print("[MODEL] Creating pipeline...") | |
_generator = pipeline( | |
task="text-to-audio", | |
model=model, | |
tokenizer=_processor.tokenizer, | |
device=device, | |
) | |
print("[MODEL] Model initialization completed successfully") | |
return _generator, _processor | |
def calculate_max_tokens(duration_seconds): | |
token_rate = 50 | |
max_new_tokens = int(duration_seconds * token_rate) | |
print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})") | |
return max_new_tokens | |
def generate_music(text_prompt, duration=10, guidance_scale=3.0): | |
try: | |
generator, processor = load_model() | |
print(f"[GENERATION] Starting generation...") | |
print(f"[GENERATION] Prompt: '{text_prompt}'") | |
print(f"[GENERATION] Duration: {duration}s") | |
print(f"[GENERATION] Guidance scale: {guidance_scale}") | |
cleanup_gpu() | |
import time | |
set_seed(42) | |
print(f"[GENERATION] Using seed: {42}") | |
max_new_tokens = calculate_max_tokens(duration) | |
generation_params = { | |
'do_sample': True, | |
'guidance_scale': guidance_scale, | |
'max_new_tokens': max_new_tokens, | |
'min_new_tokens': max_new_tokens, | |
'cache_implementation': 'paged', | |
} | |
prompts = [text_prompt] | |
outputs = generator( | |
prompts, | |
batch_size=1, | |
generate_kwargs=generation_params | |
) | |
print(f"[GENERATION] Generation completed successfully") | |
output = outputs[0] | |
audio_data = output['audio'] | |
sample_rate = output['sampling_rate'] | |
print(f"[GENERATION] Audio shape: {audio_data.shape}") | |
print(f"[GENERATION] Sample rate: {sample_rate}") | |
print(f"[GENERATION] Audio dtype: {audio_data.dtype}") | |
print(f"[GENERATION] Audio is numpy: {type(audio_data)}") | |
if hasattr(audio_data, 'cpu'): | |
audio_data = audio_data.cpu().numpy() | |
print(f"[GENERATION] Audio shape after tensor conversion: {audio_data.shape}") | |
if len(audio_data.shape) == 3: | |
audio_data = audio_data[0] | |
if len(audio_data.shape) == 2: | |
if audio_data.shape[0] < audio_data.shape[1]: | |
audio_data = audio_data.T | |
if audio_data.shape[1] > 1: | |
audio_data = audio_data[:, 0] | |
else: | |
audio_data = audio_data.flatten() | |
audio_data = audio_data.flatten() | |
print(f"[GENERATION] Audio shape after flattening: {audio_data.shape}") | |
max_val = np.max(np.abs(audio_data)) | |
if max_val > 0: | |
audio_data = audio_data / max_val * 0.95 # Scale to 95% to avoid clipping | |
audio_data = (audio_data * 32767).astype(np.int16). ### | |
print(f"[GENERATION] Final audio shape: {audio_data.shape}") | |
print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]") | |
print(f"[GENERATION] Audio dtype: {audio_data.dtype}") | |
print(f"[GENERATION] Sample rate: {sample_rate}") | |
timestamp = int(time.time() * 1000) | |
temp_filename = f"generated_music_{timestamp}.wav" | |
temp_path = os.path.join(tempfile.gettempdir(), temp_filename) | |
sf.write(temp_path, audio_data, sample_rate) | |
if os.path.exists(temp_path): | |
file_size = os.path.getsize(temp_path) | |
print(f"[GENERATION] Audio saved to: {temp_path}") | |
print(f"[GENERATION] File size: {file_size} bytes") | |
print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)") | |
return (sample_rate, audio_data) | |
else: | |
print(f"[ERROR] Failed to create audio file: {temp_path}") | |
return None | |
except Exception as e: | |
print(f"[ERROR] Generation failed: {str(e)}") | |
cleanup_gpu() | |
return None | |
with gr.Blocks(title="MusicGen Large - Music Generation") as demo: | |
gr.Markdown("# 🎵 MusicGen Large Music Generator") | |
gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model with elastic compression.") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Music Description", | |
placeholder="Enter a description of the music you want to generate", | |
lines=3, | |
value="A groovy funk bassline with a tight drum beat" | |
) | |
with gr.Row(): | |
duration = gr.Slider( | |
minimum=5, | |
maximum=30, | |
value=10, | |
step=1, | |
label="Duration (seconds)" | |
) | |
guidance_scale = gr.Slider( | |
minimum=1.0, | |
maximum=10.0, | |
value=3.0, | |
step=0.5, | |
label="Guidance Scale", | |
info="Higher values follow prompt more closely" | |
) | |
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg") | |
with gr.Column(): | |
audio_output = gr.Audio( | |
label="Generated Music", | |
type="numpy" | |
) | |
with gr.Accordion("Tips", open=False): | |
gr.Markdown(""" | |
- Be specific in your descriptions (e.g., "slow blues guitar with harmonica") | |
- Higher guidance scale = follows prompt more closely | |
- Lower guidance scale = more creative/varied results | |
- Duration is limited to 30 seconds for faster generation | |
""") | |
generate_btn.click( | |
fn=generate_music, | |
inputs=[text_input, duration, guidance_scale], | |
outputs=audio_output, | |
show_progress=True | |
) | |
gr.Examples( | |
examples=[ | |
"A groovy funk bassline with a tight drum beat", | |
"Relaxing acoustic guitar melody", | |
"Electronic dance music with heavy bass", | |
"Classical violin concerto", | |
"Reggae with steel drums and bass", | |
"Rock ballad with electric guitar solo", | |
"Jazz piano improvisation with brushed drums", | |
"Ambient synthwave with retro vibes", | |
], | |
inputs=text_input, | |
label="Example Prompts" | |
) | |
gr.Markdown("---") | |
gr.Markdown(""" | |
<div style="text-align: center; color: #666; font-size: 12px; margin-top: 2rem;"> | |
<strong>Limitations:</strong><br> | |
• The model is not able to generate realistic vocals.<br> | |
• The model has been trained with English descriptions and will not perform as well in other languages.<br> | |
• The model does not perform equally well for all music styles and cultures.<br> | |
• The model sometimes generates end of songs, collapsing to silence.<br> | |
• It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results. | |
</div> | |
""") | |
if __name__ == "__main__": | |
cleanup_temp_files() | |
demo.launch() | |