midi-composer / app.py
awacke1's picture
Update app.py
5297a72 verified
import gradio as gr
import json
import rtmidi
import os
import argparse
import base64
import io
import numpy as np
from huggingface_hub import hf_hub_download
import onnxruntime as rt
import MIDI
from midi_synthesizer import MidiSynthesizer
from midi_tokenizer import MIDITokenizer
# Match the JavaScript constant
MIDI_OUTPUT_BATCH_SIZE = 4
class MIDIDeviceManager:
"""Manages MIDI input/output devices."""
def __init__(self):
self.midiout = rtmidi.MidiOut()
self.midiin = rtmidi.MidiIn()
def get_device_info(self):
"""Returns a string listing available MIDI devices."""
out_ports = self.midiout.get_ports() or ["No MIDI output devices"]
in_ports = self.midiin.get_ports() or ["No MIDI input devices"]
return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}"
def close(self):
"""Closes open MIDI ports."""
if self.midiout.is_port_open():
self.midiout.close_port()
if self.midiin.is_port_open():
self.midiin.close_port()
del self.midiout, self.midiin
class MIDIManager:
"""Handles MIDI processing, generation, and playback."""
def __init__(self):
# Load soundfont and models from Hugging Face
self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
self.synthesizer = MidiSynthesizer(self.soundfont_path)
self.tokenizer = self._load_tokenizer("skytnt/midi-model")
self.model_base = rt.InferenceSession(
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.model_token = rt.InferenceSession(
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.generated_files = []
def _load_tokenizer(self, repo_id):
"""Loads the MIDI tokenizer configuration."""
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, "r") as f:
config = json.load(f)
tokenizer = MIDITokenizer(config["tokenizer"]["version"])
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
return tokenizer
def load_midi(self, file_path):
"""Loads a MIDI file from the given path."""
return MIDI.load(file_path)
def generate_onnx(self, midi_data):
"""Generates a MIDI variation using ONNX models."""
mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data))
input_tensor = np.array([mid_seq], dtype=np.int64)
cur_len = input_tensor.shape[1]
max_len = 1024
while cur_len < max_len:
inputs = {"x": input_tensor[:, -1:]}
hidden = self.model_base.run(None, inputs)[0]
logits = self.model_token.run(None, {"hidden": hidden})[0]
probs = self._softmax(logits, axis=-1)
next_token = self._sample_top_p_k(probs, 0.98, 20)
input_tensor = np.concatenate([input_tensor, next_token], axis=1)
cur_len += 1
new_seq = input_tensor[0].tolist()
generated_midi = self.tokenizer.detokenize(new_seq)
# Store base64-encoded MIDI data for downloads
midi_bytes = MIDI.save(generated_midi)
self.generated_files.append(base64.b64encode(midi_bytes).decode('utf-8'))
return generated_midi
def play_midi(self, midi_data):
"""Renders MIDI data to audio bytes."""
midi_bytes = base64.b64decode(midi_data)
midi_file = MIDI.load(io.BytesIO(midi_bytes))
audio = io.BytesIO()
self.synthesizer.render_midi(midi_file, audio)
audio.seek(0)
return audio
@staticmethod
def _softmax(x, axis):
"""Computes softmax probabilities."""
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
@staticmethod
def _sample_top_p_k(probs, p, k):
"""Samples a token using top-p and top-k sampling (simplified)."""
# Placeholder: replace with actual sampling logic if needed
return np.array([[np.random.choice(len(probs[0]))]])
def process_midi(files):
"""Processes uploaded MIDI files and yields updates for Gradio components."""
if not files:
yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
return
for idx, file in enumerate(files):
output_idx = idx % MIDI_OUTPUT_BATCH_SIZE
midi_data = midi_processor.load_midi(file.name)
generated_midi = midi_processor.generate_onnx(midi_data)
# Placeholder for MIDI events; in practice, extract from generated_midi
# Expected format: ["note", delta_time, track, channel, pitch, velocity, duration]
events = [
["note", 0, 0, 0, 60, 100, 1000], # Example event
# Add logic to convert generated_midi to events using tokenizer
]
# Prepare updates list: [js_msg, audio0, midi0, audio1, midi1, ...]
updates = [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
# Clear visualizer
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_clear", "data": [output_idx, "v2"]}]))
yield updates
# Send MIDI events
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_append", "data": [output_idx, events]}]))
yield updates
# Finalize visualizer and update audio/MIDI outputs
audio_update = midi_processor.play_midi(generated_midi)
midi_update = gr.File.update(value=generated_midi, label=f"Generated MIDI {output_idx}")
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_end", "data": output_idx}]))
updates[1 + 2 * output_idx] = audio_update # Audio component
updates[2 + 2 * output_idx] = midi_update # MIDI file component
yield updates
# Final yield to ensure all components are in a stable state
yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MIDI Composer App")
parser.add_argument("--port", type=int, default=7860, help="Server port")
parser.add_argument("--share", action="store_true", help="Share the app publicly")
opt = parser.parse_args()
device_manager = MIDIDeviceManager()
midi_processor = MIDIManager()
with gr.Blocks(theme=gr.themes.Soft()) as app:
# Hidden textbox for sending messages to JS
js_msg = gr.Textbox(visible=False, elem_id="msg_receiver")
with gr.Tabs():
# MIDI Prompt Tab
with gr.Tab("MIDI Prompt"):
midi_upload = gr.File(label="Upload MIDI File(s)", file_count="multiple")
generate_btn = gr.Button("Generate")
status = gr.Textbox(label="Status", value="Ready", interactive=False)
# Outputs Tab
with gr.Tab("Outputs"):
output_audios = []
output_midis = []
for i in range(MIDI_OUTPUT_BATCH_SIZE):
with gr.Column():
gr.Markdown(f"## Output {i+1}")
gr.HTML(elem_id=f"midi_visualizer_container_{i}")
output_audio = gr.Audio(label="Generated Audio", type="bytes", autoplay=True, elem_id=f"midi_audio_{i}")
output_midi = gr.File(label="Generated MIDI", file_types=[".mid"])
output_audios.append(output_audio)
output_midis.append(output_midi)
# Devices Tab
with gr.Tab("Devices"):
device_info = gr.Textbox(label="Connected MIDI Devices", value=device_manager.get_device_info(), interactive=False)
refresh_btn = gr.Button("Refresh Devices")
refresh_btn.click(fn=lambda: device_manager.get_device_info(), outputs=[device_info])
# Define output components for event handling
outputs = [js_msg] + output_audios + output_midis
# Bind the generate button to the processing function
generate_btn.click(fn=process_midi, inputs=[midi_upload], outputs=outputs)
# Launch the app
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
device_manager.close()