awacke1 commited on
Commit
1151735
·
verified ·
1 Parent(s): 75808a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -177
app.py CHANGED
@@ -1,58 +1,58 @@
1
- import random
2
  import argparse
 
 
3
  import os
4
- import glob
5
- import json
6
- import rtmidi
7
- import gradio as gr
8
  import numpy as np
 
 
9
  import onnxruntime as rt
10
  from huggingface_hub import hf_hub_download
11
  import MIDI
12
  from midi_synthesizer import MidiSynthesizer
13
  from midi_tokenizer import MIDITokenizer
14
 
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
- in_space = os.getenv("SYSTEM") == "spaces"
 
17
 
 
18
  class MIDIDeviceManager:
19
  def __init__(self):
20
  self.midiout = rtmidi.MidiOut()
21
  self.midiin = rtmidi.MidiIn()
22
 
23
- def get_output_devices(self):
24
- return self.midiout.get_ports() or ["No MIDI output devices"]
25
-
26
- def get_input_devices(self):
27
- return self.midiin.get_ports() or ["No MIDI input devices"]
28
-
29
  def get_device_info(self):
30
- out_devices = self.get_output_devices()
31
- in_devices = self.get_input_devices()
32
- out_info = "\n".join([f"Out Port {i}: {name}" for i, name in enumerate(out_devices)]) if out_devices else "No MIDI output devices detected"
33
- in_info = "\n".join([f"In Port {i}: {name}" for i, name in enumerate(in_devices)]) if in_devices else "No MIDI input devices detected"
34
- return f"Output Devices:\n{out_info}\n\nInput Devices:\n{in_info}"
35
 
36
  def close(self):
37
  if self.midiout.is_port_open():
38
  self.midiout.close_port()
39
  if self.midiin.is_port_open():
40
  self.midiin.close_port()
41
- del self.midiout
42
- del self.midiin
43
 
 
44
  class MIDIManager:
45
  def __init__(self):
46
- self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
47
- self.synthesizer = MidiSynthesizer(self.soundfont_path)
48
- self.loaded_midi = {}
49
- self.modified_files = []
 
 
 
 
 
 
 
 
50
  self.is_playing = False
