awacke1 commited on
Commit
5297a72
·
verified ·
1 Parent(s): a92c685

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -109
app.py CHANGED
@@ -1,45 +1,46 @@
 
 
 
 
1
  import argparse
2
  import base64
3
  import io
4
- import os
5
- import random
6
  import numpy as np
7
- import gradio as gr
8
- import rtmidi
9
- import onnxruntime as rt
10
  from huggingface_hub import hf_hub_download
 
11
  import MIDI
12
  from midi_synthesizer import MidiSynthesizer
13
  from midi_tokenizer import MIDITokenizer
14
 
15
- # Constants
16
- MAX_SEED = np.iinfo(np.int32).max
17
- IN_SPACE = os.getenv("SYSTEM") == "spaces"
18
- MAX_LENGTH = 1024 # Maximum tokens for generation
19
 
20
- # MIDI Device Manager
21
  class MIDIDeviceManager:
 
22
  def __init__(self):
23
  self.midiout = rtmidi.MidiOut()
24
  self.midiin = rtmidi.MidiIn()
25
 
26
  def get_device_info(self):
 
27
  out_ports = self.midiout.get_ports() or ["No MIDI output devices"]
28
  in_ports = self.midiin.get_ports() or ["No MIDI input devices"]
29
  return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}"
30
 
31
  def close(self):
 
32
  if self.midiout.is_port_open():
33
  self.midiout.close_port()
34
  if self.midiin.is_port_open():
35
  self.midiin.close_port()
36
  del self.midiout, self.midiin
37
 
38
- # MIDI Processor with ONNX Generation
39
  class MIDIManager:
 
40
  def __init__(self):
41
- self.soundfont = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
42
- self.synthesizer = MidiSynthesizer(self.soundfont)
 
43
  self.tokenizer = self._load_tokenizer("skytnt/midi-model")
44
  self.model_base = rt.InferenceSession(
45
  hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"),
@@ -49,10 +50,10 @@ class MIDIManager:
49
  hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"),
50
  providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
51
  )
52
- self.generated_files = [] # Store base64-encoded MIDI data
53
- self.is_playing = False
54
 
55
  def _load_tokenizer(self, repo_id):
 
56
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
57
  with open(config_path, "r") as f:
58
  config = json.load(f)
@@ -61,39 +62,32 @@ class MIDIManager:
61
  return tokenizer
62
 
63
  def load_midi(self, file_path):
64
- try:
65
- return MIDI.load(file_path)
66
- except Exception as e:
67
- raise ValueError(f"Failed to load MIDI file: {e}")
68
 
69
- def generate_variation(self, midi_data, temp=1.0, top_p=0.98, top_k=20):
70
- # Tokenize input MIDI
71
  mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data))
72
  input_tensor = np.array([mid_seq], dtype=np.int64)
73
  cur_len = input_tensor.shape[1]
74
- generator = np.random.RandomState(random.randint(0, MAX_SEED))
75
-
76
- # Generate up to MAX_LENGTH
77
- while cur_len < MAX_LENGTH:
78
- inputs = {"x": input_tensor[:, -1:]} # Last token
79
- hidden = self.model_base.run(None, inputs)[0] # Base model output
80
- logits = self.model_token.run(None, {"hidden": hidden})[0] # Token model output
81
- probs = softmax(logits / temp, axis=-1)
82
- next_token = sample_top_p_k(probs, top_p, top_k, generator)
83
  input_tensor = np.concatenate([input_tensor, next_token], axis=1)
84
  cur_len += 1
85
-
86
- # Detokenize and save as MIDI
87
  new_seq = input_tensor[0].tolist()
88
- new_midi = self.tokenizer.detokenize(new_seq)
89
- midi_output = io.BytesIO()
90
- MIDI.score2midi(new_midi, midi_output)
91
- midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
92
- self.generated_files.append(midi_data)
93
- return midi_data
94
 
95
  def play_midi(self, midi_data):
96
- self.is_playing = True
97
  midi_bytes = base64.b64decode(midi_data)
98
  midi_file = MIDI.load(io.BytesIO(midi_bytes))
99
  audio = io.BytesIO()
@@ -101,91 +95,103 @@ class MIDIManager:
101
  audio.seek(0)
