midi-composer / app.py
awacke1's picture
Update app.py
5297a72 verified
raw
history blame
8.54 kB
import gradio as gr
import json
import rtmidi
import os
import argparse
import base64
import io
import numpy as np
from huggingface_hub import hf_hub_download
import onnxruntime as rt
import MIDI
from midi_synthesizer import MidiSynthesizer
from midi_tokenizer import MIDITokenizer
# Match the JavaScript constant
MIDI_OUTPUT_BATCH_SIZE = 4
class MIDIDeviceManager:
"""Manages MIDI input/output devices."""
def __init__(self):
self.midiout = rtmidi.MidiOut()
self.midiin = rtmidi.MidiIn()
def get_device_info(self):
"""Returns a string listing available MIDI devices."""
out_ports = self.midiout.get_ports() or ["No MIDI output devices"]
in_ports = self.midiin.get_ports() or ["No MIDI input devices"]
return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}"
def close(self):
"""Closes open MIDI ports."""
if self.midiout.is_port_open():
self.midiout.close_port()
if self.midiin.is_port_open():
self.midiin.close_port()
del self.midiout, self.midiin
class MIDIManager:
"""Handles MIDI processing, generation, and playback."""
def __init__(self):
# Load soundfont and models from Hugging Face
self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
self.synthesizer = MidiSynthesizer(self.soundfont_path)
self.tokenizer = self._load_tokenizer("skytnt/midi-model")
self.model_base = rt.InferenceSession(
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.model_token = rt.InferenceSession(
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.generated_files = []
def _load_tokenizer(self, repo_id):
"""Loads the MIDI tokenizer configuration."""
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, "r") as f:
config = json.load(f)
tokenizer = MIDITokenizer(config["tokenizer"]["version"])
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
return tokenizer
def load_midi(self, file_path):
"""Loads a MIDI file from the given path."""
return MIDI.load(file_path)
def generate_onnx(self, midi_data):
"""Generates a MIDI variation using ONNX models."""
mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data))
input_tensor = np.array([mid_seq], dtype=np.int64)
cur_len = input_tensor.shape[1]
max_len = 1024
while cur_len < max_len:
inputs = {"x": input_tensor[:, -1:]}
hidden = self.model_base.run(None, inputs)[0]
logits = self.model_token.run(None, {"hidden": hidden})[0]
probs = self._softmax(logits, axis=-1)
next_token = self._sample_top_p_k(probs, 0.98, 20)
input_tensor = np.concatenate([input_tensor, next_token], axis=1)
cur_len += 1
new_seq = input_tensor[0].tolist()
generated_midi = self.tokenizer.detokenize(new_seq)
# Store base64-encoded MIDI data for downloads
midi_bytes = MIDI.save(generated_midi)
self.generated_files.append(base64.b64encode(midi_bytes).decode('utf-8'))
return generated_midi
def play_midi(self, midi_data):
"""Renders MIDI data to audio bytes."""
midi_bytes = base64.b64decode(midi_data)
midi_file = MIDI.load(io.BytesIO(midi_bytes))
audio = io.BytesIO()
self.synthesizer.render_midi(midi_file, audio)
audio.seek(0)
return audio
@staticmethod
def _softmax(x, axis):
"""Computes softmax probabilities."""
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
@staticmethod
def _sample_top_p_k(probs, p, k):
"""Samples a token using top-p and top-k sampling (simplified)."""
# Placeholder: replace with actual sampling logic if needed
return np.array([[np.random.choice(len(probs[0]))]])
def process_midi(files):
"""Processes uploaded MIDI files and yields updates for Gradio components."""
if not files:
yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
return
for idx, file in enumerate(files):
output_idx = idx % MIDI_OUTPUT_BATCH_SIZE
midi_data = midi_processor.load_midi(file.name)
generated_midi = midi_processor.generate_onnx(midi_data)
# Placeholder for MIDI events; in practice, extract from generated_midi
# Expected format: ["note", delta_time, track, channel, pitch, velocity, duration]
events = [
["note", 0, 0, 0, 60, 100, 1000], # Example event
# Add logic to convert generated_midi to events using tokenizer
]
# Prepare updates list: [js_msg, audio0, midi0, audio1, midi1, ...]
updates = [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
# Clear visualizer
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_clear", "data": [output_idx, "v2"]}]))
yield updates
# Send MIDI events
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_append", "data": [output_idx, events]}]))
yield updates
# Finalize visualizer and update audio/MIDI outputs
audio_update = midi_processor.play_midi(generated_midi)
midi_update = gr.File.update(value=generated_midi, label=f"Generated MIDI {output_idx}")
updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_end", "data": output_idx}]))
updates[1 + 2 * output_idx] = audio_update # Audio component
updates[2 + 2 * output_idx] = midi_update # MIDI file component
yield updates
# Final yield to ensure all components are in a stable state
yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MIDI Composer App")
parser.add_argument("--port", type=int, default=7860, help="Server port")
parser.add_argument("--share", action="store_true", help="Share the app publicly")
opt = parser.parse_args()
device_manager = MIDIDeviceManager()
midi_processor = MIDIManager()
with gr.Blocks(theme=gr.themes.Soft()) as app:
# Hidden textbox for sending messages to JS
js_msg = gr.Textbox(visible=False, elem_id="msg_receiver")
with gr.Tabs():
# MIDI Prompt Tab
with gr.Tab("MIDI Prompt"):
midi_upload = gr.File(label="Upload MIDI File(s)", file_count="multiple")
generate_btn = gr.Button("Generate")
status = gr.Textbox(label="Status", value="Ready", interactive=False)
# Outputs Tab
with gr.Tab("Outputs"):
output_audios = []
output_midis = []
for i in range(MIDI_OUTPUT_BATCH_SIZE):
with gr.Column():
gr.Markdown(f"## Output {i+1}")
gr.HTML(elem_id=f"midi_visualizer_container_{i}")
output_audio = gr.Audio(label="Generated Audio", type="bytes", autoplay=True, elem_id=f"midi_audio_{i}")
output_midi = gr.File(label="Generated MIDI", file_types=[".mid"])
output_audios.append(output_audio)
output_midis.append(output_midi)
# Devices Tab
with gr.Tab("Devices"):
device_info = gr.Textbox(label="Connected MIDI Devices", value=device_manager.get_device_info(), interactive=False)
refresh_btn = gr.Button("Refresh Devices")
refresh_btn.click(fn=lambda: device_manager.get_device_info(), outputs=[device_info])
# Define output components for event handling
outputs = [js_msg] + output_audios + output_midis
# Bind the generate button to the processing function
generate_btn.click(fn=process_midi, inputs=[midi_upload], outputs=outputs)
# Launch the app
app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
device_manager.close()