Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import json | |
import rtmidi | |
import os | |
import argparse | |
import base64 | |
import io | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
import onnxruntime as rt | |
import MIDI | |
from midi_synthesizer import MidiSynthesizer | |
from midi_tokenizer import MIDITokenizer | |
# Match the JavaScript constant | |
MIDI_OUTPUT_BATCH_SIZE = 4 | |
class MIDIDeviceManager: | |
"""Manages MIDI input/output devices.""" | |
def __init__(self): | |
self.midiout = rtmidi.MidiOut() | |
self.midiin = rtmidi.MidiIn() | |
def get_device_info(self): | |
"""Returns a string listing available MIDI devices.""" | |
out_ports = self.midiout.get_ports() or ["No MIDI output devices"] | |
in_ports = self.midiin.get_ports() or ["No MIDI input devices"] | |
return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}" | |
def close(self): | |
"""Closes open MIDI ports.""" | |
if self.midiout.is_port_open(): | |
self.midiout.close_port() | |
if self.midiin.is_port_open(): | |
self.midiin.close_port() | |
del self.midiout, self.midiin | |
class MIDIManager: | |
"""Handles MIDI processing, generation, and playback.""" | |
def __init__(self): | |
# Load soundfont and models from Hugging Face | |
self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
self.synthesizer = MidiSynthesizer(self.soundfont_path) | |
self.tokenizer = self._load_tokenizer("skytnt/midi-model") | |
self.model_base = rt.InferenceSession( | |
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), | |
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
) | |
self.model_token = rt.InferenceSession( | |
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), | |
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
) | |
self.generated_files = [] | |
def _load_tokenizer(self, repo_id): | |
"""Loads the MIDI tokenizer configuration.""" | |
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
tokenizer = MIDITokenizer(config["tokenizer"]["version"]) | |
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"]) | |
return tokenizer | |
def load_midi(self, file_path): | |
"""Loads a MIDI file from the given path.""" | |
return MIDI.load(file_path) | |
def generate_onnx(self, midi_data): | |
"""Generates a MIDI variation using ONNX models.""" | |
mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data)) | |
input_tensor = np.array([mid_seq], dtype=np.int64) | |
cur_len = input_tensor.shape[1] | |
max_len = 1024 | |
while cur_len < max_len: | |
inputs = {"x": input_tensor[:, -1:]} | |
hidden = self.model_base.run(None, inputs)[0] | |
logits = self.model_token.run(None, {"hidden": hidden})[0] | |
probs = self._softmax(logits, axis=-1) | |
next_token = self._sample_top_p_k(probs, 0.98, 20) | |
input_tensor = np.concatenate([input_tensor, next_token], axis=1) | |
cur_len += 1 | |
new_seq = input_tensor[0].tolist() | |
generated_midi = self.tokenizer.detokenize(new_seq) | |
# Store base64-encoded MIDI data for downloads | |
midi_bytes = MIDI.save(generated_midi) | |
self.generated_files.append(base64.b64encode(midi_bytes).decode('utf-8')) | |
return generated_midi | |
def play_midi(self, midi_data): | |
"""Renders MIDI data to audio bytes.""" | |
midi_bytes = base64.b64decode(midi_data) | |
midi_file = MIDI.load(io.BytesIO(midi_bytes)) | |
audio = io.BytesIO() | |
self.synthesizer.render_midi(midi_file, audio) | |
audio.seek(0) | |
return audio | |
def _softmax(x, axis): | |
"""Computes softmax probabilities.""" | |
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) | |
return exp_x / np.sum(exp_x, axis=axis, keepdims=True) | |
def _sample_top_p_k(probs, p, k): | |
"""Samples a token using top-p and top-k sampling (simplified).""" | |
# Placeholder: replace with actual sampling logic if needed | |
return np.array([[np.random.choice(len(probs[0]))]]) | |
def process_midi(files): | |
"""Processes uploaded MIDI files and yields updates for Gradio components.""" | |
if not files: | |
yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE) | |
return | |
for idx, file in enumerate(files): | |
output_idx = idx % MIDI_OUTPUT_BATCH_SIZE | |
midi_data = midi_processor.load_midi(file.name) | |
generated_midi = midi_processor.generate_onnx(midi_data) | |
# Placeholder for MIDI events; in practice, extract from generated_midi | |
# Expected format: ["note", delta_time, track, channel, pitch, velocity, duration] | |
events = [ | |
["note", 0, 0, 0, 60, 100, 1000], # Example event | |
# Add logic to convert generated_midi to events using tokenizer | |
] | |
# Prepare updates list: [js_msg, audio0, midi0, audio1, midi1, ...] | |
updates = [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE) | |
# Clear visualizer | |
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_clear", "data": [output_idx, "v2"]}])) | |
yield updates | |
# Send MIDI events | |
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_append", "data": [output_idx, events]}])) | |
yield updates | |
# Finalize visualizer and update audio/MIDI outputs | |
audio_update = midi_processor.play_midi(generated_midi) | |
midi_update = gr.File.update(value=generated_midi, label=f"Generated MIDI {output_idx}") | |
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_end", "data": output_idx}])) | |
updates[1 + 2 * output_idx] = audio_update # Audio component | |
updates[2 + 2 * output_idx] = midi_update # MIDI file component | |
yield updates | |
# Final yield to ensure all components are in a stable state | |
yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="MIDI Composer App") | |
parser.add_argument("--port", type=int, default=7860, help="Server port") | |
parser.add_argument("--share", action="store_true", help="Share the app publicly") | |
opt = parser.parse_args() | |
device_manager = MIDIDeviceManager() | |
midi_processor = MIDIManager() | |
with gr.Blocks(theme=gr.themes.Soft()) as app: | |
# Hidden textbox for sending messages to JS | |
js_msg = gr.Textbox(visible=False, elem_id="msg_receiver") | |
with gr.Tabs(): | |
# MIDI Prompt Tab | |
with gr.Tab("MIDI Prompt"): | |
midi_upload = gr.File(label="Upload MIDI File(s)", file_count="multiple") | |
generate_btn = gr.Button("Generate") | |
status = gr.Textbox(label="Status", value="Ready", interactive=False) | |
# Outputs Tab | |
with gr.Tab("Outputs"): | |
output_audios = [] | |
output_midis = [] | |
for i in range(MIDI_OUTPUT_BATCH_SIZE): | |
with gr.Column(): | |
gr.Markdown(f"## Output {i+1}") | |
gr.HTML(elem_id=f"midi_visualizer_container_{i}") | |
output_audio = gr.Audio(label="Generated Audio", type="bytes", autoplay=True, elem_id=f"midi_audio_{i}") | |
output_midi = gr.File(label="Generated MIDI", file_types=[".mid"]) | |
output_audios.append(output_audio) | |
output_midis.append(output_midi) | |
# Devices Tab | |
with gr.Tab("Devices"): | |
device_info = gr.Textbox(label="Connected MIDI Devices", value=device_manager.get_device_info(), interactive=False) | |
refresh_btn = gr.Button("Refresh Devices") | |
refresh_btn.click(fn=lambda: device_manager.get_device_info(), outputs=[device_info]) | |
# Define output components for event handling | |
outputs = [js_msg] + output_audios + output_midis | |
# Bind the generate button to the processing function | |
generate_btn.click(fn=process_midi, inputs=[midi_upload], outputs=outputs) | |
# Launch the app | |
app.launch(server_port=opt.port, share=opt.share, inbrowser=True) | |
device_manager.close() |