51
- self.tokenizer = self.load_tokenizer("skytnt/midi-model")
52
- self.model_base = rt.InferenceSession(hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
53
- self.model_token = rt.InferenceSession(hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
54
 
55
- def load_tokenizer(self, repo_id):
56
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
57
  with open(config_path, "r") as f:
58
  config = json.load(f)
@@ -61,190 +61,131 @@ class MIDIManager:
61
  return tokenizer
62
 
63
  def load_midi(self, file_path):
64
- midi = MIDI.load(file_path)
65
- midi_id = f"midi_{len(self.loaded_midi)}"
66
- self.loaded_midi[midi_id] = (file_path, midi)
67
- return midi_id
68
-
69
- def extract_notes_and_instruments(self, midi):
70
- notes = []
71
- instruments = set()
72
- for track in midi.tracks:
73
- for event in track.events:
74
- if event.type == 'note_on' and event.velocity > 0:
75
- notes.append((event.note, event.velocity, event.time))
76
- if hasattr(event, 'program'):
77
- instruments.add(event.program)
78
- return notes, list(instruments)
79
-
80
- def generate_variation(self, midi_id, length_factor=10, variation=0.3):
81
- if midi_id not in self.loaded_midi:
82
- return None
83
- _, midi = self.loaded_midi[midi_id]
84
- notes, instruments = self.extract_notes_and_instruments(midi)
85
- new_notes = []
86
- for _ in range(int(length_factor)): # Max length: 10x repetition
87
- for note, vel, time in notes:
88
- if random.random() < variation:
89
- new_note = min(127, max(0, note + random.randint(-2, 2)))
90
- new_vel = min(127, max(0, vel + random.randint(-10, 10)))
91
- new_notes.append((new_note, new_vel, time))
92
- else:
93
- new_notes.append((note, vel, time))
94
-
95
- new_midi = MIDI.MIDIFile(len(instruments) or 1)
96
- for i, inst in enumerate(instruments or [0]):
97
- new_midi.addTrack()
98
- new_midi.addProgramChange(i, 0, 0, inst)
99
- for note, vel, time in new_notes:
100
- new_midi.addNote(i, 0, note, time, 100, vel)
101
-
102
- midi_output = io.BytesIO()
103
- new_midi.writeFile(midi_output)
104
- midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
105
- self.modified_files.append(midi_data)
106
- return midi_data
107
-
108
- def generate_onnx(self, midi_id, max_len=1024, temp=1.0, top_p=0.98, top_k=20):
109
- if midi_id not in self.loaded_midi:
110
- return None
111
- _, mid = self.loaded_midi[midi_id]
112
- mid_seq = self.tokenizer.tokenize(MIDI.midi2score(mid))
113
- mid = np.asarray([mid_seq], dtype=np.int64)
114
- generator = np.random.RandomState(random.randint(0, MAX_SEED))
115
-
116
- # Simplified ONNX generation from app_onnx.py
117
- input_tensor = mid
118
  cur_len = input_tensor.shape[1]
119
- model = [self.model_base, self.model_token, self.tokenizer]
120
-
121
- while cur_len < max_len:
122
- inputs = {"x": rt.OrtValue.ortvalue_from_numpy(input_tensor[:, -1:], device_type="cuda")}
123
- outputs = {"hidden": rt.OrtValue.ortvalue_from_shape_and_type((1, 1, 1024), np.float32, device_type="cuda")}
124
- io_binding = model[0].io_binding()
125
- for name, val in inputs.items():
126
- io_binding.bind_ortvalue_input(name, val)
127
- for name in outputs:
128
- io_binding.bind_ortvalue_output(name, outputs[name])
129
- model[0].run_with_iobinding(io_binding)
130
- hidden = outputs["hidden"].numpy()[:, -1:]
131
-
132
- logits = model[1].run(None, {"hidden": hidden})[0]
133
- scores = softmax(logits / temp, -1)
134
- next_token = sample_top_p_k(scores, top_p, top_k, generator)
135
  input_tensor = np.concatenate([input_tensor, next_token], axis=1)
136
  cur_len += 1
137
 
138
- mid_seq = input_tensor.tolist()[0]
139
- new_midi = self.tokenizer.detokenize(mid_seq)
 
140
  midi_output = io.BytesIO()
141
  MIDI.score2midi(new_midi, midi_output)
142
  midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
143
- self.modified_files.append(midi_data)
144
  return midi_data
145
 
146
- def play_with_loop(self, midi_data):
147
  self.is_playing = True
148
- midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
149
- while self.is_playing:
150
- self.synthesizer.play_midi(midi_file)
151
-
152
- def stop_playback(self):
 
 
 
153
  self.is_playing = False
154
- return "Playback stopped"
155
 
 
156
  def softmax(x, axis):
157
- x_max = np.max(x, axis=axis, keepdims=True)
158
- exp_x_shifted = np.exp(x - x_max)
159
- return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
160
 
161
- def sample_top_p_k(probs, p, k, generator=None):
162
- if generator is None:
163
- generator = np.random
164
  probs_idx = np.argsort(-probs, axis=-1)
165
- probs_sort = np.take_along_axis(probs, probs_idx, -1)
166
  probs_sum = np.cumsum(probs_sort, axis=-1)
167
  mask = probs_sum - probs_sort > p
168
  probs_sort[mask] = 0.0
169
- mask = np.zeros(probs_sort.shape[-1])
170
- mask[:k] = 1
171
- probs_sort *= mask
172
- probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
173
- shape = probs_sort.shape
174
- probs_sort_flat = probs_sort.reshape(-1, shape[-1])
175
- probs_idx_flat = probs_idx.reshape(-1, shape[-1])
176
- next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
177
- return next_token.reshape(*shape[:-1])
 
 
 
 
 
 
 
 
 
178
 
179
  def create_download_list():
180
- html = "<h3>Downloads</h3><ul>"
181
- for i, midi_data in enumerate(midi_processor.modified_files):
182
- html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="midi_{i}.mid">MIDI {i}</a></li>'
 
 
183
  html += "</ul>"
184
  return html
185
 
 
 
 
 
 
 
 
 
186
  if __name__ == "__main__":
187
- parser = argparse.ArgumentParser()
188
  parser.add_argument("--port", type=int, default=7860)
189
  parser.add_argument("--share", action="store_true")
190
- opt = parser.parse_args()
191
 
192
- midi_manager = MIDIDeviceManager()
193
  midi_processor = MIDIManager()
194
 
195
- with gr.Blocks(theme=gr.themes.Soft()) as app:
196
- gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>")
197
-
198
  with gr.Tabs():
199
- # Tab 1: MIDI Prompt (Main Tab)
200
  with gr.Tab("MIDI Prompt"):
201
- midi_upload = gr.File(label="Upload MIDI File", file_count="multiple")
202
- output = gr.Audio(label="Generated MIDI", type="bytes", autoplay=True)
203
  status = gr.Textbox(label="Status", value="Ready", interactive=False)
204
-
205
- def process_midi(files):
206
- if not files:
207
- return None, "No file uploaded"
208
- midi_data = None
209
- for file in files:
210
- midi_id = midi_processor.load_midi(file.name)
211
- # Use ONNX generation for advanced synthesis
212
- midi_data = midi_processor.generate_onnx(midi_id, max_len=1024)
213
- midi_processor.play_with_loop(midi_data)
214
- return io.BytesIO(base64.b64decode(midi_data)), "Playing", create_download_list()
215
-
216
- midi_upload.change(process_midi, inputs=[midi_upload],
217
- outputs=[output, status, "downloads"])
218
-
219
- # Tab 2: Downloads
220
  with gr.Tab("Downloads", elem_id="downloads"):
221
- downloads = gr.HTML(value="No generated files yet")
222
 
223
- # Tab 3: Devices
224
  with gr.Tab("Devices"):
225
- device_info = gr.Textbox(label="Connected MIDI Devices", value=midi_manager.get_device_info(), interactive=False)
226
  refresh_btn = gr.Button("Refresh Devices")
227
  stop_btn = gr.Button("Stop Playback")
228
-
229
- def refresh_devices():
230
- return midi_manager.get_device_info()
231
-
232
- refresh_btn.click(refresh_devices, inputs=None, outputs=[device_info])
233
- stop_btn.click(midi_processor.stop_playback, inputs=None, outputs=[status])
234
-
235
- gr.Markdown("""
236
- <div style='text-align: center; margin-top: 20px;'>
237
- <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br>
238
- <strong>Hugging Face</strong><br>
239
- <a href='https://huggingface.co/models'>Models</a> |
240
- <a href='https://huggingface.co/datasets'>Datasets</a> |
241
- <a href='https://huggingface.co/spaces'>Spaces</a> |
242
- <a href='https://huggingface.co/posts'>Posts</a> |
243
- <a href='https://huggingface.co/docs'>Docs</a> |
244
- <a href='https://huggingface.co/enterprise'>Enterprise</a> |
245
- <a href='https://huggingface.co/pricing'>Pricing</a>
246
- </div>
247
- """)
248
-
249
- app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True)
250
- midi_manager.close()
 
 
1
  import argparse
2
+ import base64
3
+ import io
4
  import os
5
+ import random
 
 
 
6
  import numpy as np
7
+ import gradio as gr
8
+ import rtmidi
9
  import onnxruntime as rt
10
  from huggingface_hub import hf_hub_download
11
  import MIDI
12
  from midi_synthesizer import MidiSynthesizer
13
  from midi_tokenizer import MIDITokenizer
14
 
15
+ # Constants
16
  MAX_SEED = np.iinfo(np.int32).max
17
+ IN_SPACE = os.getenv("SYSTEM") == "spaces"
18
+ MAX_LENGTH = 1024 # Maximum tokens for generation
19
 
20
+ # MIDI Device Manager
21
  class MIDIDeviceManager:
22
  def __init__(self):
23
  self.midiout = rtmidi.MidiOut()
24
  self.midiin = rtmidi.MidiIn()
25
 
 
 
 
 
 
 
26
  def get_device_info(self):
27
+ out_ports = self.midiout.get_ports() or ["No MIDI output devices"]
28
+ in_ports = self.midiin.get_ports() or ["No MIDI input devices"]
29
+ return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}"
 
 
30
 
