import spaces import random import argparse import glob import json import os import time import rtmidi from concurrent.futures import ThreadPoolExecutor import gradio as gr import numpy as np import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from transformers import DynamicCache import MIDI from midi_model import MIDIModel, MIDIModelConfig from midi_synthesizer import MidiSynthesizer MAX_SEED = np.iinfo(np.int32).max in_space = os.getenv("SYSTEM") == "spaces" # Chord to emoji mapping CHORD_EMOJIS = { 'C': '🎵', 'Cm': '🎶', 'C7': '🎼', 'Cmaj7': '🎹', 'Cm7': '🎻', 'D': '🥁', 'Dm': '🪘', 'D7': '🎷', 'Dmaj7': '🎺', 'Dm7': '🪕', 'E': '🎸', 'Em': '🎻', 'E7': '🎵', 'Emaj7': '🎶', 'Em7': '🎼', 'F': '🎹', 'Fm': '🎸', 'F7': '🎻', 'Fmaj7': '🎷', 'Fm7': '🎺', 'G': '🪗', 'Gm': '🎵', 'G7': '🎶', 'Gmaj7': '🎼', 'Gm7': '🎹', 'A': '🎸', 'Am': '🎻', 'A7': '🎷', 'Amaj7': '🎺', 'Am7': '🪕', 'B': '🎵', 'Bm': '🎶', 'B7': '🎼', 'Bmaj7': '🎹', 'Bm7': '🎻' } # Chord note definitions (MIDI note numbers) CHORD_NOTES = { 'C': [60, 64, 67], # C major (C, E, G) 'Cm': [60, 63, 67], # C minor (C, Eb, G) 'C7': [60, 64, 67, 70], # C7 (C, E, G, Bb) 'Cmaj7': [60, 64, 67, 71], # Cmaj7 (C, E, G, B) 'Cm7': [60, 63, 67, 70], # Cm7 (C, Eb, G, Bb) 'D': [62, 66, 69], # D major (D, F#, A) 'Dm': [62, 65, 69], # D minor (D, F, A) 'D7': [62, 66, 69, 72], # D7 (D, F#, A, C) 'Dmaj7': [62, 66, 69, 73], # Dmaj7 (D, F#, A, C#) 'Dm7': [62, 65, 69, 72], # Dm7 (D, F, A, C) 'E': [64, 68, 71], # E major (E, G#, B) 'Em': [64, 67, 71], # E minor (E, G, B) 'E7': [64, 68, 71, 74], # E7 (E, G#, B, D) 'Emaj7': [64, 68, 71, 75], # Emaj7 (E, G#, B, D#) 'Em7': [64, 67, 71, 74], # Em7 (E, G, B, D) 'F': [65, 69, 72], # F major (F, A, C) 'Fm': [65, 68, 72], # F minor (F, Ab, C) 'F7': [65, 69, 72, 75], # F7 (F, A, C, Eb) 'Fmaj7': [65, 69, 72, 76], # Fmaj7 (F, A, C, E) 'Fm7': [65, 68, 72, 75], # Fm7 (F, Ab, C, Eb) 'G': [67, 71, 74], # G major (G, B, D) 'Gm': [67, 70, 74], # G minor (G, Bb, D) 'G7': [67, 71, 74, 77], # G7 (G, B, D, F) 'Gmaj7': [67, 71, 74, 78], # Gmaj7 (G, B, D, F#) 'Gm7': [67, 70, 74, 77], # Gm7 (G, Bb, D, F) 'A': [69, 73, 76], # A major (A, C#, E) 'Am': [69, 72, 76], # A minor (A, C, E) 'A7': [69, 73, 76, 79], # A7 (A, C#, E, G) 'Amaj7': [69, 73, 76, 80], # Amaj7 (A, C#, E, G#) 'Am7': [69, 72, 76, 79], # Am7 (A, C, E, G) 'B': [71, 75, 78], # B major (B, D#, F#) 'Bm': [71, 74, 78], # B minor (B, D, F#) 'B7': [71, 75, 78, 81], # B7 (B, D#, F#, A) 'Bmaj7': [71, 75, 78, 82], # Bmaj7 (B, D#, F#, A#) 'Bm7': [71, 74, 78, 81] # Bm7 (B, D, F#, A) } # MIDI device manager class MIDIDeviceManager: def __init__(self): self.midi_out = rtmidi.MidiOut() self.available_ports = self.midi_out.get_ports() self.current_port = None def get_available_devices(self): """Return list of available MIDI output devices""" self.available_ports = self.midi_out.get_ports() return self.available_ports def open_port(self, port_index): """Open a MIDI port by index""" if 0 <= port_index < len(self.available_ports): if self.current_port is not None: self.midi_out.close_port() self.midi_out.open_port(port_index) self.current_port = port_index return True return False def send_note_on(self, note, velocity=64, channel=0): """Send MIDI note on message""" if self.current_port is not None: message = [0x90 + channel, note, velocity] self.midi_out.send_message(message) def send_note_off(self, note, velocity=0, channel=0): """Send MIDI note off message""" if self.current_port is not None: message = [0x80 + channel, note, velocity] self.midi_out.send_message(message) def send_program_change(self, program, channel=0): """Send MIDI program change message""" if self.current_port is not None: message = [0xC0 + channel, program] self.midi_out.send_message(message) def play_chord(self, chord_name, velocity=80, channel=0, duration=None): """Play a chord by name with optional automatic release""" if chord_name in CHORD_NOTES: notes = CHORD_NOTES[chord_name] for note in notes: self.send_note_on(note, velocity, channel) if duration is not None: # Automatic note off after duration time.sleep(duration) for note in notes: self.send_note_off(note, 0, channel) def release_chord(self, chord_name, channel=0): """Release all notes in a chord""" if chord_name in CHORD_NOTES: notes = CHORD_NOTES[chord_name] for note in notes: self.send_note_off(note, 0, channel) def close(self): """Close current MIDI port""" if self.current_port is not None: self.midi_out.close_port() self.current_port = None # Global MIDI manager midi_manager = MIDIDeviceManager() def create_msg(name, data): return {"name": name, "data": data} def send_msgs(msgs): return json.dumps(msgs) def create_chord_events(chord, duration=480, velocity=80): """Create MIDI events for a chord""" events = [] if chord in CHORD_NOTES: notes = CHORD_NOTES[chord] # Note on events for note in notes: events.append(['note_on', 0, 0, 0, 0, note, velocity]) # Note off events after specified duration for note in notes: events.append(['note_off', duration, 0, 0, 0, note, 0]) return events def add_chord_to_queue(chord_name, chord_queue, max_queue_size=8): """Add a chord to the playback queue""" if len(chord_queue) >= max_queue_size: chord_queue.pop(0) # Remove oldest chord chord_queue.append(chord_name) return chord_queue def play_chord_on_device(chord_name, midi_device_index): """Play a chord on the selected MIDI device""" if midi_device_index is not None and midi_device_index >= 0: midi_manager.open_port(midi_device_index) midi_manager.play_chord(chord_name, duration=0.5) return chord_name def play_chord_sequence(chord_queue, midi_device_index, tempo=120): """Play a sequence of chords at the specified tempo""" if midi_device_index is not None and midi_device_index >= 0: # Calculate timing based on tempo (beats per minute) beat_duration = 60.0 / tempo # seconds per beat midi_manager.open_port(midi_device_index) for chord in chord_queue: midi_manager.play_chord(chord, duration=beat_duration) # Add a small gap between chords time.sleep(0.05) return chord_queue def refresh_midi_devices(): """Refresh the list of available MIDI devices""" return gr.Dropdown.update(choices=midi_manager.get_available_devices()) def hf_hub_download_retry(repo_id, filename): print(f"downloading {repo_id} {filename}") retry = 0 err = None while retry < 30: try: return hf_hub_download(repo_id=repo_id, filename=filename) except Exception as e: err = e retry += 1 if err: raise err def load_javascript(dir="javascript"): scripts_list = glob.glob(f"{dir}/*.js") javascript = "" for path in scripts_list: with open(path, "r", encoding="utf8") as jsfile: js_content = js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;", f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};") javascript += f"\n" template_response_ori = gr.routes.templates.TemplateResponse def template_response(*args, **kwargs): res = template_response_ori(*args, **kwargs) res.body = res.body.replace( b'', f'{javascript}'.encode("utf8")) res.init_headers() return res gr.routes.templates.TemplateResponse = template_response def create_virtual_keyboard(chord_types): """Create virtual keyboard buttons organized by root note and chord type""" root_notes = ['C', 'D', 'E', 'F', 'G', 'A', 'B'] buttons = {} for root in root_notes: buttons[root] = {} for chord_type in chord_types: chord_name = f"{root}{chord_type}" emoji = CHORD_EMOJIS.get(chord_name, "🎵") buttons[root][chord_type] = (chord_name, emoji) return buttons if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--share", action="store_true", default=False, help="share gradio app") parser.add_argument("--port", type=int, default=7860, help="gradio server port") parser.add_argument("--device", type=str, default="cuda", help="device to run model") parser.add_argument("--batch", type=int, default=4, help="batch size") parser.add_argument("--max-gen", type=int, default=1024, help="max") opt = parser.parse_args() OUTPUT_BATCH_SIZE = opt.batch # Initialize MIDI device manager midi_manager = MIDIDeviceManager() # Initialize models (simplified version) soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2") thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE) synthesizer = MidiSynthesizer(soundfont_path) # Define chord types to use in the virtual keyboard chord_types = ['', 'm', '7', 'maj7', 'm7'] # Create virtual keyboard structure keyboard = create_virtual_keyboard(chord_types) # Define CSS for the virtual keyboard keyboard_css = """ .chord-button { margin: 4px; min-width: 80px; height: 60px; font-size: 18px; font-weight: bold; border-radius: 8px; transition: all 0.2s; } .chord-button:active { transform: scale(0.95); } .chord-queue { padding: 10px; background: #f5f5f5; border-radius: 8px; min-height: 50px; font-size: 16px; margin-bottom: 15px; } .root-c { background-color: #FFCDD2; } .root-d { background-color: #F8BBD0; } .root-e { background-color: #E1BEE7; } .root-f { background-color: #D1C4E9; } .root-g { background-color: #C5CAE9; } .root-a { background-color: #BBDEFB; } .root-b { background-color: #B3E5FC; } """ load_javascript() app = gr.Blocks(theme=gr.themes.Soft(), css=keyboard_css) with app: gr.Markdown("

