Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,538 Bytes
5297a72 e8f8f3c 1151735 92acab8 5297a72 75808a5 dbc6fc5 75808a5 92acab8 5297a72 92acab8 b50d4ec 5297a72 b50d4ec e8f8f3c b50d4ec 5297a72 1151735 e8f8f3c 5297a72 e8f8f3c 1151735 b50d4ec 92acab8 5297a72 92acab8 5297a72 1151735 5297a72 75808a5 1151735 5297a72 75808a5 dbc6fc5 92acab8 5297a72 1151735 5297a72 1151735 75808a5 5297a72 75808a5 1151735 5297a72 92acab8 1151735 5297a72 1151735 5297a72 1151735 5297a72 1151735 5297a72 e8f8f3c 5297a72 e8f8f3c 1151735 e8f8f3c 5297a72 e8f8f3c 1151735 e8f8f3c 5297a72 75808a5 1151735 5297a72 e8f8f3c 1151735 75808a5 5297a72 75808a5 5297a72 1151735 5297a72 1151735 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import gradio as gr
import json
import rtmidi
import os
import argparse
import base64
import io
import numpy as np
from huggingface_hub import hf_hub_download
import onnxruntime as rt
import MIDI
from midi_synthesizer import MidiSynthesizer
from midi_tokenizer import MIDITokenizer
# Match the JavaScript constant
MIDI_OUTPUT_BATCH_SIZE = 4
class MIDIDeviceManager:
"""Manages MIDI input/output devices."""
def __init__(self):
self.midiout = rtmidi.MidiOut()
self.midiin = rtmidi.MidiIn()
def get_device_info(self):
"""Returns a string listing available MIDI devices."""
out_ports = self.midiout.get_ports() or ["No MIDI output devices"]
in_ports = self.midiin.get_ports() or ["No MIDI input devices"]
return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}"
def close(self):
"""Closes open MIDI ports."""
if self.midiout.is_port_open():
self.midiout.close_port()
if self.midiin.is_port_open():
self.midiin.close_port()
del self.midiout, self.midiin
class MIDIManager:
"""Handles MIDI processing, generation, and playback."""
def __init__(self):
# Load soundfont and models from Hugging Face
self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
self.synthesizer = MidiSynthesizer(self.soundfont_path)
self.tokenizer = self._load_tokenizer("skytnt/midi-model")
self.model_base = rt.InferenceSession(
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.model_token = rt.InferenceSession(
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.generated_files = []
def _load_tokenizer(self, repo_id):
"""Loads the MIDI tokenizer configuration."""
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, "r") as f:
config = json.load(f)
tokenizer = MIDITokenizer(config["tokenizer"]["version"])
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
return tokenizer
def load_midi(self, file_path):
"""Loads a MIDI file from the given path."""
return MIDI.load(file_path)
def generate_onnx(self, midi_data):
"""Generates a MIDI variation using ONNX models."""
mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data))
input_tensor = np.array([mid_seq], dtype=np.int64)
cur_len = input_tensor.shape[1]
max_len = 1024
while cur_len < max_len:
inputs = {"x": input_tensor[:, -1:]}
hidden = self.model_base.run(None, inputs)[0]
logits = self.model_token.run(None, {"hidden": hidden})[0]
probs = self._softmax(logits, axis=-1)
next_token = self._sample_top_p_k(probs, 0.98, 20)
input_tensor = np.concatenate([input_tensor, next_token], axis=1)
cur_len += 1
new_seq = input_tensor[0].tolist()
generated_midi = self.tokenizer.detokenize(new_seq)
# Store base64-encoded MIDI data for downloads
midi_bytes = MIDI.save(generated_midi)
self.generated_files.append(base64.b64encode(midi_bytes).decode('utf-8'))
return generated_midi
def play_midi(self, midi_data):
"""Renders MIDI data to audio bytes."""
midi_bytes = base64.b64decode(midi_data)
midi_file = MIDI.load(io.BytesIO(midi_bytes))
audio = io.BytesIO()
self.synthesizer.render_midi(midi_file, audio)
audio.seek(0)
return audio
@staticmethod
def _softmax(x, axis):
"""Computes softmax probabilities."""
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
@staticmethod
def _sample_top_p_k(probs, p, k):
"""Samples a token using top-p and top-k sampling (simplified)."""
# Placeholder: replace with actual sampling logic if needed
return np.array([[np.random.choice(len(probs[0]))]])
def process_midi(files):
"""Processes uploaded MIDI files and yields updates for Gradio components."""
if not files:
yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
return
for idx, file in enumerate(files):
output_idx = idx % MIDI_OUTPUT_BATCH_SIZE
midi_data = midi_processor.load_midi(file.name)
generated_midi = midi_processor.generate_onnx(midi_data)
# Placeholder for MIDI events; in practice, extract from generated_midi
# Expected format: ["note", delta_time, track, channel, pitch, velocity, duration]
events = [
["note", 0, 0, 0, 60, 100, 1000], # Example event
# Add logic to convert generated_midi to events using tokenizer
]
# Prepare updates list: [js_msg, audio0, midi0, audio1, midi1, ...]
updates = [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
# Clear visualizer
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_clear", "data": [output_idx, "v2"]}]))
yield updates
# Send MIDI events
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_append", "data": [output_idx, events]}]))
yield updates
# Finalize visualizer and update audio/MIDI outputs
audio_update = midi_processor.play_midi(generated_midi)
midi_update = gr.File.update(value=generated_midi, label=f"Generated MIDI {output_idx}")
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_end", "data": output_idx}]))
updates[1 + 2 * output_idx] = audio_update # Audio component
updates[2 + 2 * output_idx] = midi_update # MIDI file component
yield updates
# Final yield to ensure all components are in a stable state
yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MIDI Composer App")
parser.add_argument("--port", type=int, default=7860, help="Server port")
parser.add_argument("--share", action="store_true", help="Share the app publicly")
opt = parser.parse_args()
device_manager = MIDIDeviceManager()
midi_processor = MIDIManager()
with gr.Blocks(theme=gr.themes.Soft()) as app:
# Hidden textbox for sending messages to JS
js_msg = gr.Textbox(visible=False, elem_id="msg_receiver")
with gr.Tabs():
# MIDI Prompt Tab
with gr.Tab("MIDI Prompt"):
midi_upload = gr.File(label="Upload MIDI File(s)", file_count="multiple")
generate_btn = gr.Button("Generate")
status = gr.Textbox(label="Status", value="Ready", interactive=False)
# Outputs Tab
with gr.Tab("Outputs"):
output_audios = []
output_midis = []
for i in range(MIDI_OUTPUT_BATCH_SIZE):
with gr.Column():
gr.Markdown(f"## Output {i+1}")
gr.HTML(elem_id=f"midi_visualizer_container_{i}")
output_audio = gr.Audio(label="Generated Audio", type="bytes", autoplay=True, elem_id=f"midi_audio_{i}")
output_midi = gr.File(label="Generated MIDI", file_types=[".mid"])
output_audios.append(output_audio)
output_midis.append(output_midi)
# Devices Tab
with gr.Tab("Devices"):
device_info = gr.Textbox(label="Connected MIDI Devices", value=device_manager.get_device_info(), interactive=False)
refresh_btn = gr.Button("Refresh Devices")
refresh_btn.click(fn=lambda: device_manager.get_device_info(), outputs=[device_info])
# Define output components for event handling
outputs = [js_msg] + output_audios + output_midis
# Bind the generate button to the processing function
generate_btn.click(fn=process_midi, inputs=[midi_upload], outputs=outputs)
# Launch the app
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
device_manager.close() |