awacke1 commited on
Commit
75808a5
·
verified ·
1 Parent(s): 5eeabb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -186
app.py CHANGED
@@ -2,25 +2,19 @@ import random
2
  import argparse
3
  import os
4
  import glob
 
5
  import rtmidi
6
  import gradio as gr
7
  import numpy as np
8
- import MIDI
9
- import base64
10
- import io
11
- import soundfile as sf # Placeholder for audio rendering
12
  from huggingface_hub import hf_hub_download
 
13
  from midi_synthesizer import MidiSynthesizer
 
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
  in_space = os.getenv("SYSTEM") == "spaces"
17
 
18
- SONG_DATA = {
19
- "title": "Do You Believe in Love",
20
- "progression": ["G", "D", "Em", "C"],
21
- "lyrics": ["I was walking down a one-way street", "Just a-looking for someone to meet"]
22
- }
23
-
24
  class MIDIDeviceManager:
25
  def __init__(self):
26
  self.midiout = rtmidi.MidiOut()
@@ -51,63 +45,24 @@ class MIDIManager:
51
  def __init__(self):
52
  self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
53
  self.synthesizer = MidiSynthesizer(self.soundfont_path)
54
- self.loaded_midi = {} # midi_id: (file_path, midi_obj)
55
- self.modified_files = [] # Stores (midi_base64, audio_base64) tuples
56
- self.example_variations = {} # midi_id: (midi_base64, audio_base64) for pre-generated examples
57
  self.is_playing = False
58
- self.instruments = self.random_instrument_set()
59
- self.drum_beat = self.create_drum_beat()
60
- self.starter_midi = self.create_starter_midi()
61
- self.loaded_midi["starter"] = ("Starter MIDI", self.starter_midi)
62
- self.preload_default_midi()
63
- self.load_example_midis() # Pre-generate example variations
64
-
65
- def random_instrument_set(self):
66
- instrument_pool = [0, 24, 32, 48] # Piano, Guitar, Bass, Strings
67
- return random.sample(instrument_pool, 4)
68
-
69
- def create_drum_beat(self):
70
- return [(36, 100, 0), (42, 80, 50), (38, 90, 100), (42, 80, 150)] # Kick, hi-hat, snare, hi-hat
71
-
72
- def create_starter_midi(self):
73
- midi = MIDI.MIDIFile(5)
74
- for i, inst in enumerate(self.instruments):
75
- midi.addTrack()
76
- midi.addProgramChange(i, 0, 0, inst)
77
- for t in range(0, 400, 100):
78
- note = random.randint(60, 84)
79
- midi.addNote(i, 0, note, t, 100, 100)
80
- midi.addTrack()
81
- for note, vel, time in self.drum_beat:
82
- midi.addNote(4, 9, note, time, 100, vel)
83
- return midi
84
-
85
- def preload_default_midi(self):
86
- default_path = "default.mid"
87
- if os.path.exists(default_path):
88
- midi_id = "default"
89
- midi = MIDI.load(default_path)
90
- self.loaded_midi[midi_id] = (default_path, midi)
91
- midi_data, audio_data = self.generate_variation(midi_id)
92
- self.play_with_loop(midi_data)
93
-
94
- def load_example_midis(self):
95
- examples = {}
96
- for file_path in glob.glob("*.mid") + glob.glob("*.midi"):
97
- if file_path == "default.mid":
98
- continue
99
- midi_id = f"example_{len(examples)}"
100
- midi = MIDI.load(file_path)
101
- self.loaded_midi[midi_id] = (file_path, midi)
102
- # Pre-generate variation for each example
103
- midi_data, audio_data = self.generate_variation(midi_id)
104
- examples[midi_id] = (midi_data, audio_data)
105
- self.example_variations = examples
106
- return examples
107
 
108
  def load_midi(self, file_path):
109
  midi = MIDI.load(file_path)
110
- midi_id = f"midi_{len(self.loaded_midi) - len(self.example_variations) - 1}"
111
  self.loaded_midi[midi_id] = (file_path, midi)
112
  return midi_id
113
 
@@ -122,13 +77,13 @@ class MIDIManager:
122
  instruments.add(event.program)
123
  return notes, list(instruments)
124
 
125
- def generate_variation(self, midi_id, length_factor=2, variation=0.3):
126
  if midi_id not in self.loaded_midi:
127
- return None, None
128
  _, midi = self.loaded_midi[midi_id]
129
  notes, instruments = self.extract_notes_and_instruments(midi)
130
  new_notes = []
131
- for _ in range(int(length_factor)):
132
  for note, vel, time in notes:
