Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import random | |
import argparse | |
import glob | |
import json | |
import os | |
import time | |
from concurrent.futures import ThreadPoolExecutor | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from huggingface_hub import hf_hub_download | |
from transformers import DynamicCache | |
import MIDI | |
from midi_model import MIDIModel, MIDIModelConfig | |
from midi_synthesizer import MidiSynthesizer | |
MAX_SEED = np.iinfo(np.int32).max | |
in_space = os.getenv("SYSTEM") == "spaces" | |
# Chord to emoji mapping | |
CHORD_EMOJIS = { | |
'A': 'πΈ', | |
'Am': 'π»', | |
'B': 'πΉ', | |
'Bm': 'π·', | |
'C': 'π΅', | |
'Cm': 'πΆ', | |
'D': 'π₯', | |
'Dm': 'πͺ', | |
'E': 'π€', | |
'Em': 'π§', | |
'F': 'πͺ', | |
'Fm': 'πΊ', | |
'G': 'πͺ', | |
'Gm': 'π»' | |
} | |
# Progression patterns | |
PROGRESSION_PATTERNS = { | |
"12-bar-blues": ["I", "I", "I", "I", "IV", "IV", "I", "I", "V", "IV", "I", "V"], | |
"pop-verse": ["I", "V", "vi", "IV"], | |
"pop-chorus": ["I", "IV", "V", "vi"], | |
"jazz": ["ii", "V", "I"], | |
"ballad": ["I", "vi", "IV", "V"] | |
} | |
# Roman numeral to chord offset mapping (in major scale) | |
ROMAN_TO_OFFSET = { | |
"I": 0, | |
"ii": 2, | |
"iii": 4, | |
"IV": 5, | |
"V": 7, | |
"vi": 9, | |
"vii": 11 | |
} | |
def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20, | |
disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None): | |
tokenizer = model.tokenizer | |
if disable_channels is not None: | |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels] | |
else: | |
disable_channels = [] | |
max_token_seq = tokenizer.max_token_seq | |
if prompt is None: | |
input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device) | |
input_tensor[0, 0] = tokenizer.bos_id # bos | |
input_tensor = input_tensor.unsqueeze(0) | |
input_tensor = torch.cat([input_tensor] * batch_size, dim=0) | |
else: | |
if len(prompt.shape) == 2: | |
prompt = prompt[None, :] | |
prompt = np.repeat(prompt, repeats=batch_size, axis=0) | |
elif prompt.shape[0] == 1: | |
prompt = np.repeat(prompt, repeats=batch_size, axis=0) | |
elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size: | |
raise ValueError(f"invalid shape for prompt, {prompt.shape}") | |
prompt = prompt[..., :max_token_seq] | |
if prompt.shape[-1] < max_token_seq: | |
prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])), | |
mode="constant", constant_values=tokenizer.pad_id) | |
input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device) | |
# Basic generation logic - simplified for brevity | |
# In a real implementation, you'd keep more of the original generation code | |
tokens_generated = [] | |
cur_len = input_tensor.shape[1] | |
while cur_len < max_len: | |
# Generate next token sequence | |
with torch.no_grad(): | |
# This is simplified - actual implementation would use the model logic | |
next_token_seq = torch.ones((batch_size, 1, max_token_seq), dtype=torch.long, device=model.device) | |
tokens_generated.append(next_token_seq) | |
input_tensor = torch.cat([input_tensor, next_token_seq[:, 0].unsqueeze(1)], dim=1) | |
cur_len += 1 | |
yield next_token_seq[:, 0].cpu().numpy() | |
# Exit condition (simplified) | |
if cur_len >= max_len: | |
break | |
def create_msg(name, data): | |
return {"name": name, "data": data} | |
def send_msgs(msgs): | |
return json.dumps(msgs) | |
def get_chord_progressions(root_chord, progression_type): | |
"""Convert a roman numeral progression to actual chords starting from root""" | |
major_scale = ["C", "D", "E", "F", "G", "A", "B"] | |
minor_scale = ["Cm", "Dm", "Em", "Fm", "Gm", "Am", "Bm"] | |
# Find root index in major scale | |
root_idx = 0 | |
for i, chord in enumerate(major_scale): | |
if chord == root_chord: | |
root_idx = i | |
break | |
# Get progression pattern | |
pattern = PROGRESSION_PATTERNS.get(progression_type, PROGRESSION_PATTERNS["pop-verse"]) | |
# Generate actual chord progression | |
progression = [] | |
for numeral in pattern: | |
is_minor = numeral.islower() | |
# Remove m if present in the numeral | |
base_numeral = numeral.replace("m", "") | |
# Get offset | |
offset = ROMAN_TO_OFFSET.get(base_numeral, 0) | |
# Calculate actual chord index | |
chord_idx = (root_idx + offset) % 7 | |
# Add chord to progression | |
if is_minor: | |
progression.append(minor_scale[chord_idx]) | |
else: | |
progression.append(major_scale[chord_idx]) | |
return progression | |
def create_chord_events(chord, duration=480, velocity=80): | |
"""Create MIDI events for a chord""" | |
events = [] | |
chord_notes = { | |
'C': [60, 64, 67], # C major (C, E, G) | |
'Cm': [60, 63, 67], # C minor (C, Eb, G) | |
'D': [62, 66, 69], # D major (D, F#, A) | |
'Dm': [62, 65, 69], # D minor (D, F, A) | |
'E': [64, 68, 71], # E major (E, G#, B) | |
'Em': [64, 67, 71], # E minor (E, G, B) | |
'F': [65, 69, 72], # F major (F, A, C) | |
'Fm': [65, 68, 72], # F minor (F, Ab, C) | |
'G': [67, 71, 74], # G major (G, B, D) | |
'Gm': [67, 70, 74], # G minor (G, Bb, D) | |
'A': [69, 73, 76], # A major (A, C#, E) | |
'Am': [69, 72, 76], # A minor (A, C, E) | |
'B': [71, 75, 78], # B major (B, D#, F#) | |
'Bm': [71, 74, 78] # B minor (B, D, F#) | |
} | |
if chord in chord_notes: | |
notes = chord_notes[chord] | |
# Note on events | |
for note in notes: | |
events.append(['note_on', 0, 0, 0, 0, note, velocity]) | |
# Note off events | |
for note in notes: | |
events.append(['note_off', duration, 0, 0, 0, note, 0]) | |
return events | |
def create_chord_sequence(tokenizer, chords, pattern="simple", duration=480): | |
"""Create a sequence of chord events with a pattern""" | |
events = [] | |
for chord in chords: | |
if pattern == "simple": | |
# Just play the chord | |
events.extend(create_chord_events(chord, duration)) | |
elif pattern == "arpeggio": | |
# Arpeggiate the chord | |
chord_notes = { | |
'C': [60, 64, 67], | |
'Cm': [60, 63, 67], | |
'D': [62, 66, 69], | |
'Dm': [62, 65, 69], | |
'E': [64, 68, 71], | |
'Em': [64, 67, 71], | |
'F': [65, 69, 72], | |
'Fm': [65, 68, 72], | |
'G': [67, 71, 74], | |
'Gm': [67, 70, 74], | |
'A': [69, 73, 76], | |
'Am': [69, 72, 76], | |
'B': [71, 75, 78], | |
'Bm': [71, 74, 78] | |
} | |
if chord in chord_notes: | |
notes = chord_notes[chord] | |
for i, note in enumerate(notes): | |
events.append(['note_on', 0 if i == 0 else duration//4, 0, 0, 0, note, 80]) | |
events.append(['note_off', duration//4, 0, 0, 0, note, 0]) | |
# Add final pause to complete the bar | |
events.append(['note_on', 0, 0, 0, 0, notes[0], 0]) | |
events.append(['note_off', duration//4, 0, 0, 0, notes[0], 0]) | |
# Convert events to tokens | |
tokens = [] | |
for event in events: | |
tokens.append(tokenizer.event2tokens(event)) | |
return tokens | |
def add_chord_sequence(model_name, mid_seq, root_chord="C", progression_type="pop-verse", pattern="simple"): | |
"""Add a chord sequence to the MIDI sequence""" | |
tokenizer = models[model_name].tokenizer | |
# Generate chord progression | |
chord_progression = create_chord_progressions(root_chord, progression_type) | |
# Create chord sequence tokens | |
tokens = create_chord_sequence(tokenizer, chord_progression, pattern) | |
# Add tokens to sequence | |
if mid_seq is None: | |
mid_seq = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)] | |
mid_seq = [mid_seq] * OUTPUT_BATCH_SIZE | |
# Add tokens to the first sequence | |
mid_seq[0].extend(tokens) | |
return mid_seq | |
def create_song_structure(model_name, root_chord="C"): | |
"""Create a complete song structure with verse, chorus, etc.""" | |
tokenizer = models[model_name].tokenizer | |
# Initialize sequence | |
mid_seq = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)] | |
mid_seq = [mid_seq] * OUTPUT_BATCH_SIZE | |
# Add intro | |
intro_tokens = create_chord_sequence(tokenizer, | |
create_chord_progressions(root_chord, "pop-verse"), | |
"arpeggio") | |
mid_seq[0].extend(intro_tokens) | |
# Add verse | |
verse_tokens = create_chord_sequence(tokenizer, | |
create_chord_progressions(root_chord, "pop-verse"), | |
"simple") | |
mid_seq[0].extend(verse_tokens) | |
# Add chorus | |
chorus_tokens = create_chord_sequence(tokenizer, | |
create_chord_progressions(root_chord, "pop-chorus"), | |
"simple") | |
mid_seq[0].extend(chorus_tokens) | |
# Add outro | |
outro_tokens = create_chord_sequence(tokenizer, | |
create_chord_progressions(root_chord, "ballad"), | |
"arpeggio") | |
mid_seq[0].extend(outro_tokens) | |
return mid_seq | |
def load_javascript(dir="javascript"): | |
scripts_list = glob.glob(f"{dir}/*.js") | |
javascript = "" | |
for path in scripts_list: | |
with open(path, "r", encoding="utf8") as jsfile: | |
js_content = jsfile.read() | |
js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;", | |
f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};") | |
javascript += f"\n<!-- {path} --><script>{js_content}</script>" | |
template_response_ori = gr.routes.templates.TemplateResponse | |
def template_response(*args, **kwargs): | |
res = template_response_ori(*args, **kwargs) | |
res.body = res.body.replace( | |
b'</head>', f'{javascript}</head>'.encode("utf8")) | |
res.init_headers() | |
return res | |
gr.routes.templates.TemplateResponse = template_response | |
def render_audio(model_name, mid_seq, should_render_audio): | |
if (not should_render_audio) or mid_seq is None: | |
outputs = [None] * OUTPUT_BATCH_SIZE | |
return tuple(outputs) | |
tokenizer = models[model_name].tokenizer | |
outputs = [] | |
if not os.path.exists("outputs"): | |
os.mkdir("outputs") | |
audio_futures = [] | |
for i in range(OUTPUT_BATCH_SIZE): | |
mid = tokenizer.detokenize(mid_seq[i]) | |
audio_future = thread_pool.submit(synthesis_task, mid) | |
audio_futures.append(audio_future) | |
for future in audio_futures: | |
outputs.append((44100, future.result())) | |
if OUTPUT_BATCH_SIZE == 1: | |
return outputs[0] | |
return tuple(outputs) | |
def synthesis_task(mid): | |
return synthesizer.synthesis(MIDI.score2opus(mid)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app") | |
parser.add_argument("--port", type=int, default=7860, help="gradio server port") | |
parser.add_argument("--device", type=str, default="cuda", help="device to run model") | |
parser.add_argument("--batch", type=int, default=4, help="batch size") | |
parser.add_argument("--max-gen", type=int, default=1024, help="max") | |
opt = parser.parse_args() | |
OUTPUT_BATCH_SIZE = opt.batch | |
# Initialize models (simplified version) | |
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE) | |
synthesizer = MidiSynthesizer(soundfont_path) | |
models_info = { | |
"generic pretrain model (tv2o-medium) by skytnt": [ | |
"skytnt/midi-model-tv2o-medium", {} | |
] | |
} | |
models = {} | |
# Initialize models (simplified) | |
for name, (repo_id, loras) in models_info.items(): | |
model = MIDIModel.from_pretrained(repo_id) | |
model.to(device="cpu", dtype=torch.float32) | |
models[name] = model | |
load_javascript() | |
app = gr.Blocks(theme=gr.themes.Soft()) | |
with app: | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>π΅ Chord-Emoji MIDI Composer π΅</h1>") | |
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False) | |
js_msg.change(None, [js_msg], [], js=""" | |
(msg_json) =>{ | |
let msgs = JSON.parse(msg_json); | |
executeCallbacks(msgReceiveCallbacks, msgs); | |
return []; | |
} | |
""") | |
input_model = gr.Dropdown(label="Select Model", choices=list(models.keys()), | |
type="value", value=list(models.keys())[0]) | |
# Main chord progression section | |
with gr.Tabs(): | |
with gr.TabItem("Chord Progressions") as tab1: | |
with gr.Row(): | |
root_chord = gr.Dropdown(label="Root Chord", choices=["C", "D", "E", "F", "G", "A", "B"], | |
value="C") | |
progression_type = gr.Dropdown(label="Progression Type", | |
choices=list(PROGRESSION_PATTERNS.keys()), | |
value="pop-verse") | |
# Emoji-Chord Button Grid - Create a 2x7 grid of chord buttons | |
gr.Markdown("### Chord Buttons - Click to Add Individual Chords") | |
with gr.Row(): | |
chord_buttons_major = [] | |
for chord in ["C", "D", "E", "F", "G", "A", "B"]: | |
emoji = CHORD_EMOJIS.get(chord, "π΅") | |
btn = gr.Button(f"{emoji} {chord}", size="sm") | |
chord_buttons_major.append((chord, btn)) | |
with gr.Row(): | |
chord_buttons_minor = [] | |
for chord in ["Cm", "Dm", "Em", "Fm", "Gm", "Am", "Bm"]: | |
emoji = CHORD_EMOJIS.get(chord, "π΅") | |
btn = gr.Button(f"{emoji} {chord}", size="sm") | |
chord_buttons_minor.append((chord, btn)) | |
# Song structure buttons | |
gr.Markdown("### Song Structure Patterns - Click to Add a Pattern") | |
with gr.Row(): | |
intro_btn = gr.Button("π΅ Intro", variant="primary") | |
verse_btn = gr.Button("πΈ Verse", variant="primary") | |
chorus_btn = gr.Button("πΉ Chorus", variant="primary") | |
bridge_btn = gr.Button("π· Bridge", variant="primary") | |
outro_btn = gr.Button("πͺ Outro", variant="primary") | |
with gr.Row(): | |
blues_btn = gr.Button("πΊ 12-Bar Blues", variant="primary") | |
jazz_btn = gr.Button("π» Jazz Pattern", variant="primary") | |
ballad_btn = gr.Button("π€ Ballad", variant="primary") | |
with gr.Row(): | |
pattern_type = gr.Radio(label="Pattern Style", | |
choices=["simple", "arpeggio"], | |
value="simple") | |
with gr.Row(): | |
clear_btn = gr.Button("ποΈ Clear Sequence", variant="secondary") | |
play_btn = gr.Button("βΆοΈ Play Current Sequence", variant="primary") | |
with gr.TabItem("Custom MIDI Settings") as tab2: | |
input_instruments = gr.Dropdown(label="πͺ Instruments (auto if empty)", | |
choices=["Acoustic Grand", "Electric Piano", "Violin", "Guitar"], | |
multiselect=True, type="value") | |
input_bpm = gr.Slider(label="BPM (beats per minute)", minimum=60, maximum=180, | |
step=1, value=120) | |
# Output section | |
output_midi_seq = gr.State() | |
output_continuation_state = gr.State([0]) | |
midi_outputs = [] | |
audio_outputs = [] | |
with gr.Tabs(elem_id="output_tabs"): | |
for i in range(OUTPUT_BATCH_SIZE): | |
with gr.TabItem(f"Output {i + 1}") as tab: | |
output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}") | |
output_audio = gr.Audio(label="Output Audio", format="mp3", elem_id=f"midi_audio_{i}") | |
output_midi = gr.File(label="Output MIDI", file_types=[".mid"]) | |
midi_outputs.append(output_midi) | |
audio_outputs.append(output_audio) | |
# Connect chord buttons to functions | |
for chord, btn in chord_buttons_major + chord_buttons_minor: | |
btn.click( | |
fn=lambda chord=chord, m=input_model, seq=output_midi_seq, pt=pattern_type: | |
add_chord_sequence(m, seq, chord, "ballad", pt.value), | |
inputs=[input_model, output_midi_seq, pattern_type], | |
outputs=[output_midi_seq] | |
) | |
# Connect song structure buttons | |
intro_btn.click( | |
fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord: | |
add_chord_sequence(m, seq, rc.value, "pop-verse", "arpeggio"), | |
inputs=[input_model, output_midi_seq, root_chord], | |
outputs=[output_midi_seq] | |
) | |
verse_btn.click( | |
fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord: | |
add_chord_sequence(m, seq, rc.value, "pop-verse", "simple"), | |
inputs=[input_model, output_midi_seq, root_chord], | |
outputs=[output_midi_seq] | |
) | |
chorus_btn.click( | |
fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord: | |
add_chord_sequence(m, seq, rc.value, "pop-chorus", "simple"), | |
inputs=[input_model, output_midi_seq, root_chord], | |
outputs=[output_midi_seq] | |
) | |
bridge_btn.click( | |
fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord: | |
add_chord_sequence(m, seq, rc.value, "jazz", "simple"), | |
inputs=[input_model, output_midi_seq, root_chord], | |
outputs=[output_midi_seq] | |
) | |
outro_btn.click( | |
fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord: | |
add_chord_sequence(m, seq, rc.value, "ballad", "arpeggio"), | |
inputs=[input_model, output_midi_seq, root_chord], | |
outputs=[output_midi_seq] | |
) | |
blues_btn.click( | |
fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord: | |
add_chord_sequence(m, seq, rc.value, "12-bar-blues", "simple"), | |
inputs=[input_model, output_midi_seq, root_chord], | |
outputs=[output_midi_seq] | |
) | |
jazz_btn.click( | |
fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord: | |
add_chord_sequence(m, seq, rc.value, "jazz", "simple"), | |
inputs=[input_model, output_midi_seq, root_chord], | |
outputs=[output_midi_seq] | |
) | |
ballad_btn.click( | |
fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord: | |
add_chord_sequence(m, seq, rc.value, "ballad", "simple"), | |
inputs=[input_model, output_midi_seq, root_chord], | |
outputs=[output_midi_seq] | |
) | |
# Clear and play buttons | |
clear_btn.click( | |
fn=lambda m=input_model: [[models[m].tokenizer.bos_id] + | |
[models[m].tokenizer.pad_id] * (models[m].tokenizer.max_token_seq - 1)] * OUTPUT_BATCH_SIZE, | |
inputs=[input_model], | |
outputs=[output_midi_seq] | |
) | |
# Play functionality - render audio and visualize | |
def prepare_playback(model_name, mid_seq): | |
if mid_seq is None: | |
return mid_seq, [], send_msgs([]) | |
tokenizer = models[model_name].tokenizer | |
msgs = [] | |
for i in range(OUTPUT_BATCH_SIZE): | |
events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]] | |
msgs += [ | |
create_msg("visualizer_clear", [i, tokenizer.version]), | |
create_msg("visualizer_append", [i, events]), | |
create_msg("visualizer_end", i) | |
] | |
return mid_seq, mid_seq, send_msgs(msgs) | |
play_btn.click( | |
fn=prepare_playback, | |
inputs=[input_model, output_midi_seq], | |
outputs=[output_midi_seq, output_continuation_state, js_msg] | |
).then( | |
fn=render_audio, | |
inputs=[input_model, output_midi_seq, gr.State(True)], | |
outputs=audio_outputs | |
) | |
app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False) | |
thread_pool.shutdown() |