Spaces:
Running
Running
| import spaces | |
| import random | |
| import argparse | |
| import glob | |
| import json | |
| import os | |
| import time | |
| import rtmidi | |
| from concurrent.futures import ThreadPoolExecutor | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| from transformers import DynamicCache | |
| import MIDI | |
| from midi_model import MIDIModel, MIDIModelConfig | |
| from midi_synthesizer import MidiSynthesizer | |
| MAX_SEED = np.iinfo(np.int32).max | |
| in_space = os.getenv("SYSTEM") == "spaces" | |
| # Chord to emoji mapping | |
| CHORD_EMOJIS = { | |
| 'C': 'π΅', 'Cm': 'πΆ', 'C7': 'πΌ', 'Cmaj7': 'πΉ', 'Cm7': 'π»', | |
| 'D': 'π₯', 'Dm': 'πͺ', 'D7': 'π·', 'Dmaj7': 'πΊ', 'Dm7': 'πͺ', | |
| 'E': 'πΈ', 'Em': 'π»', 'E7': 'π΅', 'Emaj7': 'πΆ', 'Em7': 'πΌ', | |
| 'F': 'πΉ', 'Fm': 'πΈ', 'F7': 'π»', 'Fmaj7': 'π·', 'Fm7': 'πΊ', | |
| 'G': 'πͺ', 'Gm': 'π΅', 'G7': 'πΆ', 'Gmaj7': 'πΌ', 'Gm7': 'πΉ', | |
| 'A': 'πΈ', 'Am': 'π»', 'A7': 'π·', 'Amaj7': 'πΊ', 'Am7': 'πͺ', | |
| 'B': 'π΅', 'Bm': 'πΆ', 'B7': 'πΌ', 'Bmaj7': 'πΉ', 'Bm7': 'π»' | |
| } | |
| # Chord note definitions (MIDI note numbers) | |
| CHORD_NOTES = { | |
| 'C': [60, 64, 67], # C major (C, E, G) | |
| 'Cm': [60, 63, 67], # C minor (C, Eb, G) | |
| 'C7': [60, 64, 67, 70], # C7 (C, E, G, Bb) | |
| 'Cmaj7': [60, 64, 67, 71], # Cmaj7 (C, E, G, B) | |
| 'Cm7': [60, 63, 67, 70], # Cm7 (C, Eb, G, Bb) | |
| 'D': [62, 66, 69], # D major (D, F#, A) | |
| 'Dm': [62, 65, 69], # D minor (D, F, A) | |
| 'D7': [62, 66, 69, 72], # D7 (D, F#, A, C) | |
| 'Dmaj7': [62, 66, 69, 73], # Dmaj7 (D, F#, A, C#) | |
| 'Dm7': [62, 65, 69, 72], # Dm7 (D, F, A, C) | |
| 'E': [64, 68, 71], # E major (E, G#, B) | |
| 'Em': [64, 67, 71], # E minor (E, G, B) | |
| 'E7': [64, 68, 71, 74], # E7 (E, G#, B, D) | |
| 'Emaj7': [64, 68, 71, 75], # Emaj7 (E, G#, B, D#) | |
| 'Em7': [64, 67, 71, 74], # Em7 (E, G, B, D) | |
| 'F': [65, 69, 72], # F major (F, A, C) | |
| 'Fm': [65, 68, 72], # F minor (F, Ab, C) | |
| 'F7': [65, 69, 72, 75], # F7 (F, A, C, Eb) | |
| 'Fmaj7': [65, 69, 72, 76], # Fmaj7 (F, A, C, E) | |
| 'Fm7': [65, 68, 72, 75], # Fm7 (F, Ab, C, Eb) | |
| 'G': [67, 71, 74], # G major (G, B, D) | |
| 'Gm': [67, 70, 74], # G minor (G, Bb, D) | |
| 'G7': [67, 71, 74, 77], # G7 (G, B, D, F) | |
| 'Gmaj7': [67, 71, 74, 78], # Gmaj7 (G, B, D, F#) | |
| 'Gm7': [67, 70, 74, 77], # Gm7 (G, Bb, D, F) | |
| 'A': [69, 73, 76], # A major (A, C#, E) | |
| 'Am': [69, 72, 76], # A minor (A, C, E) | |
| 'A7': [69, 73, 76, 79], # A7 (A, C#, E, G) | |
| 'Amaj7': [69, 73, 76, 80], # Amaj7 (A, C#, E, G#) | |
| 'Am7': [69, 72, 76, 79], # Am7 (A, C, E, G) | |
| 'B': [71, 75, 78], # B major (B, D#, F#) | |
| 'Bm': [71, 74, 78], # B minor (B, D, F#) | |
| 'B7': [71, 75, 78, 81], # B7 (B, D#, F#, A) | |
| 'Bmaj7': [71, 75, 78, 82], # Bmaj7 (B, D#, F#, A#) | |
| 'Bm7': [71, 74, 78, 81] # Bm7 (B, D, F#, A) | |
| } | |
| # MIDI device manager | |
| class MIDIDeviceManager: | |
| def __init__(self): | |
| self.midi_out = rtmidi.MidiOut() | |
| self.available_ports = self.midi_out.get_ports() | |
| self.current_port = None | |
| def get_available_devices(self): | |
| """Return list of available MIDI output devices""" | |
| self.available_ports = self.midi_out.get_ports() | |
| return self.available_ports | |
| def open_port(self, port_index): | |
| """Open a MIDI port by index""" | |
| if 0 <= port_index < len(self.available_ports): | |
| if self.current_port is not None: | |
| self.midi_out.close_port() | |
| self.midi_out.open_port(port_index) | |
| self.current_port = port_index | |
| return True | |
| return False | |
| def send_note_on(self, note, velocity=64, channel=0): | |
| """Send MIDI note on message""" | |
| if self.current_port is not None: | |
| message = [0x90 + channel, note, velocity] | |
| self.midi_out.send_message(message) | |
| def send_note_off(self, note, velocity=0, channel=0): | |
| """Send MIDI note off message""" | |
| if self.current_port is not None: | |
| message = [0x80 + channel, note, velocity] | |
| self.midi_out.send_message(message) | |
| def send_program_change(self, program, channel=0): | |
| """Send MIDI program change message""" | |
| if self.current_port is not None: | |
| message = [0xC0 + channel, program] | |
| self.midi_out.send_message(message) | |
| def play_chord(self, chord_name, velocity=80, channel=0, duration=None): | |
| """Play a chord by name with optional automatic release""" | |
| if chord_name in CHORD_NOTES: | |
| notes = CHORD_NOTES[chord_name] | |
| for note in notes: | |
| self.send_note_on(note, velocity, channel) | |
| if duration is not None: | |
| # Automatic note off after duration | |
| time.sleep(duration) | |
| for note in notes: | |
| self.send_note_off(note, 0, channel) | |
| def release_chord(self, chord_name, channel=0): | |
| """Release all notes in a chord""" | |
| if chord_name in CHORD_NOTES: | |
| notes = CHORD_NOTES[chord_name] | |
| for note in notes: | |
| self.send_note_off(note, 0, channel) | |
| def close(self): | |
| """Close current MIDI port""" | |
| if self.current_port is not None: | |
| self.midi_out.close_port() | |
| self.current_port = None | |
| # Global MIDI manager | |
| midi_manager = MIDIDeviceManager() | |
| def create_msg(name, data): | |
| return {"name": name, "data": data} | |
| def send_msgs(msgs): | |
| return json.dumps(msgs) | |
| def create_chord_events(chord, duration=480, velocity=80): | |
| """Create MIDI events for a chord""" | |
| events = [] | |
| if chord in CHORD_NOTES: | |
| notes = CHORD_NOTES[chord] | |
| # Note on events | |
| for note in notes: | |
| events.append(['note_on', 0, 0, 0, 0, note, velocity]) | |
| # Note off events after specified duration | |
| for note in notes: | |
| events.append(['note_off', duration, 0, 0, 0, note, 0]) | |
| return events | |
| def add_chord_to_queue(chord_name, chord_queue, max_queue_size=8): | |
| """Add a chord to the playback queue""" | |
| if len(chord_queue) >= max_queue_size: | |
| chord_queue.pop(0) # Remove oldest chord | |
| chord_queue.append(chord_name) | |
| return chord_queue | |
| def play_chord_on_device(chord_name, midi_device_index): | |
| """Play a chord on the selected MIDI device""" | |
| if midi_device_index is not None and midi_device_index >= 0: | |
| midi_manager.open_port(midi_device_index) | |
| midi_manager.play_chord(chord_name, duration=0.5) | |
| return chord_name | |
| def play_chord_sequence(chord_queue, midi_device_index, tempo=120): | |
| """Play a sequence of chords at the specified tempo""" | |
| if midi_device_index is not None and midi_device_index >= 0: | |
| # Calculate timing based on tempo (beats per minute) | |
| beat_duration = 60.0 / tempo # seconds per beat | |
| midi_manager.open_port(midi_device_index) | |
| for chord in chord_queue: | |
| midi_manager.play_chord(chord, duration=beat_duration) | |
| # Add a small gap between chords | |
| time.sleep(0.05) | |
| return chord_queue | |
| def refresh_midi_devices(): | |
| """Refresh the list of available MIDI devices""" | |
| return gr.Dropdown.update(choices=midi_manager.get_available_devices()) | |
| def hf_hub_download_retry(repo_id, filename): | |
| print(f"downloading {repo_id} {filename}") | |
| retry = 0 | |
| err = None | |
| while retry < 30: | |
| try: | |
| return hf_hub_download(repo_id=repo_id, filename=filename) | |
| except Exception as e: | |
| err = e | |
| retry += 1 | |
| if err: | |
| raise err | |
| def load_javascript(dir="javascript"): | |
| scripts_list = glob.glob(f"{dir}/*.js") | |
| javascript = "" | |
| for path in scripts_list: | |
| with open(path, "r", encoding="utf8") as jsfile: | |
| js_content = jsfile.read() | |
| js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;", | |
| f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};") | |
| javascript += f"\n<!-- {path} --><script>{js_content}</script>" | |
| template_response_ori = gr.routes.templates.TemplateResponse | |
| def template_response(*args, **kwargs): | |
| res = template_response_ori(*args, **kwargs) | |
| res.body = res.body.replace( | |
| b'</head>', f'{javascript}</head>'.encode("utf8")) | |
| res.init_headers() | |
| return res | |
| gr.routes.templates.TemplateResponse = template_response | |
| def create_virtual_keyboard(chord_types): | |
| """Create virtual keyboard buttons organized by root note and chord type""" | |
| root_notes = ['C', 'D', 'E', 'F', 'G', 'A', 'B'] | |
| buttons = {} | |
| for root in root_notes: | |
| buttons[root] = {} | |
| for chord_type in chord_types: | |
| chord_name = f"{root}{chord_type}" | |
| emoji = CHORD_EMOJIS.get(chord_name, "π΅") | |
| buttons[root][chord_type] = (chord_name, emoji) | |
| return buttons | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--share", action="store_true", default=False, help="share gradio app") | |
| parser.add_argument("--port", type=int, default=7860, help="gradio server port") | |
| parser.add_argument("--device", type=str, default="cuda", help="device to run model") | |
| parser.add_argument("--batch", type=int, default=4, help="batch size") | |
| parser.add_argument("--max-gen", type=int, default=1024, help="max") | |
| opt = parser.parse_args() | |
| OUTPUT_BATCH_SIZE = opt.batch | |
| # Initialize MIDI device manager | |
| midi_manager = MIDIDeviceManager() | |
| # Initialize models (simplified version) | |
| soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
| thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE) | |
| synthesizer = MidiSynthesizer(soundfont_path) | |
| # Define chord types to use in the virtual keyboard | |
| chord_types = ['', 'm', '7', 'maj7', 'm7'] | |
| # Create virtual keyboard structure | |
| keyboard = create_virtual_keyboard(chord_types) | |
| # Define CSS for the virtual keyboard | |
| keyboard_css = """ | |
| .chord-button { | |
| margin: 4px; | |
| min-width: 80px; | |
| height: 60px; | |
| font-size: 18px; | |
| font-weight: bold; | |
| border-radius: 8px; | |
| transition: all 0.2s; | |
| } | |
| .chord-button:active { | |
| transform: scale(0.95); | |
| } | |
| .chord-queue { | |
| padding: 10px; | |
| background: #f5f5f5; | |
| border-radius: 8px; | |
| min-height: 50px; | |
| font-size: 16px; | |
| margin-bottom: 15px; | |
| } | |
| .root-c { background-color: #FFCDD2; } | |
| .root-d { background-color: #F8BBD0; } | |
| .root-e { background-color: #E1BEE7; } | |
| .root-f { background-color: #D1C4E9; } | |
| .root-g { background-color: #C5CAE9; } | |
| .root-a { background-color: #BBDEFB; } | |
| .root-b { background-color: #B3E5FC; } | |
| """ | |
| load_javascript() | |
| app = gr.Blocks(theme=gr.themes.Soft(), css=keyboard_css) | |
| with app: | |
| gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>π΅ Real-Time MIDI Chord Keyboard π΅</h1>") | |
| js_msg = gr.Textbox(elem_id="msg_receiver", visible=False) | |
| js_msg.change(None, [js_msg], [], js=""" | |
| (msg_json) =>{ | |
| let msgs = JSON.parse(msg_json); | |
| executeCallbacks(msgReceiveCallbacks, msgs); | |
| return []; | |
| } | |
| """) | |
| # MIDI Device Settings | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| midi_device = gr.Dropdown(label="MIDI Output Device", | |
| choices=midi_manager.get_available_devices(), | |
| type="index") | |
| refresh_button = gr.Button("π Refresh MIDI Devices") | |
| with gr.Column(scale=1): | |
| tempo = gr.Slider(label="Tempo (BPM)", | |
| minimum=40, | |
| maximum=200, | |
| value=120, | |
| step=1) | |
| # Chord Queue Display | |
| chord_queue = gr.State([]) | |
| queue_display = gr.Markdown("### Current Chord Queue\n*No chords in queue*", | |
| elem_classes=["chord-queue"]) | |
| # Play queue button | |
| play_queue_button = gr.Button("βΆοΈ Play Chord Sequence", variant="primary", size="lg") | |
| # Clear queue button | |
| clear_queue_button = gr.Button("ποΈ Clear Queue", variant="secondary") | |
| # Virtual Keyboard - Create sections for each root note | |
| gr.Markdown("## Virtual Chord Keyboard") | |
| for root in ['C', 'D', 'E', 'F', 'G', 'A', 'B']: | |
| with gr.Row(): | |
| gr.Markdown(f"### {root}") | |
| for chord_type in chord_types: | |
| chord_name, emoji = keyboard[root][chord_type] | |
| display_name = chord_name if chord_type == '' else chord_name | |
| button = gr.Button(f"{emoji} {display_name}", | |
| elem_classes=[f"chord-button root-{root.lower()}"]) | |
| # Connect the button to add chord to queue and play it immediately | |
| button.click( | |
| fn=play_chord_on_device, | |
| inputs=[gr.State(chord_name), midi_device], | |
| outputs=None | |
| ).then( | |
| fn=add_chord_to_queue, | |
| inputs=[gr.State(chord_name), chord_queue], | |
| outputs=[chord_queue] | |
| ).then( | |
| fn=lambda q: f"### Current Chord Queue\n" + " β ".join(q) if q else "*No chords in queue*", | |
| inputs=[chord_queue], | |
| outputs=[queue_display] | |
| ) | |
| # Connect refresh button | |
| refresh_button.click( | |
| fn=refresh_midi_devices, | |
| inputs=None, | |
| outputs=[midi_device] | |
| ) | |
| # Connect play queue button | |
| play_queue_button.click( | |
| fn=play_chord_sequence, | |
| inputs=[chord_queue, midi_device, tempo], | |
| outputs=[chord_queue] | |
| ) | |
| # Connect clear queue button | |
| clear_queue_button.click( | |
| fn=lambda: [], | |
| inputs=None, | |
| outputs=[chord_queue] | |
| ).then( | |
| fn=lambda: "### Current Chord Queue\n*No chords in queue*", | |
| inputs=None, | |
| outputs=[queue_display] | |
| ) | |
| # MIDI Generation Settings (for advanced users) | |
| with gr.Accordion("Advanced MIDI Settings", open=False): | |
| with gr.Row(): | |
| midi_channel = gr.Slider(label="MIDI Channel", | |
| minimum=0, | |
| maximum=15, | |
| value=0, | |
| step=1) | |
| instrument = gr.Dropdown(label="Instrument", | |
| choices=[(f"{i}: {name}", i) for i, name in enumerate([ | |
| "Acoustic Grand Piano", "Bright Acoustic Piano", "Electric Grand Piano", | |
| "Honky-tonk Piano", "Electric Piano 1", "Electric Piano 2", "Harpsichord", | |
| "Clavinet", "Celesta", "Glockenspiel", "Music Box", "Vibraphone", | |
| "Marimba", "Xylophone", "Tubular Bells", "Dulcimer" | |
| ])], | |
| value=0) | |
| velocity = gr.Slider(label="Velocity", | |
| minimum=1, | |
| maximum=127, | |
| value=80, | |
| step=1) | |
| # Program change button | |
| program_change_button = gr.Button("Send Program Change") | |
| program_change_button.click( | |
| fn=lambda inst, chan: midi_manager.send_program_change(inst, chan), | |
| inputs=[instrument, midi_channel], | |
| outputs=None | |
| ) | |
| app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False) | |
| # Clean up MIDI connections when the app closes | |
| midi_manager.close() |