133
  if random.random() < variation:
134
  new_note = min(127, max(0, note + random.randint(-2, 2)))
@@ -147,75 +102,91 @@ class MIDIManager:
147
  midi_output = io.BytesIO()
148
  new_midi.writeFile(midi_output)
149
  midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
150
-
151
- temp_midi = 'temp.mid'
152
- with open(temp_midi, 'wb') as f:
153
- f.write(midi_output.getvalue())
154
- audio_output = io.BytesIO()
155
- self.synthesizer.play_midi(new_midi)
156
- audio_data = None # Placeholder; see Notes
157
- if os.path.exists(temp_midi):
158
- os.remove(temp_midi)
159
-
160
- self.modified_files.append((midi_data, audio_data))
161
- return midi_data, audio_data
162
 
163
- def apply_synth_effect(self, midi_data, effect, intensity):
164
- midi = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
165
- if effect == "tempo":
166
- factor = 1 + (intensity - 0.5) * 0.4
167
- for track in midi.tracks:
168
- for event in track.events:
169
- event.time = int(event.time * factor)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  midi_output = io.BytesIO()
172
- midi.writeFile(midi_output)
173
  midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
174
-
175
- temp_midi = 'temp.mid'
176
- with open(temp_midi, 'wb') as f:
177
- f.write(midi_output.getvalue())
178
- audio_output = io.BytesIO()
179
- self.synthesizer.play_midi(midi)
180
- audio_data = None # Placeholder
181
- if os.path.exists(temp_midi):
182
- os.remove(temp_midi)
183
-
184
- self.modified_files.append((midi_data, audio_data))
185
- return midi_data, audio_data
186
 
187
  def play_with_loop(self, midi_data):
188
  self.is_playing = True
189
  midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
190
  while self.is_playing:
191
  self.synthesizer.play_midi(midi_file)
192
- return "Stopped"
193
 
194
  def stop_playback(self):
195
  self.is_playing = False
196
- return "Stopping..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def create_download_list():
199
  html = "<h3>Downloads</h3><ul>"
200
- for i, (midi_data, audio_data) in enumerate(midi_processor.modified_files):
201
- html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="midi_{i}.mid">MIDI {i}</a>'
202
- if audio_data:
203
- html += f' | <a href="data:audio/wav;base64,{audio_data}" download="audio_{i}.wav">Audio {i}</a>'
204
- html += '</li>'
205
  html += "</ul>"
206
  return html
207
 
208
- def get_midi_choices():
209
- return [(os.path.basename(path), midi_id) for midi_id, (path, _) in midi_processor.loaded_midi.items()]
210
-
211
- def get_example_choices():
212
- return [(os.path.basename(path), midi_id) for midi_id, (path, _) in midi_processor.loaded_midi.items() if midi_id.startswith("example")]
213
-
214
  if __name__ == "__main__":
215
  parser = argparse.ArgumentParser()
216
  parser.add_argument("--port", type=int, default=7860)
217
  parser.add_argument("--share", action="store_true")
218
- parser.add_argument("--batch", type=int, default=1)
219
  opt = parser.parse_args()
220
 
221
  midi_manager = MIDIDeviceManager()
@@ -227,96 +198,50 @@ if __name__ == "__main__":
227
  with gr.Tabs():
228
  # Tab 1: MIDI Prompt (Main Tab)
229
  with gr.Tab("MIDI Prompt"):
230
- midi_files = gr.File(label="Upload MIDI Files", file_count="multiple")
231
- loaded_display = gr.HTML(value="No files loaded")
232
- output = gr.Audio(label="Generated Preview", type="bytes", autoplay=True)
233
 
234
- def load_and_generate(files):
235
- html = "<h3>Loaded Files</h3>"
 
236
  midi_data = None
237
- for file in files or []:
238
  midi_id = midi_processor.load_midi(file.name)
239
- html += f"<div>{file.name} <button onclick=\"remove_midi('{midi_id}')\">X</button></div>"
240
- midi_data, _ = midi_processor.generate_variation(midi_id)
241
- return html, (io.BytesIO(base64.b64decode(midi_data)) if midi_data else None), get_midi_choices(), create_download_list()
 
242
 
243
- midi_files.change(load_and_generate, inputs=[midi_files],
244
- outputs=[loaded_display, output, gr.State(get_midi_choices()), "downloads"])
245
 
