Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import random | |
import argparse | |
import glob | |
import json | |
import os | |
import time | |
import rtmidi | |
from concurrent.futures import ThreadPoolExecutor | |
import gradio as gr | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from huggingface_hub import hf_hub_download | |
from transformers import DynamicCache | |
import MIDI | |
from midi_model import MIDIModel, MIDIModelConfig | |
from midi_synthesizer import MidiSynthesizer | |
MAX_SEED = np.iinfo(np.int32).max | |
in_space = os.getenv("SYSTEM") == "spaces" | |
# Chord to emoji mapping | |
CHORD_EMOJIS = { | |
'C': 'π΅', 'Cm': 'πΆ', 'C7': 'πΌ', 'Cmaj7': 'πΉ', 'Cm7': 'π»', | |
'D': 'π₯', 'Dm': 'πͺ', 'D7': 'π·', 'Dmaj7': 'πΊ', 'Dm7': 'πͺ', | |
'E': 'πΈ', 'Em': 'π»', 'E7': 'π΅', 'Emaj7': 'πΆ', 'Em7': 'πΌ', | |
'F': 'πΉ', 'Fm': 'πΈ', 'F7': 'π»', 'Fmaj7': 'π·', 'Fm7': 'πΊ', | |
'G': 'πͺ', 'Gm': 'π΅', 'G7': 'πΆ', 'Gmaj7': 'πΌ', 'Gm7': 'πΉ', | |
'A': 'πΈ', 'Am': 'π»', 'A7': 'π·', 'Amaj7': 'πΊ', 'Am7': 'πͺ', | |
'B': 'π΅', 'Bm': 'πΆ', 'B7': 'πΌ', 'Bmaj7': 'πΉ', 'Bm7': 'π»' | |
} | |
# Chord note definitions (MIDI note numbers) | |
CHORD_NOTES = { | |
'C': [60, 64, 67], # C major (C, E, G) | |
'Cm': [60, 63, 67], # C minor (C, Eb, G) | |
'C7': [60, 64, 67, 70], # C7 (C, E, G, Bb) | |
'Cmaj7': [60, 64, 67, 71], # Cmaj7 (C, E, G, B) | |
'Cm7': [60, 63, 67, 70], # Cm7 (C, Eb, G, Bb) | |
'D': [62, 66, 69], # D major (D, F#, A) | |
'Dm': [62, 65, 69], # D minor (D, F, A) | |
'D7': [62, 66, 69, 72], # D7 (D, F#, A, C) | |
'Dmaj7': [62, 66, 69, 73], # Dmaj7 (D, F#, A, C#) | |
'Dm7': [62, 65, 69, 72], # Dm7 (D, F, A, C) | |
'E': [64, 68, 71], # E major (E, G#, B) | |
'Em': [64, 67, 71], # E minor (E, G, B) | |
'E7': [64, 68, 71, 74], # E7 (E, G#, B, D) | |
'Emaj7': [64, 68, 71, 75], # Emaj7 (E, G#, B, D#) | |
'Em7': [64, 67, 71, 74], # Em7 (E, G, B, D) | |
'F': [65, 69, 72], # F major (F, A, C) | |
'Fm': [65, 68, 72], # F minor (F, Ab, C) | |
'F7': [65, 69, 72, 75], # F7 (F, A, C, Eb) | |
'Fmaj7': [65, 69, 72, 76], # Fmaj7 (F, A, C, E) | |
'Fm7': [65, 68, 72, 75], # Fm7 (F, Ab, C, Eb) | |
'G': [67, 71, 74], # G major (G, B, D) | |
'Gm': [67, 70, 74], # G minor (G, Bb, D) | |
'G7': [67, 71, 74, 77], # G7 (G, B, D, F) | |
'Gmaj7': [67, 71, 74, 78], # Gmaj7 (G, B, D, F#) | |
'Gm7': [67, 70, 74, 77], # Gm7 (G, Bb, D, F) | |
'A': [69, 73, 76], # A major (A, C#, E) | |
'Am': [69, 72, 76], # A minor (A, C, E) | |
'A7': [69, 73, 76, 79], # A7 (A, C#, E, G) | |
'Amaj7': [69, 73, 76, 80], # Amaj7 (A, C#, E, G#) | |
'Am7': [69, 72, 76, 79], # Am7 (A, C, E, G) | |
'B': [71, 75, 78], # B major (B, D#, F#) | |
'Bm': [71, 74, 78], # B minor (B, D, F#) | |
'B7': [71, 75, 78, 81], # B7 (B, D#, F#, A) | |
'Bmaj7': [71, 75, 78, 82], # Bmaj7 (B, D#, F#, A#) | |
'Bm7': [71, 74, 78, 81] # Bm7 (B, D, F#, A) | |
} | |
# MIDI device manager | |
class MIDIDeviceManager: | |
def __init__(self): | |
self.midi_out = rtmidi.MidiOut() | |
self.available_ports = self.midi_out.get_ports() | |
self.current_port = None | |
def get_available_devices(self): | |
"""Return list of available MIDI output devices""" | |
self.available_ports = self.midi_out.get_ports() | |
return self.available_ports | |
def open_port(self, port_index): | |
"""Open a MIDI port by index""" | |
if 0 <= port_index < len(self.available_ports): | |
if self.current_port is not None: | |
self.midi_out.close_port() | |
self.midi_out.open_port(port_index) | |
self.current_port = port_index | |
return True | |
return False | |
def send_note_on(self, note, velocity=64, channel=0): | |
"""Send MIDI note on message""" | |
if self.current_port is not None: | |
message = [0x90 + channel, note, velocity] | |
self.midi_out.send_message(message) | |
def send_note_off(self, note, velocity=0, channel=0): | |
"""Send MIDI note off message""" | |
if self.current_port is not None: | |
message = [0x80 + channel, note, velocity] | |
self.midi_out.send_message(message) | |
def send_program_change(self, program, channel=0): | |
"""Send MIDI program change message""" | |
if self.current_port is not None: | |
message = [0xC0 + channel, program] | |
self.midi_out.send_message(message) | |
def play_chord(self, chord_name, velocity=80, channel=0, duration=None): | |
"""Play a chord by name with optional automatic release""" | |
if chord_name in CHORD_NOTES: | |
notes = CHORD_NOTES[chord_name] | |
for note in notes: | |
self.send_note_on(note, velocity, channel) | |
if duration is not None: | |
# Automatic note off after duration | |
time.sleep(duration) | |
for note in notes: | |
self.send_note_off(note, 0, channel) | |
def release_chord(self, chord_name, channel=0): | |
"""Release all notes in a chord""" | |
if chord_name in CHORD_NOTES: | |
notes = CHORD_NOTES[chord_name] | |
for note in notes: | |
self.send_note_off(note, 0, channel) | |
def close(self): | |
"""Close current MIDI port""" | |
if self.current_port is not None: | |
self.midi_out.close_port() | |
self.current_port = None | |
# Global MIDI manager | |
midi_manager = MIDIDeviceManager() | |
def create_msg(name, data): | |
return {"name": name, "data": data} | |
def send_msgs(msgs): | |
return json.dumps(msgs) | |
def create_chord_events(chord, duration=480, velocity=80): | |
"""Create MIDI events for a chord""" | |
events = [] | |
if chord in CHORD_NOTES: | |
notes = CHORD_NOTES[chord] | |
# Note on events | |
for note in notes: | |
events.append(['note_on', 0, 0, 0, 0, note, velocity]) | |
# Note off events after specified duration | |
for note in notes: | |
events.append(['note_off', duration, 0, 0, 0, note, 0]) | |
return events | |
def add_chord_to_queue(chord_name, chord_queue, max_queue_size=8): | |
"""Add a chord to the playback queue""" | |
if len(chord_queue) >= max_queue_size: | |
chord_queue.pop(0) # Remove oldest chord | |
chord_queue.append(chord_name) | |
return chord_queue | |
def play_chord_on_device(chord_name, midi_device_index): | |
"""Play a chord on the selected MIDI device""" | |
if midi_device_index is not None and midi_device_index >= 0: | |
midi_manager.open_port(midi_device_index) | |
midi_manager.play_chord(chord_name, duration=0.5) | |
return chord_name | |
def play_chord_sequence(chord_queue, midi_device_index, tempo=120): | |
"""Play a sequence of chords at the specified tempo""" | |
if midi_device_index is not None and midi_device_index >= 0: | |
# Calculate timing based on tempo (beats per minute) | |
beat_duration = 60.0 / tempo # seconds per beat | |
midi_manager.open_port(midi_device_index) | |
for chord in chord_queue: | |
midi_manager.play_chord(chord, duration=beat_duration) | |
# Add a small gap between chords | |
time.sleep(0.05) | |
return chord_queue | |
def refresh_midi_devices(): | |
"""Refresh the list of available MIDI devices""" | |
return gr.Dropdown.update(choices=midi_manager.get_available_devices()) | |
def hf_hub_download_retry(repo_id, filename): | |
print(f"downloading {repo_id} {filename}") | |
retry = 0 | |
err = None | |
while retry < 30: | |
try: | |
return hf_hub_download(repo_id=repo_id, filename=filename) | |
except Exception as e: | |
err = e | |
retry += 1 | |
if err: | |
raise err | |
def load_javascript(dir="javascript"): | |
scripts_list = glob.glob(f"{dir}/*.js") | |
javascript = "" | |
for path in scripts_list: | |
with open(path, "r", encoding="utf8") as jsfile: | |
js_content = jsfile.read() | |
js_content = js_content.replace("const MIDI_OUTPUT_BATCH_SIZE=4;", | |
f"const MIDI_OUTPUT_BATCH_SIZE={OUTPUT_BATCH_SIZE};") | |
javascript += f"\n<!-- {path} --><script>{js_content}</script>" | |
template_response_ori = gr.routes.templates.TemplateResponse | |
def template_response(*args, **kwargs): | |
res = template_response_ori(*args, **kwargs) | |
res.body = res.body.replace( | |
b'</head>', f'{javascript}</head>'.encode("utf8")) | |
res.init_headers() | |
return res | |
gr.routes.templates.TemplateResponse = template_response | |
def create_virtual_keyboard(chord_types): | |
"""Create virtual keyboard buttons organized by root note and chord type""" | |
root_notes = ['C', 'D', 'E', 'F', 'G', 'A', 'B'] | |
buttons = {} | |
for root in root_notes: | |
buttons[root] = {} | |
for chord_type in chord_types: | |
chord_name = f"{root}{chord_type}" | |
emoji = CHORD_EMOJIS.get(chord_name, "π΅") | |
buttons[root][chord_type] = (chord_name, emoji) | |
return buttons | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--share", action="store_true", default=False, help="share gradio app") | |
parser.add_argument("--port", type=int, default=7860, help="gradio server port") | |
parser.add_argument("--device", type=str, default="cuda", help="device to run model") | |
parser.add_argument("--batch", type=int, default=4, help="batch size") | |
parser.add_argument("--max-gen", type=int, default=1024, help="max") | |
opt = parser.parse_args() | |
OUTPUT_BATCH_SIZE = opt.batch | |
# Initialize MIDI device manager | |
midi_manager = MIDIDeviceManager() | |
# Initialize models (simplified version) | |
soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE) | |
synthesizer = MidiSynthesizer(soundfont_path) | |
# Define chord types to use in the virtual keyboard | |
chord_types = ['', 'm', '7', 'maj7', 'm7'] | |
# Create virtual keyboard structure | |
keyboard = create_virtual_keyboard(chord_types) | |
# Define CSS for the virtual keyboard | |
keyboard_css = """ | |
.chord-button { | |
margin: 4px; | |
min-width: 80px; | |
height: 60px; | |
font-size: 18px; | |
font-weight: bold; | |
border-radius: 8px; | |
transition: all 0.2s; | |
} | |
.chord-button:active { | |
transform: scale(0.95); | |
} | |
.chord-queue { | |
padding: 10px; | |
background: #f5f5f5; | |
border-radius: 8px; | |
min-height: 50px; | |
font-size: 16px; | |
margin-bottom: 15px; | |
} | |
.root-c { background-color: #FFCDD2; } | |
.root-d { background-color: #F8BBD0; } | |
.root-e { background-color: #E1BEE7; } | |
.root-f { background-color: #D1C4E9; } | |
.root-g { background-color: #C5CAE9; } | |
.root-a { background-color: #BBDEFB; } | |
.root-b { background-color: #B3E5FC; } | |
""" | |
load_javascript() | |
app = gr.Blocks(theme=gr.themes.Soft(), css=keyboard_css) | |
with app: | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>π΅ Real-Time MIDI Chord Keyboard π΅</h1>") | |
js_msg = gr.Textbox(elem_id="msg_receiver", visible=False) | |
js_msg.change(None, [js_msg], [], js=""" | |
(msg_json) =>{ | |
let msgs = JSON.parse(msg_json); | |
executeCallbacks(msgReceiveCallbacks, msgs); | |
return []; | |
} | |
""") | |
# MIDI Device Settings | |
with gr.Row(): | |
with gr.Column(scale=3): | |
midi_device = gr.Dropdown(label="MIDI Output Device", | |
choices=midi_manager.get_available_devices(), | |
type="index") | |
refresh_button = gr.Button("π Refresh MIDI Devices") | |
with gr.Column(scale=1): | |
tempo = gr.Slider(label="Tempo (BPM)", | |
minimum=40, | |
maximum=200, | |
value=120, | |
step=1) | |
# Chord Queue Display | |
chord_queue = gr.State([]) | |
queue_display = gr.Markdown("### Current Chord Queue\n*No chords in queue*", | |
elem_classes=["chord-queue"]) | |
# Play queue button | |
play_queue_button = gr.Button("βΆοΈ Play Chord Sequence", variant="primary", size="lg") | |
# Clear queue button | |
clear_queue_button = gr.Button("ποΈ Clear Queue", variant="secondary") | |
# Virtual Keyboard - Create sections for each root note | |
gr.Markdown("## Virtual Chord Keyboard") | |
for root in ['C', 'D', 'E', 'F', 'G', 'A', 'B']: | |
with gr.Row(): | |
gr.Markdown(f"### {root}") | |
for chord_type in chord_types: | |
chord_name, emoji = keyboard[root][chord_type] | |
display_name = chord_name if chord_type == '' else chord_name | |
button = gr.Button(f"{emoji} {display_name}", | |
elem_classes=[f"chord-button root-{root.lower()}"]) | |
# Connect the button to add chord to queue and play it immediately | |
button.click( | |
fn=play_chord_on_device, | |
inputs=[gr.State(chord_name), midi_device], | |
outputs=None | |
).then( | |
fn=add_chord_to_queue, | |
inputs=[gr.State(chord_name), chord_queue], | |
outputs=[chord_queue] | |
).then( | |
fn=lambda q: f"### Current Chord Queue\n" + " β ".join(q) if q else "*No chords in queue*", | |
inputs=[chord_queue], | |
outputs=[queue_display] | |
) | |
# Connect refresh button | |
refresh_button.click( | |
fn=refresh_midi_devices, | |
inputs=None, | |
outputs=[midi_device] | |
) | |
# Connect play queue button | |
play_queue_button.click( | |
fn=play_chord_sequence, | |
inputs=[chord_queue, midi_device, tempo], | |
outputs=[chord_queue] | |
) | |
# Connect clear queue button | |
clear_queue_button.click( | |
fn=lambda: [], | |
inputs=None, | |
outputs=[chord_queue] | |
).then( | |
fn=lambda: "### Current Chord Queue\n*No chords in queue*", | |
inputs=None, | |
outputs=[queue_display] | |
) | |
# MIDI Generation Settings (for advanced users) | |
with gr.Accordion("Advanced MIDI Settings", open=False): | |
with gr.Row(): | |
midi_channel = gr.Slider(label="MIDI Channel", | |
minimum=0, | |
maximum=15, | |
value=0, | |
step=1) | |
instrument = gr.Dropdown(label="Instrument", | |
choices=[(f"{i}: {name}", i) for i, name in enumerate([ | |
"Acoustic Grand Piano", "Bright Acoustic Piano", "Electric Grand Piano", | |
"Honky-tonk Piano", "Electric Piano 1", "Electric Piano 2", "Harpsichord", | |
"Clavinet", "Celesta", "Glockenspiel", "Music Box", "Vibraphone", | |
"Marimba", "Xylophone", "Tubular Bells", "Dulcimer" | |
])], | |
value=0) | |
velocity = gr.Slider(label="Velocity", | |
minimum=1, | |
maximum=127, | |
value=80, | |
step=1) | |
# Program change button | |
program_change_button = gr.Button("Send Program Change") | |
program_change_button.click( | |
fn=lambda inst, chan: midi_manager.send_program_change(inst, chan), | |
inputs=[instrument, midi_channel], | |
outputs=None | |
) | |
app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False) | |
# Clean up MIDI connections when the app closes | |
midi_manager.close() |