quazim's picture
updated
f94241a
raw
history blame
17.2 kB
import gradio as gr
import torch
import gc
import numpy as np
import random
import os
import tempfile
import soundfile as sf
import time
os.environ['ELASTIC_LOG_LEVEL'] = 'DEBUG'
from transformers import AutoProcessor, pipeline
from elastic_models.transformers import MusicgenForConditionalGeneration
MODEL_CONFIG = {
'cost_per_hour': 1.8, # $1.8 per hour
}
original_time_cache = {}
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def cleanup_gpu():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
def cleanup_temp_files():
import glob
import time
temp_dir = tempfile.gettempdir()
cutoff_time = time.time() - 3600
for temp_file in glob.glob(os.path.join(temp_dir, "tmp*.wav")):
try:
if os.path.getctime(temp_file) < cutoff_time:
os.remove(temp_file)
print(f"[CLEANUP] Removed old temp file: {temp_file}")
except OSError:
pass
_generator = None
_processor = None
_original_generator = None
_original_processor = None
def load_model():
global _generator, _processor
if _generator is None:
print("[MODEL] Starting model initialization...")
cleanup_gpu()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[MODEL] Using device: {device}")
print("[MODEL] Loading processor...")
_processor = AutoProcessor.from_pretrained(
"facebook/musicgen-large"
)
print("[MODEL] Loading model...")
model = MusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-large",
torch_dtype=torch.float16,
device=device,
mode="S",
__paged=True,
)
model.eval()
print("[MODEL] Creating pipeline...")
_generator = pipeline(
task="text-to-audio",
model=model,
tokenizer=_processor.tokenizer,
device=device,
)
print("[MODEL] Model initialization completed successfully")
return _generator, _processor
def load_original_model():
global _original_generator, _original_processor
if _original_generator is None:
print("[ORIGINAL MODEL] Starting original model initialization...")
cleanup_gpu()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[ORIGINAL MODEL] Using device: {device}")
print("[ORIGINAL MODEL] Loading processor...")
_original_processor = AutoProcessor.from_pretrained(
"facebook/musicgen-large"
)
from transformers import MusicgenForConditionalGeneration as HFMusicgenForConditionalGeneration
print("[ORIGINAL MODEL] Loading original model...")
model = HFMusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-large",
torch_dtype=torch.float16,
device=device,
)
model.eval()
print("[ORIGINAL MODEL] Creating pipeline...")
_original_generator = pipeline(
task="text-to-audio",
model=model,
tokenizer=_original_processor.tokenizer,
device=device,
)
print("[ORIGINAL MODEL] Original model initialization completed successfully")
return _original_generator, _original_processor
def calculate_max_tokens(duration_seconds):
token_rate = 50
max_new_tokens = int(duration_seconds * token_rate)
print(f"[MODEL] Duration: {duration_seconds}s -> Tokens: {max_new_tokens} (rate: {token_rate})")
return max_new_tokens
def generate_music(text_prompt, duration=10, guidance_scale=3.0):
try:
generator, processor = load_model()
print(f"[GENERATION] Starting generation...")
print(f"[GENERATION] Prompt: '{text_prompt}'")
print(f"[GENERATION] Duration: {duration}s")
print(f"[GENERATION] Guidance scale: {guidance_scale}")
cleanup_gpu()
import time
set_seed(42)
print(f"[GENERATION] Using seed: {42}")
max_new_tokens = calculate_max_tokens(duration)
generation_params = {
'do_sample': True,
'guidance_scale': guidance_scale,
'max_new_tokens': max_new_tokens,
'min_new_tokens': max_new_tokens,
'cache_implementation': 'paged',
}
prompts = [text_prompt]
outputs = generator(
prompts,
batch_size=1,
generate_kwargs=generation_params
)
print(f"[GENERATION] Generation completed successfully")
output = outputs[0]
audio_data = output['audio']
sample_rate = output['sampling_rate']
print(f"[GENERATION] Audio shape: {audio_data.shape}")
print(f"[GENERATION] Sample rate: {sample_rate}")
print(f"[GENERATION] Audio dtype: {audio_data.dtype}")
print(f"[GENERATION] Audio is numpy: {type(audio_data)}")
if hasattr(audio_data, 'cpu'):
audio_data = audio_data.cpu().numpy()
print(f"[GENERATION] Audio shape after tensor conversion: {audio_data.shape}")
if len(audio_data.shape) == 3:
audio_data = audio_data[0]
if len(audio_data.shape) == 2:
if audio_data.shape[0] < audio_data.shape[1]:
audio_data = audio_data.T
if audio_data.shape[1] > 1:
audio_data = audio_data[:, 0]
else:
audio_data = audio_data.flatten()
audio_data = audio_data.flatten()
print(f"[GENERATION] Audio shape after flattening: {audio_data.shape}")
max_val = np.max(np.abs(audio_data))
if max_val > 0:
audio_data = audio_data / max_val * 0.95
audio_data = (audio_data * 32767).astype(np.int16)
print(f"[GENERATION] Final audio shape: {audio_data.shape}")
print(f"[GENERATION] Audio range: [{np.min(audio_data)}, {np.max(audio_data)}]")
print(f"[GENERATION] Audio dtype: {audio_data.dtype}")
print(f"[GENERATION] Sample rate: {sample_rate}")
timestamp = int(time.time() * 1000)
temp_filename = f"generated_music_{timestamp}.wav"
temp_path = os.path.join(tempfile.gettempdir(), temp_filename)
sf.write(temp_path, audio_data, sample_rate)
if os.path.exists(temp_path):
file_size = os.path.getsize(temp_path)
print(f"[GENERATION] Audio saved to: {temp_path}")
print(f"[GENERATION] File size: {file_size} bytes")
# Try returning numpy format instead
print(f"[GENERATION] Returning numpy tuple: ({sample_rate}, audio_array)")
return (sample_rate, audio_data)
else:
print(f"[ERROR] Failed to create audio file: {temp_path}")
return None
except Exception as e:
print(f"[ERROR] Generation failed: {str(e)}")
cleanup_gpu()
return None
def calculate_generation_cost(generation_time_seconds, mode='S'):
hours = generation_time_seconds / 3600
cost_per_hour = MODEL_CONFIG['cost_per_hour']
return hours * cost_per_hour
def calculate_cost_savings(compressed_time, original_time):
compressed_cost = calculate_generation_cost(compressed_time, 'S')
original_cost = calculate_generation_cost(original_time, 'original')
savings = original_cost - compressed_cost
savings_percent = (savings / original_cost * 100) if original_cost > 0 else 0
return {
'compressed_cost': compressed_cost,
'original_cost': original_cost,
'savings': savings,
'savings_percent': savings_percent
}
def get_cache_key(prompt, duration, guidance_scale):
return f"{hash(prompt)}_{duration}_{guidance_scale}"
def generate_music_batch(text_prompt, duration=10, guidance_scale=3.0, model_mode="compressed"):
try:
cache_key = get_cache_key(text_prompt, duration, guidance_scale)
generator, processor = load_model()
model_name = "Compressed (S)"
print(f"[GENERATION] Starting batch generation using {model_name} model...")
print(f"[GENERATION] Prompt: '{text_prompt}'")
print(f"[GENERATION] Duration: {duration}s")
print(f"[GENERATION] Guidance scale: {guidance_scale}")
cleanup_gpu()
set_seed(42)
print(f"[GENERATION] Using seed: 42")
max_new_tokens = calculate_max_tokens(duration)
generation_params = {
'do_sample': True,
'guidance_scale': guidance_scale,
'max_new_tokens': max_new_tokens,
'min_new_tokens': max_new_tokens,
'cache_implementation': 'paged',
}
prompts = [text_prompt] * 4
start_time = time.time()
outputs = generator(
prompts,
batch_size=4,
generate_kwargs=generation_params
)
generation_time = time.time() - start_time
print(f"[GENERATION] Batch generation completed in {generation_time:.2f}s")
audio_variants = []
sample_rate = outputs[0]['sampling_rate']
for i, output in enumerate(outputs):
audio_data = output['audio']
print(f"[GENERATION] Processing variant {i+1} audio shape: {audio_data.shape}")
if hasattr(audio_data, 'cpu'):
audio_data = audio_data.cpu().numpy()
if len(audio_data.shape) == 3:
audio_data = audio_data[0]
if len(audio_data.shape) == 2:
if audio_data.shape[0] < audio_data.shape[1]:
audio_data = audio_data.T
if audio_data.shape[1] > 1:
audio_data = audio_data[:, 0]
else:
audio_data = audio_data.flatten()
audio_data = audio_data.flatten()
max_val = np.max(np.abs(audio_data))
if max_val > 0:
audio_data = audio_data / max_val * 0.95
audio_data = (audio_data * 32767).astype(np.int16)
audio_variants.append((sample_rate, audio_data))
print(f"[GENERATION] Variant {i+1} final shape: {audio_data.shape}")
comparison_message = ""
if cache_key in original_time_cache:
original_time = original_time_cache[cache_key]
cost_info = calculate_cost_savings(generation_time, original_time)
comparison_message = f"πŸ’° Cost Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%) - Compressed: ${cost_info['compressed_cost']:.4f} vs Original: ${cost_info['original_cost']:.4f}"
print(f"[COST] Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%)")
else:
try:
print(f"[TIMING] Measuring original model speed for comparison...")
original_generator, original_processor = load_original_model()
original_start = time.time()
original_outputs = original_generator(
prompts,
batch_size=4,
generate_kwargs=generation_params
)
original_time = time.time() - original_start
original_time_cache[cache_key] = original_time
cost_info = calculate_cost_savings(generation_time, original_time)
comparison_message = f"πŸ’° Cost Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%) - Compressed: ${cost_info['compressed_cost']:.4f} vs Original: ${cost_info['original_cost']:.4f}"
print(f"[COST] First comparison - Savings: ${cost_info['savings']:.4f} ({cost_info['savings_percent']:.1f}%)")
print(f"[TIMING] Original: {original_time:.2f}s, Compressed: {generation_time:.2f}s")
del original_generator, original_processor
cleanup_gpu()
print(f"[CLEANUP] Original model cleaned up after timing measurement")
except Exception as e:
print(f"[WARNING] Could not measure original timing: {e}")
compressed_cost = calculate_generation_cost(generation_time, 'S')
comparison_message = f"πŸ’Έ Compressed Cost: ${compressed_cost:.4f} (could not compare with original)"
generation_info = f"βœ… Generated 4 variants in {generation_time:.2f}s\n{comparison_message}"
return audio_variants[0], audio_variants[1], audio_variants[2], audio_variants[3], generation_info
except Exception as e:
print(f"[ERROR] Batch generation failed: {str(e)}")
cleanup_gpu()
error_msg = f"❌ Generation failed: {str(e)}"
return None, None, None, None, error_msg
with gr.Blocks(title="MusicGen Large - Music Generation") as demo:
gr.Markdown("# 🎡 MusicGen Large Music Generator")
gr.Markdown("Generate music from text descriptions using Facebook's MusicGen Large model accelerated by TheStage for 2.3x faster performance")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Music Description",
placeholder="Enter a description of the music you want to generate",
lines=3,
value="A groovy funk bassline with a tight drum beat"
)
with gr.Row():
duration = gr.Slider(
minimum=5,
maximum=30,
value=10,
step=1,
label="Duration (seconds)"
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
label="Guidance Scale",
info="Higher values follow prompt more closely"
)
generate_btn = gr.Button("🎡 Generate 4 Music Variants", variant="primary", size="lg")
with gr.Column():
generation_info = gr.Markdown("Ready to generate music variants with cost comparison vs original model")
with gr.Row():
audio_output1 = gr.Audio(label="Variant 1", type="numpy")
audio_output2 = gr.Audio(label="Variant 2", type="numpy")
with gr.Row():
audio_output3 = gr.Audio(label="Variant 3", type="numpy")
audio_output4 = gr.Audio(label="Variant 4", type="numpy")
with gr.Accordion("Tips", open=False):
gr.Markdown("""
- Be specific in your descriptions (e.g., "slow blues guitar with harmonica")
- Higher guidance scale = follows prompt more closely
- Lower guidance scale = more creative/varied results
- Duration is limited to 30 seconds for faster generation
""")
def generate_simple(text_prompt, duration, guidance_scale):
return generate_music_batch(text_prompt, duration, guidance_scale, "compressed")
generate_btn.click(
fn=generate_simple,
inputs=[text_input, duration, guidance_scale],
outputs=[audio_output1, audio_output2, audio_output3, audio_output4, generation_info],
show_progress=True
)
gr.Examples(
examples=[
"A groovy funk bassline with a tight drum beat",
"Relaxing acoustic guitar melody",
"Electronic dance music with heavy bass",
"Classical violin concerto",
"Reggae with steel drums and bass",
"Rock ballad with electric guitar solo",
"Jazz piano improvisation with brushed drums",
"Ambient synthwave with retro vibes",
],
inputs=text_input,
label="Example Prompts"
)
gr.Markdown("---")
gr.Markdown("""
<div style="text-align: center; color: #666; font-size: 12px; margin-top: 2rem;">
<strong>Limitations:</strong><br>
β€’ The model is not able to generate realistic vocals.<br>
β€’ The model has been trained with English descriptions and will not perform as well in other languages.<br>
β€’ The model does not perform equally well for all music styles and cultures.<br>
β€’ The model sometimes generates end of songs, collapsing to silence.<br>
β€’ It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
</div>
""")
if __name__ == "__main__":
cleanup_temp_files()
demo.launch()