246
- # Tab 2: Examples
247
- with gr.Tab("Examples"):
248
- example_select = gr.Dropdown(label="Select Example", choices=get_example_choices(), value=None)
249
- example_output = gr.Audio(label="Example Preview", type="bytes", autoplay=True)
250
-
251
- def load_example(midi_id):
252
- if not midi_id or midi_id not in midi_processor.example_variations:
253
- return None
254
- midi_data, audio_data = midi_processor.example_variations[midi_id]
255
- midi_processor.play_with_loop(midi_data)
256
- return io.BytesIO(base64.b64decode(midi_data))
257
-
258
- example_select.change(load_example, inputs=[example_select], outputs=[example_output])
259
 
260
- # Tab 3: Generate & Perform
261
- with gr.Tab("Generate & Perform"):
262
- midi_select = gr.Dropdown(label="Select MIDI", choices=get_midi_choices(), value="starter")
263
- length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor")
264
- variation = gr.Slider(0, 1, value=0.3, label="Variation")
265
- generate_btn = gr.Button("Generate")
266
- effect = gr.Radio(["tempo"], label="Synth Effect", value="tempo")
267
- intensity = gr.Slider(0, 1, value=0.5, label="Effect Intensity")
268
- apply_btn = gr.Button("Apply Effect")
269
  stop_btn = gr.Button("Stop Playback")
270
- output = gr.Audio(label="Preview", type="bytes", autoplay=True)
271
- status = gr.Textbox(label="Status", value="Ready")
272
- midi_device = gr.Dropdown(label="MIDI Output Device", choices=midi_manager.get_output_devices(), type="index")
273
- tempo = gr.Slider(label="Tempo (BPM)", minimum=40, maximum=200, value=120, step=1)
274
- device_info = gr.Textbox(label="Connected MIDI Devices", value=midi_manager.get_device_info(), readonly=True)
275
- refresh_btn = gr.Button("🔄 Refresh MIDI Devices")
276
-
277
- def update_dropdown(choices):
278
- return gr.update(choices=choices)
279
-
280
- gr.State(get_midi_choices()).change(update_dropdown, inputs=[gr.State()], outputs=[midi_select])
281
-
282
- def generate(midi_id, length, var):
283
- if not midi_id:
284
- return None, "Select a MIDI file", create_download_list()
285
- midi_data, audio_data = midi_processor.generate_variation(midi_id, length, var)
286
- midi_processor.play_with_loop(midi_data)
287
- return io.BytesIO(base64.b64decode(midi_data)), "Playing", create_download_list()
288
-
289
- def apply_effect(midi_data, fx, inten):
290
- if not midi_data:
291
- return None, "Generate a MIDI first", create_download_list()
292
- new_midi_data, audio_data = midi_processor.apply_synth_effect(midi_data.decode('utf-8'), fx, inten)
293
- midi_processor.play_with_loop(new_midi_data)
294
- return io.BytesIO(base64.b64decode(new_midi_data)), "Playing", create_download_list()
295
 
296
  def refresh_devices():
297
- return midi_manager.get_output_devices(), midi_manager.get_device_info()
298
 
299
- generate_btn.click(generate, inputs=[midi_select, length_factor, variation],
300
- outputs=[output, status, "downloads"])
301
- apply_btn.click(apply_effect, inputs=[output, effect, intensity],
302
- outputs=[output, status, "downloads"])
303
  stop_btn.click(midi_processor.stop_playback, inputs=None, outputs=[status])
304
- refresh_btn.click(refresh_devices, inputs=None, outputs=[midi_device, device_info])
305
-
306
- # Tab 4: Downloads
307
- with gr.Tab("Downloads", elem_id="downloads"):
308
- downloads = gr.HTML(value=create_download_list())
309
 
310
  gr.Markdown("""
311
  <div style='text-align: center; margin-top: 20px;'>
312
  <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br>
313
  <strong>Hugging Face</strong><br>
314
- <a href='https://huggingface.co/models'>Models</a> |
315
- <a href='https://huggingface.co/datasets'>Datasets</a> |
316
- <a href='https://huggingface.co/spaces'>Spaces</a> |
317
- <a href='https://huggingface.co/posts'>Posts</a> |
318
- <a href='https://huggingface.co/docs'>Docs</a> |
319
- <a href='https://huggingface.co/enterprise'>Enterprise</a> |
320
  <a href='https://huggingface.co/pricing'>Pricing</a>
321
  </div>
322
  """)
 
2
  import argparse
3
  import os
4
  import glob
5
+ import json
6
  import rtmidi
7
  import gradio as gr
8
  import numpy as np
9
+ import onnxruntime as rt
 
 
 
10
  from huggingface_hub import hf_hub_download
