import argparse import base64 import io import os import random import numpy as np import gradio as gr import rtmidi import onnxruntime as rt from huggingface_hub import hf_hub_download import MIDI from midi_synthesizer import MidiSynthesizer from midi_tokenizer import MIDITokenizer # Constants MAX_SEED = np.iinfo(np.int32).max IN_SPACE = os.getenv("SYSTEM") == "spaces" MAX_LENGTH = 1024 # Maximum tokens for generation # MIDI Device Manager class MIDIDeviceManager: def __init__(self): self.midiout = rtmidi.MidiOut() self.midiin = rtmidi.MidiIn() def get_device_info(self): out_ports = self.midiout.get_ports() or ["No MIDI output devices"] in_ports = self.midiin.get_ports() or ["No MIDI input devices"] return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}" def close(self): if self.midiout.is_port_open(): self.midiout.close_port() if self.midiin.is_port_open(): self.midiin.close_port() del self.midiout, self.midiin # MIDI Processor with ONNX Generation class MIDIManager: def __init__(self): self.soundfont = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") self.synthesizer = MidiSynthesizer(self.soundfont) self.tokenizer = self._load_tokenizer("skytnt/midi-model") self.model_base = rt.InferenceSession( hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] ) self.model_token = rt.InferenceSession( hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] ) self.generated_files = [] # Store base64-encoded MIDI data self.is_playing = False def _load_tokenizer(self, repo_id): config_path = hf_hub_download(repo_id=repo_id, filename="config.json") with open(config_path, "r") as f: config = json.load(f) tokenizer = MIDITokenizer(config["tokenizer"]["version"]) tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"]) return tokenizer def load_midi(self, file_path): try: return MIDI.load(file_path) except Exception as e: raise ValueError(f"Failed to load MIDI file: {e}") def generate_variation(self, midi_data, temp=1.0, top_p=0.98, top_k=20): # Tokenize input MIDI mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data)) input_tensor = np.array([mid_seq], dtype=np.int64) cur_len = input_tensor.shape[1] generator = np.random.RandomState(random.randint(0, MAX_SEED)) # Generate up to MAX_LENGTH while cur_len < MAX_LENGTH: inputs = {"x": input_tensor[:, -1:]} # Last token hidden = self.model_base.run(None, inputs)[0] # Base model output logits = self.model_token.run(None, {"hidden": hidden})[0] # Token model output probs = softmax(logits / temp, axis=-1) next_token = sample_top_p_k(probs, top_p, top_k, generator) input_tensor = np.concatenate([input_tensor, next_token], axis=1) cur_len += 1 # Detokenize and save as MIDI new_seq = input_tensor[0].tolist() new_midi = self.tokenizer.detokenize(new_seq) midi_output = io.BytesIO() MIDI.score2midi(new_midi, midi_output) midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') self.generated_files.append(midi_data) return midi_data def play_midi(self, midi_data): self.is_playing = True midi_bytes = base64.b64decode(midi_data) midi_file = MIDI.load(io.BytesIO(midi_bytes)) audio = io.BytesIO() self.synthesizer.render_midi(midi_file, audio) audio.seek(0) return audio def stop(self): self.is_playing = False # Helper Functions def softmax(x, axis): exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) return exp_x / np.sum(exp_x, axis=axis, keepdims=True) def sample_top_p_k(probs, p, k, generator): probs_idx = np.argsort(-probs, axis=-1) probs_sort = np.take_along_axis(probs, probs_idx, axis=-1) probs_sum = np.cumsum(probs_sort, axis=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 probs_sort[:, k:] = 0.0 # Top-k filtering probs_sort /= probs_sort.sum(axis=-1, keepdims=True) next_token = generator.choice(probs.shape[-1], p=probs_sort[0]) return np.array([[next_token]]) # UI Functions def process_midi_upload(files): if not files: return None, "No file uploaded", "" file = files[0] # Process first file try: midi_data = midi_processor.load_midi(file.name) generated_midi = midi_processor.generate_variation(midi_data) audio = midi_processor.play_midi(generated_midi) download_html = create_download_list() return audio, "Generated and playing", download_html except Exception as e: return None, f"Error: {e}", "" def create_download_list(): if not midi_processor.generated_files: return "
No generated files yet.
" html = "