102
  return audio
103
 
104
- def stop(self):
105
- self.is_playing = False
106
-
107
- # Helper Functions
108
- def softmax(x, axis):
109
- exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
110
- return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
111
-
112
- def sample_top_p_k(probs, p, k, generator):
113
- probs_idx = np.argsort(-probs, axis=-1)
114
- probs_sort = np.take_along_axis(probs, probs_idx, axis=-1)
115
- probs_sum = np.cumsum(probs_sort, axis=-1)
116
- mask = probs_sum - probs_sort > p
117
- probs_sort[mask] = 0.0
118
- probs_sort[:, k:] = 0.0 # Top-k filtering
119
- probs_sort /= probs_sort.sum(axis=-1, keepdims=True)
120
- next_token = generator.choice(probs.shape[-1], p=probs_sort[0])
121
- return np.array([[next_token]])
122
-
123
- # UI Functions
124
- def process_midi_upload(files):
125
  if not files:
126
- return None, "No file uploaded", ""
127
- file = files[0] # Process first file
128
- try:
 
 
129
  midi_data = midi_processor.load_midi(file.name)
130
- generated_midi = midi_processor.generate_variation(midi_data)
131
- audio = midi_processor.play_midi(generated_midi)
132
- download_html = create_download_list()
133
- return audio, "Generated and playing", download_html
134
- except Exception as e:
135
- return None, f"Error: {e}", ""
136
-
137
- def create_download_list():
138
- if not midi_processor.generated_files:
139
- return "<p>No generated files yet.</p>"
140
- html = "<h3>Generated MIDI Files</h3><ul>"
141
- for i, midi_data in enumerate(midi_processor.generated_files):
142
- html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="generated_{i}.mid">Download MIDI {i}</a></li>'
143
- html += "</ul>"
144
- return html
145
-
146
- def refresh_devices():
147
- return device_manager.get_device_info()
148
-
149
- def stop_playback():
150
- midi_processor.stop()
151
- return "Playback stopped"
152
-
153
- # Main Application
 
 
 
 
 
 
 
154
  if __name__ == "__main__":
155
- parser = argparse.ArgumentParser(description="MIDI Composer with ONNX Generation")
156
- parser.add_argument("--port", type=int, default=7860)
157
- parser.add_argument("--share", action="store_true")
158
- args = parser.parse_args()
159
 
160
  device_manager = MIDIDeviceManager()
161
  midi_processor = MIDIManager()
162
 
163
- with gr.Blocks(title="MIDI Composer", theme=gr.themes.Soft()) as app:
164
- gr.Markdown("# 🎵 MIDI Composer 🎵")
165
-
 
166
  with gr.Tabs():
167
  # MIDI Prompt Tab
168
  with gr.Tab("MIDI Prompt"):
169
- midi_upload = gr.File(label="Upload MIDI File", file_types=[".mid", ".midi"])
170
- audio_output = gr.Audio(label="Generated MIDI", type="bytes", autoplay=True)
171
  status = gr.Textbox(label="Status", value="Ready", interactive=False)
172
- midi_upload.change(
173
- process_midi_upload,
174
- inputs=[midi_upload],
175
- outputs=[audio_output, status, gr.HTML(elem_id="downloads")]
176
- )
177
 
178
- # Downloads Tab
179
- with gr.Tab("Downloads", elem_id="downloads"):
180
- gr.HTML(value=create_download_list())
 
 
 
 
 
 
 
 
 
181
 
182
  # Devices Tab
183
  with gr.Tab("Devices"):
184
- device_info = gr.Textbox(label="MIDI Devices", value=device_manager.get_device_info(), interactive=False)
185
  refresh_btn = gr.Button("Refresh Devices")
186
- stop_btn = gr.Button("Stop Playback")
187
- refresh_btn.click(refresh_devices, outputs=[device_info])
188
- stop_btn.click(stop_playback, outputs=[status])
 
 
 
 
189
 
190
- app.launch(server_port=args.port, share=args.share, inbrowser=True)
 
191
  device_manager.close()
 
1
+ import gradio as gr
2
+ import json
3
+ import rtmidi
4
+ import os
5
  import argparse