11
+ import MIDI
12
  from midi_synthesizer import MidiSynthesizer
13
+ from midi_tokenizer import MIDITokenizer
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
  in_space = os.getenv("SYSTEM") == "spaces"
17
 
 
 
 
 
 
 
18
  class MIDIDeviceManager:
19
  def __init__(self):
20
  self.midiout = rtmidi.MidiOut()
 
45
  def __init__(self):
46
  self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
47
  self.synthesizer = MidiSynthesizer(self.soundfont_path)
48
+ self.loaded_midi = {}
49
+ self.modified_files = []
 
50
  self.is_playing = False
51
+ self.tokenizer = self.load_tokenizer("skytnt/midi-model")
52
+ self.model_base = rt.InferenceSession(hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
53
+ self.model_token = rt.InferenceSession(hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
54
+
55
+ def load_tokenizer(self, repo_id):
56
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
57
+ with open(config_path, "r") as f:
58
+ config = json.load(f)
59
+ tokenizer = MIDITokenizer(config["tokenizer"]["version"])
60
+ tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
61
+ return tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def load_midi(self, file_path):
64
  midi = MIDI.load(file_path)
65
+ midi_id = f"midi_{len(self.loaded_midi)}"
66
  self.loaded_midi[midi_id] = (file_path, midi)
67
  return midi_id
68
 
 
77
  instruments.add(event.program)
78
  return notes, list(instruments)
79
 
80
+ def generate_variation(self, midi_id, length_factor=10, variation=0.3):
81
  if midi_id not in self.loaded_midi:
82
+ return None
83
  _, midi = self.loaded_midi[midi_id]
84
  notes, instruments = self.extract_notes_and_instruments(midi)
85
  new_notes = []
86
+ for _ in range(int(length_factor)): # Max length: 10x repetition
87
  for note, vel, time in notes:
88
  if random.random() < variation:
89
  new_note = min(127, max(0, note + random.randint(-2, 2)))
 
102
  midi_output = io.BytesIO()
103
  new_midi.writeFile(midi_output)
104
  midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
105
+ self.modified_files.append(midi_data)
106
+ return midi_data
 
 
 
 
 
 
 
 
 
 
107
 
108
+ def generate_onnx(self, midi_id, max_len=1024, temp=1.0, top_p=0.98, top_k=20):
109
+ if midi_id not in self.loaded_midi:
110
+ return None
111
+ _, mid = self.loaded_midi[midi_id]
112
+ mid_seq = self.tokenizer.tokenize(MIDI.midi2score(mid))
113
+ mid = np.asarray([mid_seq], dtype=np.int64)
114
+ generator = np.random.RandomState(random.randint(0, MAX_SEED))
115
 
116
+ # Simplified ONNX generation from app_onnx.py
117
+ input_tensor = mid
118
+ cur_len = input_tensor.shape[1]
119
+ model = [self.model_base, self.model_token, self.tokenizer]
120
+
121
+ while cur_len < max_len:
122
+ inputs = {"x": rt.OrtValue.ortvalue_from_numpy(input_tensor[:, -1:], device_type="cuda")}
123
+ outputs = {"hidden": rt.OrtValue.ortvalue_from_shape_and_type((1, 1, 1024), np.float32, device_type="cuda")}
124
+ io_binding = model[0].io_binding()
125
+ for name, val in inputs.items():
126
+ io_binding.bind_ortvalue_input(name, val)
127
+ for name in outputs:
128
+ io_binding.bind_ortvalue_output(name, outputs[name])
129
+ model[0].run_with_iobinding(io_binding)
130
+ hidden = outputs["hidden"].numpy()[:, -1:]
131
+
132
+ logits = model[1].run(None, {"hidden": hidden})[0]
133
+ scores = softmax(logits / temp, -1)
134
+ next_token = sample_top_p_k(scores, top_p, top_k, generator)
135
+ input_tensor = np.concatenate([input_tensor, next_token], axis=1)
136
+ cur_len += 1
137
+
138
+ mid_seq = input_tensor.tolist()[0]
139
+ new_midi = self.tokenizer.detokenize(mid_seq)
140
  midi_output = io.BytesIO()
141
+ MIDI.score2midi(new_midi, midi_output)
142
  midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
143
+ self.modified_files.append(midi_data)
144
+ return midi_data
 
 
 
 
 
 
 
 
 
 
145
 
146
  def play_with_loop(self, midi_data):
147
  self.is_playing = True
148
  midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
149
  while self.is_playing:
150
  self.synthesizer.play_midi(midi_file)
 
151
 
152
  def stop_playback(self):
153
  self.is_playing = False
154
+ return "Playback stopped"
155
+
156
+ def softmax(x, axis):
157
+ x_max = np.max(x, axis=axis, keepdims=True)
158
+ exp_x_shifted = np.exp(x - x_max)
159
+ return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
160
+
161
+ def sample_top_p_k(probs, p, k, generator=None):
162
+ if generator is None:
163
+ generator = np.random
164
+ probs_idx = np.argsort(-probs, axis=-1)
165
+ probs_sort = np.take_along_axis(probs, probs_idx, -1)
166
+ probs_sum = np.cumsum(probs_sort, axis=-1)
167
+ mask = probs_sum - probs_sort > p
168
+ probs_sort[mask] = 0.0
169
+ mask = np.zeros(probs_sort.shape[-1])
170
+ mask[:k] = 1
171
+ probs_sort *= mask
172
+ probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True)
173
+ shape = probs_sort.shape
174
+ probs_sort_flat = probs_sort.reshape(-1, shape[-1])
175
+ probs_idx_flat = probs_idx.reshape(-1, shape[-1])
176
+ next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)])
177
+ return next_token.reshape(*shape[:-1])
178
 
