midi-composer / app.py
awacke1's picture
Update app.py
c43c03a verified
raw
history blame
16.6 kB
import spaces
import random
import argparse
import glob
import json
import os
import time
import rtmidi
from concurrent.futures import ThreadPoolExecutor
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from transformers import DynamicCache
import MIDI
from midi_model import MIDIModel, MIDIModelConfig
from midi_synthesizer import MidiSynthesizer
MAX_SEED = np.iinfo(np.int32).max
in_space = os.getenv("SYSTEM") == "spaces"
# Chord to emoji mapping
CHORD_EMOJIS = {
'C': '🎡', 'Cm': '🎢', 'C7': '🎼', 'Cmaj7': '🎹', 'Cm7': '🎻',
'D': 'πŸ₯', 'Dm': 'πŸͺ˜', 'D7': '🎷', 'Dmaj7': '🎺', 'Dm7': 'πŸͺ•',
'E': '🎸', 'Em': '🎻', 'E7': '🎡', 'Emaj7': '🎢', 'Em7': '🎼',
'F': '🎹', 'Fm': '🎸', 'F7': '🎻', 'Fmaj7': '🎷', 'Fm7': '🎺',
'G': 'πŸͺ—', 'Gm': '🎡', 'G7': '🎢', 'Gmaj7': '🎼', 'Gm7': '🎹',
'A': '🎸', 'Am': '🎻', 'A7': '🎷', 'Amaj7': '🎺', 'Am7': 'πŸͺ•',
'B': '🎡', 'Bm': '🎢', 'B7': '🎼', 'Bmaj7': '🎹', 'Bm7': '🎻'
}
# Chord note definitions (MIDI note numbers)
CHORD_NOTES = {
'C': [60, 64, 67], # C major (C, E, G)
'Cm': [60, 63, 67], # C minor (C, Eb, G)
'C7': [60, 64, 67, 70], # C7 (C, E, G, Bb)
'Cmaj7': [60, 64, 67, 71], # Cmaj7 (C, E, G, B)
'Cm7': [60, 63, 67, 70], # Cm7 (C, Eb, G, Bb)
'D': [62, 66, 69], # D major (D, F#, A)
'Dm': [62, 65, 69], # D minor (D, F, A)
'D7': [62, 66, 69, 72], # D7 (D, F#, A, C)
'Dmaj7': [62, 66, 69, 73], # Dmaj7 (D, F#, A, C#)
'Dm7': [62, 65, 69, 72], # Dm7 (D, F, A, C)
'E': [64, 68, 71], # E major (E, G#, B)
'Em': [64, 67, 71], # E minor (E, G, B)
'E7': [64, 68, 71, 74], # E7 (E, G#, B, D)
'Emaj7': [64, 68, 71, 75], # Emaj7 (E, G#, B, D#)
'Em7': [64, 67, 71, 74], # Em7 (E, G, B, D)
'F': [65, 69, 72], # F major (F, A, C)
'Fm': [65, 68, 72], # F minor (F, Ab, C)
'F7': [65, 69, 72, 75], # F7 (F, A, C, Eb)
'Fmaj7': [65, 69, 72, 76], # Fmaj7 (F, A, C, E)
'Fm7': [65, 68, 72, 75], # Fm7 (F, Ab, C, Eb)
'G': [67, 71, 74], # G major (G, B, D)
'Gm': [67, 70, 74], # G minor (G, Bb, D)
'G7': [67, 71, 74, 77], # G7 (G, B, D, F)
'Gmaj7': [67, 71, 74, 78], # Gmaj7 (G, B, D, F#)
'Gm7': [67, 70, 74, 77], # Gm7 (G, Bb, D, F)
'A': [69, 73, 76], # A major (A, C#, E)
'Am': [69, 72, 76], # A minor (A, C, E)
'A7': [69, 73, 76, 79], # A7 (A, C#, E, G)
'Amaj7': [69, 73, 76, 80], # Amaj7 (A, C#, E, G#)
'Am7': [69, 72, 76, 79], # Am7 (A, C, E, G)
'B': [71, 75, 78], # B major (B, D#, F#)
'Bm': [71, 74, 78], # B minor (B, D, F#)
'B7': [71, 75, 78, 81], # B7 (B, D#, F#, A)
'Bmaj7': [71, 75, 78, 82], # Bmaj7 (B, D#, F#, A#)
'Bm7': [71, 74, 78, 81] # Bm7 (B, D, F#, A)
}
# MIDI device manager
class MIDIDeviceManager:
def __init__(self):
self.midi_out = rtmidi.MidiOut()
self.available_ports = self.midi_out.get_ports()
self.current_port = None
def get_available_devices(self):
"""Return list of available MIDI output devices"""
self.available_ports = self.midi_out.get_ports()
return self.available_ports
def open_port(self, port_index):
"""Open a MIDI port by index"""
if 0 <= port_index < len(self.available_ports):
if self.current_port is not None:
self.midi_out.close_port()
self.midi_out.open_port(port_index)
self.current_port = port_index
return True
return False
def send_note_on(self, note, velocity=64, channel=0):
"""Send MIDI note on message"""
if self.current_port is not None:
message = [0x90 + channel, note, velocity]
self.midi_out.send_message(message)
def send_note_off(self, note, velocity=0, channel=0):
"""Send MIDI note off message"""
if self.current_port is not None:
message = [0x80 + channel, note, velocity]
self.midi_out.send_message(message)
def send_program_change(self, program, channel=0):
"""Send MIDI program change message"""
if self.current_port is not None:
message = [0xC0 + channel, program]
self.midi_out.send_message(message)
def play_chord(self, chord_name, velocity=80, channel=0, duration=None):
"""Play a chord by name with optional automatic release"""
if chord_name in CHORD_NOTES:
notes = CHORD_NOTES[chord_name]
for note in notes:
self.send_note_on(note, velocity, channel)
if duration is not None:
# Automatic note off after duration
time.sleep(duration)
for note in notes:
self.send_note_off(note, 0, channel)
def release_chord(self, chord_name, channel=0):
"""Release all notes in a chord"""
if chord_name in CHORD_NOTES:
notes = CHORD_NOTES[chord_name]
for note in notes:
self.send_note_off(note, 0, channel)
def close(self):
"""Close current MIDI port"""
if self.current_port is not None:
self.midi_out.close_port()
self.current_port = None
# Global MIDI manager
midi_manager = MIDIDeviceManager()
def create_msg(name, data):
return {"name": name, "data": data}
def send_msgs(msgs):
return json.dumps(msgs)
def create_chord_events(chord, duration=480, velocity=80):
"""Create MIDI events for a chord"""
events = []
if chord in CHORD_NOTES:
notes = CHORD_NOTES[chord]
# Note on events
for note in notes:
events.append(['note_on', 0, 0, 0, 0, note, velocity])
# Note off events after specified duration
for note in notes:
events.append(['note_off', duration, 0, 0, 0, note, 0])
return events
def add_chord_to_queue(chord_name, chord_queue, max_queue_size=8):
"""Add a chord to the playback queue"""
if len(chord_queue) >= max_queue_size:
chord_queue.pop(0) # Remove oldest chord
chord_queue.append(chord_name)
return chord_queue
def play_chord_on_device(chord_name, midi_device_index):
"""Play a chord on the selected MIDI device"""
if midi_device_index is not None and midi_device_index >= 0:
midi_manager.open_port(midi_device_index)
midi_manager.play_chord(chord_name, duration=0.5)
return chord_name
def play_chord_sequence(chord_queue, midi_device_index, tempo=120):
"""Play a sequence of chords at the specified tempo"""
if midi_device_index is not None and midi_device_index >= 0:
# Calculate timing based on tempo (beats per minute)
beat_duration = 60.0 / tempo # seconds per beat
midi_manager.open_port(midi_device_index)
for chord in chord_queue:
midi_manager.play_chord(chord, duration=beat_duration)
# Add a small gap between chords
time.sleep(0.05)
return chord_queue
def refresh_midi_devices():
"""Refresh the list of available MIDI devices"""
return gr.Dropdown.update(choices=midi_manager.get_available_devices())
def hf_hub_download_retry(repo_id, filename):
print(f"downloading {repo_id} {filename}")
retry = 0
err = None
while retry < 30:
try:
return hf_hub_download(repo_id=repo_id, filename=filename)
except Exception as e:
err = e
retry += 1
if err:
raise err
def load_javascript(dir="javascript"):
scripts_list = glob.glob(f"{dir}/*.js")
javascript = ""
for path in scripts_list:
with open(path, "r", encoding="utf8") as jsfile:
js_content = jsfile.read()
js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;",
f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};")
javascript += f"\n<!-- {path} --><script>{js_content}</script>"
template_response_ori = gr.routes.templates.TemplateResponse
def template_response(*args, **kwargs):
res = template_response_ori(*args, **kwargs)
res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers()
return res
gr.routes.templates.TemplateResponse = template_response
def create_virtual_keyboard(chord_types):
"""Create virtual keyboard buttons organized by root note and chord type"""
root_notes = ['C', 'D', 'E', 'F', 'G', 'A', 'B']
buttons = {}
for root in root_notes:
buttons[root] = {}
for chord_type in chord_types:
chord_name = f"{root}{chord_type}"
emoji = CHORD_EMOJIS.get(chord_name, "🎡")
buttons[root][chord_type] = (chord_name, emoji)
return buttons
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
parser.add_argument("--port", type=int, default=7860, help="gradio server port")
parser.add_argument("--device", type=str, default="cuda", help="device to run model")
parser.add_argument("--batch", type=int, default=4, help="batch size")
parser.add_argument("--max-gen", type=int, default=1024, help="max")
opt = parser.parse_args()
OUTPUT_BATCH_SIZE = opt.batch
# Initialize MIDI device manager
midi_manager = MIDIDeviceManager()
# Initialize models (simplified version)
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
synthesizer = MidiSynthesizer(soundfont_path)
# Define chord types to use in the virtual keyboard
chord_types = ['', 'm', '7', 'maj7', 'm7']
# Create virtual keyboard structure
keyboard = create_virtual_keyboard(chord_types)
# Define CSS for the virtual keyboard
keyboard_css = """
.chord-button {
margin: 4px;
min-width: 80px;
height: 60px;
font-size: 18px;
font-weight: bold;
border-radius: 8px;
transition: all 0.2s;
}
.chord-button:active {
transform: scale(0.95);
}
.chord-queue {
padding: 10px;
background: #f5f5f5;
border-radius: 8px;
min-height: 50px;
font-size: 16px;
margin-bottom: 15px;
}
.root-c { background-color: #FFCDD2; }
.root-d { background-color: #F8BBD0; }
.root-e { background-color: #E1BEE7; }
.root-f { background-color: #D1C4E9; }
.root-g { background-color: #C5CAE9; }
.root-a { background-color: #BBDEFB; }
.root-b { background-color: #B3E5FC; }
"""
load_javascript()
app = gr.Blocks(theme=gr.themes.Soft(), css=keyboard_css)
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>🎡 Real-Time MIDI Chord Keyboard 🎡</h1>")
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
js_msg.change(None, [js_msg], [], js="""
(msg_json) =>{
let msgs = JSON.parse(msg_json);
executeCallbacks(msgReceiveCallbacks, msgs);
return [];
}
""")
# MIDI Device Settings
with gr.Row():
with gr.Column(scale=3):
midi_device = gr.Dropdown(label="MIDI Output Device",
choices=midi_manager.get_available_devices(),
type="index")
refresh_button = gr.Button("πŸ”„ Refresh MIDI Devices")
with gr.Column(scale=1):
tempo = gr.Slider(label="Tempo (BPM)",
minimum=40,
maximum=200,
value=120,
step=1)
# Chord Queue Display
chord_queue = gr.State([])
queue_display = gr.Markdown("### Current Chord Queue\n*No chords in queue*",
elem_classes=["chord-queue"])
# Play queue button
play_queue_button = gr.Button("▢️ Play Chord Sequence", variant="primary", size="lg")
# Clear queue button
clear_queue_button = gr.Button("πŸ—‘οΈ Clear Queue", variant="secondary")
# Virtual Keyboard - Create sections for each root note
gr.Markdown("## Virtual Chord Keyboard")
for root in ['C', 'D', 'E', 'F', 'G', 'A', 'B']:
with gr.Row():
gr.Markdown(f"### {root}")
for chord_type in chord_types:
chord_name, emoji = keyboard[root][chord_type]
display_name = chord_name if chord_type == '' else chord_name
button = gr.Button(f"{emoji} {display_name}",
elem_classes=[f"chord-button root-{root.lower()}"])
# Connect the button to add chord to queue and play it immediately
button.click(
fn=play_chord_on_device,
inputs=[gr.State(chord_name), midi_device],
outputs=None
).then(
fn=add_chord_to_queue,
inputs=[gr.State(chord_name), chord_queue],
outputs=[chord_queue]
).then(
fn=lambda q: f"### Current Chord Queue\n" + " β†’ ".join(q) if q else "*No chords in queue*",
inputs=[chord_queue],
outputs=[queue_display]
)
# Connect refresh button
refresh_button.click(
fn=refresh_midi_devices,
inputs=None,
outputs=[midi_device]
)
# Connect play queue button
play_queue_button.click(
fn=play_chord_sequence,
inputs=[chord_queue, midi_device, tempo],
outputs=[chord_queue]
)
# Connect clear queue button
clear_queue_button.click(
fn=lambda: [],
inputs=None,
outputs=[chord_queue]
).then(
fn=lambda: "### Current Chord Queue\n*No chords in queue*",
inputs=None,
outputs=[queue_display]
)
# MIDI Generation Settings (for advanced users)
with gr.Accordion("Advanced MIDI Settings", open=False):
with gr.Row():
midi_channel = gr.Slider(label="MIDI Channel",
minimum=0,
maximum=15,
value=0,
step=1)
instrument = gr.Dropdown(label="Instrument",
choices=[(f"{i}: {name}", i) for i, name in enumerate([
"Acoustic Grand Piano", "Bright Acoustic Piano", "Electric Grand Piano",
"Honky-tonk Piano", "Electric Piano 1", "Electric Piano 2", "Harpsichord",
"Clavinet", "Celesta", "Glockenspiel", "Music Box", "Vibraphone",
"Marimba", "Xylophone", "Tubular Bells", "Dulcimer"
])],
value=0)
velocity = gr.Slider(label="Velocity",
minimum=1,
maximum=127,
value=80,
step=1)
# Program change button
program_change_button = gr.Button("Send Program Change")
program_change_button.click(
fn=lambda inst, chan: midi_manager.send_program_change(inst, chan),
inputs=[instrument, midi_channel],
outputs=None
)
app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
# Clean up MIDI connections when the app closes
midi_manager.close()