🎵 Real-Time MIDI Chord Keyboard 🎵

") js_msg = gr.Textbox(elem_id="msg_receiver", visible=False) js_msg.change(None, [js_msg], [], js=""" (msg_json) =>{ let msgs = JSON.parse(msg_json); executeCallbacks(msgReceiveCallbacks, msgs); return []; } """) # MIDI Device Settings with gr.Row(): with gr.Column(scale=3): midi_device = gr.Dropdown(label="MIDI Output Device", choices=midi_manager.get_available_devices(), type="index") refresh_button = gr.Button("🔄 Refresh MIDI Devices") with gr.Column(scale=1): tempo = gr.Slider(label="Tempo (BPM)", minimum=40, maximum=200, value=120, step=1) # Chord Queue Display chord_queue = gr.State([]) queue_display = gr.Markdown("### Current Chord Queue\n*No chords in queue*", elem_classes=["chord-queue"]) # Play queue button play_queue_button = gr.Button("▶️ Play Chord Sequence", variant="primary", size="lg") # Clear queue button clear_queue_button = gr.Button("🗑️ Clear Queue", variant="secondary") # Virtual Keyboard - Create sections for each root note gr.Markdown("## Virtual Chord Keyboard") for root in ['C', 'D', 'E', 'F', 'G', 'A', 'B']: with gr.Row(): gr.Markdown(f"### {root}") for chord_type in chord_types: chord_name, emoji = keyboard[root][chord_type] display_name = chord_name if chord_type == '' else chord_name button = gr.Button(f"{emoji} {display_name}", elem_classes=[f"chord-button root-{root.lower()}"]) # Connect the button to add chord to queue and play it immediately fn=play_chord_on_device, inputs=[gr.State(chord_name), midi_device], outputs=None ).then( fn=add_chord_to_queue, inputs=[gr.State(chord_name), chord_queue], outputs=[chord_queue] ).then( fn=lambda q: f"### Current Chord Queue\n" + " → ".join(q) if q else "*No chords in queue*", inputs=[chord_queue], outputs=[queue_display] ) # Connect refresh button fn=refresh_midi_devices, inputs=None, outputs=[midi_device] ) # Connect play queue button fn=play_chord_sequence, inputs=[chord_queue, midi_device, tempo], outputs=[chord_queue] ) # Connect clear queue button fn=lambda: [], inputs=None, outputs=[chord_queue] ).then( fn=lambda: "### Current Chord Queue\n*No chords in queue*", inputs=None, outputs=[queue_display] ) # MIDI Generation Settings (for advanced users) with gr.Accordion("Advanced MIDI Settings", open=False): with gr.Row(): midi_channel = gr.Slider(label="MIDI Channel", minimum=0, maximum=15, value=0, step=1) instrument = gr.Dropdown(label="Instrument", choices=[(f"{i}: {name}", i) for i, name in enumerate([ "Acoustic Grand Piano", "Bright Acoustic Piano", "Electric Grand Piano", "Honky-tonk Piano", "Electric Piano 1", "Electric Piano 2", "Harpsichord", "Clavinet", "Celesta", "Glockenspiel", "Music Box", "Vibraphone", "Marimba", "Xylophone", "Tubular Bells", "Dulcimer" ])], value=0) velocity = gr.Slider(label="Velocity", minimum=1, maximum=127, value=80, step=1) # Program change button program_change_button = gr.Button("Send Program Change") fn=lambda inst, chan: midi_manager.send_program_change(inst, chan), inputs=[instrument, midi_channel], outputs=None ) app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False) # Clean up MIDI connections when the app closes midi_manager.close()