import gradio as gr import json import rtmidi import os import argparse import base64 import io import numpy as np from huggingface_hub import hf_hub_download import onnxruntime as rt import MIDI from midi_synthesizer import MidiSynthesizer from midi_tokenizer import MIDITokenizer # Match the JavaScript constant MIDI_OUTPUT_BATCH_SIZE = 4 class MIDIDeviceManager: """Manages MIDI input/output devices.""" def __init__(self): self.midiout = rtmidi.MidiOut() self.midiin = rtmidi.MidiIn() def get_device_info(self): """Returns a string listing available MIDI devices.""" out_ports = self.midiout.get_ports() or ["No MIDI output devices"] in_ports = self.midiin.get_ports() or ["No MIDI input devices"] return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}" def close(self): """Closes open MIDI ports.""" if self.midiout.is_port_open(): self.midiout.close_port() if self.midiin.is_port_open(): self.midiin.close_port() del self.midiout, self.midiin class MIDIManager: """Handles MIDI processing, generation, and playback.""" def __init__(self): # Load soundfont and models from Hugging Face self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") self.synthesizer = MidiSynthesizer(self.soundfont_path) self.tokenizer = self._load_tokenizer("skytnt/midi-model") self.model_base = rt.InferenceSession( hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] ) self.model_token = rt.InferenceSession( hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] ) self.generated_files = [] def _load_tokenizer(self, repo_id): """Loads the MIDI tokenizer configuration.""" config_path = hf_hub_download(repo_id=repo_id, filename="config.json") with open(config_path, "r") as f: config = json.load(f) tokenizer = MIDITokenizer(config["tokenizer"]["version"]) tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"]) return tokenizer def load_midi(self, file_path): """Loads a MIDI file from the given path.""" return MIDI.load(file_path) def generate_onnx(self, midi_data): """Generates a MIDI variation using ONNX models.""" mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data)) input_tensor = np.array([mid_seq], dtype=np.int64) cur_len = input_tensor.shape[1] max_len = 1024 while cur_len < max_len: inputs = {"x": input_tensor[:, -1:]} hidden = self.model_base.run(None, inputs)[0] logits = self.model_token.run(None, {"hidden": hidden})[0] probs = self._softmax(logits, axis=-1) next_token = self._sample_top_p_k(probs, 0.98, 20) input_tensor = np.concatenate([input_tensor, next_token], axis=1) cur_len += 1 new_seq = input_tensor[0].tolist() generated_midi = self.tokenizer.detokenize(new_seq) # Store base64-encoded MIDI data for downloads midi_bytes = MIDI.save(generated_midi) self.generated_files.append(base64.b64encode(midi_bytes).decode('utf-8')) return generated_midi def play_midi(self, midi_data): """Renders MIDI data to audio bytes.""" midi_bytes = base64.b64decode(midi_data) midi_file = MIDI.load(io.BytesIO(midi_bytes)) audio = io.BytesIO() self.synthesizer.render_midi(midi_file, audio) audio.seek(0) return audio @staticmethod def _softmax(x, axis): """Computes softmax probabilities.""" exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) return exp_x / np.sum(exp_x, axis=axis, keepdims=True) @staticmethod def _sample_top_p_k(probs, p, k): """Samples a token using top-p and top-k sampling (simplified).""" # Placeholder: replace with actual sampling logic if needed return np.array([[np.random.choice(len(probs[0]))]]) def process_midi(files): """Processes uploaded MIDI files and yields updates for Gradio components.""" if not files: yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE) return for idx, file in enumerate(files): output_idx = idx % MIDI_OUTPUT_BATCH_SIZE midi_data = midi_processor.load_midi(file.name) generated_midi = midi_processor.generate_onnx(midi_data) # Placeholder for MIDI events; in practice, extract from generated_midi # Expected format: ["note", delta_time, track, channel, pitch, velocity, duration] events = [ ["note", 0, 0, 0, 60, 100, 1000], # Example event # Add logic to convert generated_midi to events using tokenizer ] # Prepare updates list: [js_msg, audio0, midi0, audio1, midi1, ...] updates = [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE) # Clear visualizer updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_clear", "data": [output_idx, "v2"]}])) yield updates # Send MIDI events updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_append", "data": [output_idx, events]}])) yield updates # Finalize visualizer and update audio/MIDI outputs audio_update = midi_processor.play_midi(generated_midi) midi_update = gr.File.update(value=generated_midi, label=f"Generated MIDI {output_idx}") updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_end", "data": output_idx}])) updates[1 + 2 * output_idx] = audio_update # Audio component updates[2 + 2 * output_idx] = midi_update # MIDI file component yield updates # Final yield to ensure all components are in a stable state yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE) if __name__ == "__main__": parser = argparse.ArgumentParser(description="MIDI Composer App") parser.add_argument("--port", type=int, default=7860, help="Server port") parser.add_argument("--share", action="store_true", help="Share the app publicly") opt = parser.parse_args() device_manager = MIDIDeviceManager() midi_processor = MIDIManager() with gr.Blocks(theme=gr.themes.Soft()) as app: # Hidden textbox for sending messages to JS js_msg = gr.Textbox(visible=False, elem_id="msg_receiver") with gr.Tabs(): # MIDI Prompt Tab with gr.Tab("MIDI Prompt"): midi_upload = gr.File(label="Upload MIDI File(s)", file_count="multiple") generate_btn = gr.Button("Generate") status = gr.Textbox(label="Status", value="Ready", interactive=False) # Outputs Tab with gr.Tab("Outputs"): output_audios = [] output_midis = [] for i in range(MIDI_OUTPUT_BATCH_SIZE): with gr.Column(): gr.Markdown(f"## Output {i+1}") gr.HTML(elem_id=f"midi_visualizer_container_{i}") output_audio = gr.Audio(label="Generated Audio", type="bytes", autoplay=True, elem_id=f"midi_audio_{i}") output_midi = gr.File(label="Generated MIDI", file_types=[".mid"]) output_audios.append(output_audio) output_midis.append(output_midi) # Devices Tab with gr.Tab("Devices"): device_info = gr.Textbox(label="Connected MIDI Devices", value=device_manager.get_device_info(), interactive=False) refresh_btn = gr.Button("Refresh Devices") refresh_btn.click(fn=lambda: device_manager.get_device_info(), outputs=[device_info]) # Define output components for event handling outputs = [js_msg] + output_audios + output_midis # Bind the generate button to the processing function generate_btn.click(fn=process_midi, inputs=[midi_upload], outputs=outputs) # Launch the app app.launch(server_port=opt.port, share=opt.share, inbrowser=True) device_manager.close()