179
  def create_download_list():
180
  html = "<h3>Downloads</h3><ul>"
181
+ for i, midi_data in enumerate(midi_processor.modified_files):
182
+ html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="midi_{i}.mid">MIDI {i}</a></li>'
 
 
 
183
  html += "</ul>"
184
  return html
185
 
 
 
 
 
 
 
186
  if __name__ == "__main__":
187
  parser = argparse.ArgumentParser()
188
  parser.add_argument("--port", type=int, default=7860)
189
  parser.add_argument("--share", action="store_true")
 
190
  opt = parser.parse_args()
191
 
192
  midi_manager = MIDIDeviceManager()
 
198
  with gr.Tabs():
199
  # Tab 1: MIDI Prompt (Main Tab)
200
  with gr.Tab("MIDI Prompt"):
201
+ midi_upload = gr.File(label="Upload MIDI File", file_count="multiple")
202
+ output = gr.Audio(label="Generated MIDI", type="bytes", autoplay=True)
203
+ status = gr.Textbox(label="Status", value="Ready", interactive=False)
204
 
205
+ def process_midi(files):
206
+ if not files:
207
+ return None, "No file uploaded"
208
  midi_data = None
209
+ for file in files:
210
  midi_id = midi_processor.load_midi(file.name)
211
+ # Use ONNX generation for advanced synthesis
212
+ midi_data = midi_processor.generate_onnx(midi_id, max_len=1024)
213
+ midi_processor.play_with_loop(midi_data)
214
+ return io.BytesIO(base64.b64decode(midi_data)), "Playing", create_download_list()
215
 
216
+ midi_upload.change(process_midi, inputs=[midi_upload],
217
+ outputs=[output, status, "downloads"])
218
 
219
+ # Tab 2: Downloads
220
+ with gr.Tab("Downloads", elem_id="downloads"):
221
+ downloads = gr.HTML(value="No generated files yet")
 
 
 
 
 
 
 
 
 
 
222
 
223
+ # Tab 3: Devices
224
+ with gr.Tab("Devices"):
225
+ device_info = gr.Textbox(label="Connected MIDI Devices", value=midi_manager.get_device_info(), interactive=False)
226
+ refresh_btn = gr.Button("Refresh Devices")
 
 
 
 
 
227
  stop_btn = gr.Button("Stop Playback")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  def refresh_devices():
230
+ return midi_manager.get_device_info()
231
 
232
+ refresh_btn.click(refresh_devices, inputs=None, outputs=[device_info])
 
 
 
233
  stop_btn.click(midi_processor.stop_playback, inputs=None, outputs=[status])
 
 
 
 
 
234
 
235
  gr.Markdown("""
236
  <div style='text-align: center; margin-top: 20px;'>
237
  <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br>
238
  <strong>Hugging Face</strong><br>
239
+ <a href='https://huggingface.co/models'>Models</a> |
240
+ <a href='https://huggingface.co/datasets'>Datasets</a> |
241
+ <a href='https://huggingface.co/spaces'>Spaces</a> |
242
+ <a href='https://huggingface.co/posts'>Posts</a> |
243
+ <a href='https://huggingface.co/docs'>Docs</a> |
244
+ <a href='https://huggingface.co/enterprise'>Enterprise</a> |
245
  <a href='https://huggingface.co/pricing'>Pricing</a>
246
  </div>
247
  """)