31
  def close(self):
32
  if self.midiout.is_port_open():
33
  self.midiout.close_port()
34
  if self.midiin.is_port_open():
35
  self.midiin.close_port()
36
+ del self.midiout, self.midiin
 
37
 
38
+ # MIDI Processor with ONNX Generation
39
  class MIDIManager:
40
  def __init__(self):
41
+ self.soundfont = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
42
+ self.synthesizer = MidiSynthesizer(self.soundfont)
43
+ self.tokenizer = self._load_tokenizer("skytnt/midi-model")
44
+ self.model_base = rt.InferenceSession(
45
+ hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"),
46
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
47
+ )
48
+ self.model_token = rt.InferenceSession(
49
+ hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"),
50
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
51
+ )
52
+ self.generated_files = [] # Store base64-encoded MIDI data
53
  self.is_playing = False
 
 
 
54
 
55
+ def _load_tokenizer(self, repo_id):
56
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
57
  with open(config_path, "r") as f:
58
  config = json.load(f)
 
61
  return tokenizer
62
 
63
  def load_midi(self, file_path):
64
+ try:
65
+ return MIDI.load(file_path)
66
+ except Exception as e:
67
+ raise ValueError(f"Failed to load MIDI file: {e}")
68
+
69
+ def generate_variation(self, midi_data, temp=1.0, top_p=0.98, top_k=20):
70
+ # Tokenize input MIDI
71
+ mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data))
72
+ input_tensor = np.array([mid_seq], dtype=np.int64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  cur_len = input_tensor.shape[1]
74
+ generator = np.random.RandomState(random.randint(0, MAX_SEED))
75
+
76
+ # Generate up to MAX_LENGTH
77
+ while cur_len < MAX_LENGTH:
78
+ inputs = {"x": input_tensor[:, -1:]} # Last token
79
+ hidden = self.model_base.run(None, inputs)[0] # Base model output
80
+ logits = self.model_token.run(None, {"hidden": hidden})[0] # Token model output
81
+ probs = softmax(logits / temp, axis=-1)
82
+ next_token = sample_top_p_k(probs, top_p, top_k, generator)
 
 
 
 
 
 
 
83
  input_tensor = np.concatenate([input_tensor, next_token], axis=1)
84
  cur_len += 1
85
 
86
+ # Detokenize and save as MIDI
87
+ new_seq = input_tensor[0].tolist()
88
+ new_midi = self.tokenizer.detokenize(new_seq)
89
  midi_output = io.BytesIO()
90
  MIDI.score2midi(new_midi, midi_output)
91
  midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
92
+ self.generated_files.append(midi_data)
93
  return midi_data
94
 
95
+ def play_midi(self, midi_data):
96
  self.is_playing = True
97
+ midi_bytes = base64.b64decode(midi_data)
98
+ midi_file = MIDI.load(io.BytesIO(midi_bytes))
99
+ audio = io.BytesIO()
100
+ self.synthesizer.render_midi(midi_file, audio)
101
+ audio.seek(0)
102
+ return audio
103
+
104
+ def stop(self):
105
  self.is_playing = False
 
106
 
107
+ # Helper Functions
108
  def softmax(x, axis):
109
+ exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
110
+ return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
 
111
 
112
+ def sample_top_p_k(probs, p, k, generator):
 
 
113
  probs_idx = np.argsort(-probs, axis=-1)
114
+ probs_sort = np.take_along_axis(probs, probs_idx, axis=-1)
115
  probs_sum = np.cumsum(probs_sort, axis=-1)
116
  mask = probs_sum - probs_sort > p
117
  probs_sort[mask] = 0.0
118
+ probs_sort[:, k:] = 0.0 # Top-k filtering
119
+ probs_sort /= probs_sort.sum(axis=-1, keepdims=True)
120
+ next_token = generator.choice(probs.shape[-1], p=probs_sort[0])
121
+ return np.array([[next_token]])
122
+
123
+ # UI Functions
124
+ def process_midi_upload(files):
125
+ if not files:
126
+ return None, "No file uploaded", ""
127
+ file = files[0] # Process first file
128
+ try:
129
+ midi_data = midi_processor.load_midi(file.name)
130
+ generated_midi = midi_processor.generate_variation(midi_data)
131
+ audio = midi_processor.play_midi(generated_midi)
132
+ download_html = create_download_list()
133
+ return audio, "Generated and playing", download_html
134
+ except Exception as e:
135
+ return None, f"Error: {e}", ""
136
 
137
  def create_download_list():
138
+ if not midi_processor.generated_files:
139
+ return "<p>No generated files yet.</p>"
140
+ html = "<h3>Generated MIDI Files</h3><ul>"
141
+ for i, midi_data in enumerate(midi_processor.generated_files):
142
+ html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="generated_{i}.mid">Download MIDI {i}</a></li>'
143
  html += "</ul>"
144
  return html
145
 
146
+ def refresh_devices():
147
+ return device_manager.get_device_info()
148
+
149
+ def stop_playback():
150
+ midi_processor.stop()
151
+ return "Playback stopped"
152
+
153
+ # Main Application
154
  if __name__ == "__main__":
155
+ parser = argparse.ArgumentParser(description="MIDI Composer with ONNX Generation")
156
  parser.add_argument("--port", type=int, default=7860)
157
  parser.add_argument("--share", action="store_true")
158
+ args = parser.parse_args()
159
 
160
+ device_manager = MIDIDeviceManager()
161
  midi_processor = MIDIManager()
162
 
163
+ with gr.Blocks(title="MIDI Composer", theme=gr.themes.Soft()) as app:
164
+ gr.Markdown("# 🎵 MIDI Composer 🎵")
165
+
166
  with gr.Tabs():
167
+ # MIDI Prompt Tab
168
  with gr.Tab("MIDI Prompt"):
169
+ midi_upload = gr.File(label="Upload MIDI File", file_types=[".mid", ".midi"])
170
+ audio_output = gr.Audio(label="Generated MIDI", type="bytes", autoplay=True)
171
  status = gr.Textbox(label="Status", value="Ready", interactive=False)
172
+ midi_upload.change(
173
+ process_midi_upload,
174
+ inputs=[midi_upload],
175
+ outputs=[audio_output, status, gr.HTML(elem_id="downloads")]
176
+ )
177
+
178
+ # Downloads Tab
 
 
 
 
 
 
 
 
 
179
  with gr.Tab("Downloads", elem_id="downloads"):
180
+ gr.HTML(value=create_download_list())
181
 
182
+ # Devices Tab
183
  with gr.Tab("Devices"):
184
+ device_info = gr.Textbox(label="MIDI Devices", value=device_manager.get_device_info(), interactive=False)
185
  refresh_btn = gr.Button("Refresh Devices")
186
  stop_btn = gr.Button("Stop Playback")
187
+ refresh_btn.click(refresh_devices, outputs=[device_info])
188
+ stop_btn.click(stop_playback, outputs=[status])
189
+
190
+ app.launch(server_port=args.port, share=args.share, inbrowser=True)
191
+ device_manager.close()