6
  import base64
7
  import io
 
 
8
  import numpy as np
 
 
 
9
  from huggingface_hub import hf_hub_download
10
+ import onnxruntime as rt
11
  import MIDI
12
  from midi_synthesizer import MidiSynthesizer
13
  from midi_tokenizer import MIDITokenizer
14
 
15
+ # Match the JavaScript constant
16
+ MIDI_OUTPUT_BATCH_SIZE = 4
 
 
17
 
 
18
  class MIDIDeviceManager:
19
+ """Manages MIDI input/output devices."""
20
  def __init__(self):
21
  self.midiout = rtmidi.MidiOut()
22
  self.midiin = rtmidi.MidiIn()
23
 
24
  def get_device_info(self):
25
+ """Returns a string listing available MIDI devices."""
26
  out_ports = self.midiout.get_ports() or ["No MIDI output devices"]
27
  in_ports = self.midiin.get_ports() or ["No MIDI input devices"]
28
  return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}"
29
 
30
  def close(self):
31
+ """Closes open MIDI ports."""
32
  if self.midiout.is_port_open():
33
  self.midiout.close_port()
34
  if self.midiin.is_port_open():
35
  self.midiin.close_port()
36
  del self.midiout, self.midiin
37
 
 
38
  class MIDIManager:
39
+ """Handles MIDI processing, generation, and playback."""
40
  def __init__(self):
41
+ # Load soundfont and models from Hugging Face
42
+ self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
43
+ self.synthesizer = MidiSynthesizer(self.soundfont_path)
44
  self.tokenizer = self._load_tokenizer("skytnt/midi-model")
45
  self.model_base = rt.InferenceSession(
46
  hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"),
 
50
  hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"),
51
  providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
52
  )
53
+ self.generated_files = []
 
54
 
55
  def _load_tokenizer(self, repo_id):
56
+ """Loads the MIDI tokenizer configuration."""
57
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
58
  with open(config_path, "r") as f:
59
  config = json.load(f)
 
62
  return tokenizer
63
 
64
  def load_midi(self, file_path):
65
+ """Loads a MIDI file from the given path."""
66
+ return MIDI.load(file_path)
 
 
67
 
68
+ def generate_onnx(self, midi_data):
69
+ """Generates a MIDI variation using ONNX models."""
70
  mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data))
71
  input_tensor = np.array([mid_seq], dtype=np.int64)
72
  cur_len = input_tensor.shape[1]
73
+ max_len = 1024
74
+ while cur_len < max_len:
75
+ inputs = {"x": input_tensor[:, -1:]}
76
+ hidden = self.model_base.run(None, inputs)[0]
77
+ logits = self.model_token.run(None, {"hidden": hidden})[0]
78
+ probs = self._softmax(logits, axis=-1)
79
+ next_token = self._sample_top_p_k(probs, 0.98, 20)
 
 
80
  input_tensor = np.concatenate([input_tensor, next_token], axis=1)
81
  cur_len += 1
 
 
82
  new_seq = input_tensor[0].tolist()
83
+ generated_midi = self.tokenizer.detokenize(new_seq)
84
+ # Store base64-encoded MIDI data for downloads
85
+ midi_bytes = MIDI.save(generated_midi)
86
+ self.generated_files.append(base64.b64encode(midi_bytes).decode('utf-8'))
87
+ return generated_midi
 
88
 
89
  def play_midi(self, midi_data):
90
+ """Renders MIDI data to audio bytes."""
91
  midi_bytes = base64.b64decode(midi_data)
92
  midi_file = MIDI.load(io.BytesIO(midi_bytes))
93
  audio = io.BytesIO()
 
95
  audio.seek(0)
96
  return audio
97
 
98
+ @staticmethod
99
+ def _softmax(x, axis):
100
+ """Computes softmax probabilities."""
101
+ exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
102
+ return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
103
+
104
+ @staticmethod
105
+ def _sample_top_p_k(probs, p, k):
106
+ """Samples a token using top-p and top-k sampling (simplified)."""
107
+ # Placeholder: replace with actual sampling logic if needed
108
+ return np.array([[np.random.choice(len(probs[0]))]])
109
+
110
+ def process_midi(files):
111
+ """Processes uploaded MIDI files and yields updates for Gradio components."""
 
 
 
 
 
 
 
112
  if not files:
113
+ yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
114
+ return
115
+
116
+ for idx, file in enumerate(files):
117
+ output_idx = idx % MIDI_OUTPUT_BATCH_SIZE
118
  midi_data = midi_processor.load_midi(file.name)
119
+ generated_midi = midi_processor.generate_onnx(midi_data)
120
+
121
+ # Placeholder for MIDI events; in practice, extract from generated_midi
122
+ # Expected format: ["note", delta_time, track, channel, pitch, velocity, duration]
123
+ events = [
124
+ ["note", 0, 0, 0, 60, 100, 1000], # Example event
125
+ # Add logic to convert generated_midi to events using tokenizer
126
+ ]
127
+
128
+ # Prepare updates list: [js_msg, audio0, midi0, audio1, midi1, ...]
129
+ updates = [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
130
+
131
+ # Clear visualizer
132
+ updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_clear", "data": [output_idx, "v2"]}]))
133
+ yield updates
134
+
135
+ # Send MIDI events
136
+ updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_append", "data": [output_idx, events]}]))
137
+ yield updates
138
+
139
+ # Finalize visualizer and update audio/MIDI outputs
140
+ audio_update = midi_processor.play_midi(generated_midi)
141
+ midi_update = gr.File.update(value=generated_midi, label=f"Generated MIDI {output_idx}")
142
+ updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_end", "data": output_idx}]))
143
+ updates[1 + 2 * output_idx] = audio_update # Audio component
144
+ updates[2 + 2 * output_idx] = midi_update # MIDI file component
145
+ yield updates
146
+
147
+ # Final yield to ensure all components are in a stable state
148
+ yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
149
+
150
  if __name__ == "__main__":
151
+ parser = argparse.ArgumentParser(description="MIDI Composer App")
152
+ parser.add_argument("--port", type=int, default=7860, help="Server port")
153
+ parser.add_argument("--share", action="store_true", help="Share the app publicly")
154
+ opt = parser.parse_args()
155
 
156
  device_manager = MIDIDeviceManager()
157
  midi_processor = MIDIManager()
158
 
159
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
160
+ # Hidden textbox for sending messages to JS
161
+ js_msg = gr.Textbox(visible=False, elem_id="msg_receiver")
162
+
163
  with gr.Tabs():
164
  # MIDI Prompt Tab
165
  with gr.Tab("MIDI Prompt"):
166
+ midi_upload = gr.File(label="Upload MIDI File(s)", file_count="multiple")
167
+ generate_btn = gr.Button("Generate")
168
  status = gr.Textbox(label="Status", value="Ready", interactive=False)
 
 
 
 
 
169
 
170
+ # Outputs Tab
171
+ with gr.Tab("Outputs"):
172
+ output_audios = []
173
+ output_midis = []
174
+ for i in range(MIDI_OUTPUT_BATCH_SIZE):
175
+ with gr.Column():
176
+ gr.Markdown(f"## Output {i+1}")
177
+ gr.HTML(elem_id=f"midi_visualizer_container_{i}")
178
+ output_audio = gr.Audio(label="Generated Audio", type="bytes", autoplay=True, elem_id=f"midi_audio_{i}")
179
+ output_midi = gr.File(label="Generated MIDI", file_types=[".mid"])
180
+ output_audios.append(output_audio)
181
+ output_midis.append(output_midi)
182
 
183
  # Devices Tab
184
  with gr.Tab("Devices"):
185
+ device_info = gr.Textbox(label="Connected MIDI Devices", value=device_manager.get_device_info(), interactive=False)
186
  refresh_btn = gr.Button("Refresh Devices")
187
+ refresh_btn.click(fn=lambda: device_manager.get_device_info(), outputs=[device_info])
188
+
189
+ # Define output components for event handling
190
+ outputs = [js_msg] + output_audios + output_midis
191
+
192
+ # Bind the generate button to the processing function
193
+ generate_btn.click(fn=process_midi, inputs=[midi_upload], outputs=outputs)
194
 
195
+ # Launch the app
196
+ app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
197
  device_manager.close()