awacke1 commited on
Commit
64015bc
·
verified ·
1 Parent(s): 1e9061c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -26
app.py CHANGED
@@ -11,7 +11,6 @@ import gradio as gr
11
  import numpy as np
12
  import torch
13
  import torch.nn.functional as F
14
- import tqdm
15
  from huggingface_hub import hf_hub_download
16
  from transformers import DynamicCache
17
 
@@ -22,7 +21,6 @@ from midi_synthesizer import MidiSynthesizer
22
  MAX_SEED = np.iinfo(np.int32).max
23
  in_space = os.getenv("SYSTEM") == "spaces"
24
 
25
-
26
  @torch.inference_mode()
27
  def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
28
  disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
@@ -118,15 +116,12 @@ def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0,
118
  if all(end):
119
  break
120
 
121
-
122
  def create_msg(name, data):
123
  return {"name": name, "data": data}
124
 
125
-
126
  def send_msgs(msgs):
127
  return json.dumps(msgs)
128
 
129
-
130
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
131
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
132
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
@@ -135,7 +130,6 @@ def get_duration(model_name, tab, mid_seq, continuation_state, continuation_sele
135
  t = gen_events // 14
136
  return t + 5
137
 
138
-
139
  @spaces.GPU(duration=get_duration)
140
  def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
141
  key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
@@ -246,7 +240,6 @@ def run(model_name, tab, mid_seq, continuation_state, continuation_select, instr
246
  t = time.time()
247
  yield mid_seq, continuation_state, seed, send_msgs([])
248
 
249
-
250
  def finish_run(model_name, mid_seq):
251
  if mid_seq is None:
252
  outputs = [None] * OUTPUT_BATCH_SIZE
@@ -267,10 +260,11 @@ def finish_run(model_name, mid_seq):
267
  create_msg("visualizer_end", i)]
268
  return *outputs, send_msgs(end_msgs)
269
 
270
-
271
  def synthesis_task(mid):
272
  return synthesizer.synthesis(MIDI.score2opus(mid))
273
 
 
 
274
  def render_audio(model_name, mid_seq, should_render_audio):
275
  if (not should_render_audio) or mid_seq is None:
276
  outputs = [None] * OUTPUT_BATCH_SIZE
@@ -351,7 +345,67 @@ drum_kits2number = {v: k for k, v in number2drum_kits.items()}
351
  key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
352
  'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
353
 
354
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  parser = argparse.ArgumentParser()
356
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
357
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
@@ -402,17 +456,7 @@ if __name__ == "__main__":
402
  load_javascript()
403
  app = gr.Blocks(theme=gr.themes.Soft())
404
  with app:
405
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer</h1>")
406
- gr.Markdown("![Visitors](https://api.visitorbadge.io/api/visitors?path=skytnt.midi-composer&style=flat)\n\n"
407
- "Midi event transformer for symbolic music generation\n\n"
408
- "Demo for [SkyTNT/midi-model](https://github.com/SkyTNT/midi-model)\n\n"
409
- "[Open In Colab]"
410
- "(https://colab.research.google.com/github/SkyTNT/midi-model/blob/main/demo.ipynb)"
411
- " or [download windows app](https://github.com/SkyTNT/midi-model/releases)"
412
- " for unlimited generation\n\n"
413
- "**Update v1.3**: MIDITokenizerV2 and new MidiVisualizer\n\n"
414
- "The current **best** model: generic pretrain model (tv2o-medium) by skytnt"
415
- )
416
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
417
  js_msg.change(None, [js_msg], [], js="""
418
  (msg_json) =>{
@@ -431,18 +475,24 @@ if __name__ == "__main__":
431
  input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
432
  value="None")
433
  input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
434
- step=1,
435
- value=0)
436
  input_time_sig = gr.Radio(label="time signature (only for tv2 models)",
437
  value="auto",
438
  choices=["auto", "4/4", "2/4", "3/4", "6/4", "7/4",
439
- "2/2", "3/2", "4/2", "3/8", "5/8", "6/8", "7/8", "9/8", "12/8"]
440
- )
441
  input_key_sig = gr.Radio(label="key signature (only for tv2 models)",
442
  value="auto",
443
  choices=["auto"] + key_signatures,
444
- type="index"
445
- )
 
 
 
 
 
 
 
 
446
  example1 = gr.Examples([
447
  [[], "None"],
448
  [["Acoustic Grand"], "None"],
@@ -457,6 +507,7 @@ if __name__ == "__main__":
457
  [["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
458
  "Electric Bass(finger)"], "Standard"]
459
  ], [input_instruments, input_drum_kit])
 
460
  with gr.TabItem("midi prompt") as tab2:
461
  input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
462
  input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
@@ -470,6 +521,7 @@ if __name__ == "__main__":
470
  input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
471
  example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
472
  [input_midi, input_midi_events])
 
473
  with gr.TabItem("last output prompt") as tab3:
474
  gr.Markdown("Continue generating on the last output.")
475
  input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
@@ -530,5 +582,38 @@ if __name__ == "__main__":
530
  # queue=False)
531
  undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
532
  [output_midi_seq, output_continuation_state, js_msg], queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
534
  thread_pool.shutdown()
 
11
  import numpy as np
12
  import torch
13
  import torch.nn.functional as F
 
14
  from huggingface_hub import hf_hub_download
15
  from transformers import DynamicCache
16
 
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  in_space = os.getenv("SYSTEM") == "spaces"
23
 
 
24
  @torch.inference_mode()
25
  def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
26
  disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
 
116
  if all(end):
117
  break
118
 
 
119
  def create_msg(name, data):
120
  return {"name": name, "data": data}
121
 
 
122
  def send_msgs(msgs):
123
  return json.dumps(msgs)
124
 
 
125
  def get_duration(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm,
126
  time_sig, key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr,
127
  remove_empty_channels, seed, seed_rand, gen_events, temp, top_p, top_k, allow_cc):
 
130
  t = gen_events // 14
131
  return t + 5
132
 
 
133
  @spaces.GPU(duration=get_duration)
134
  def run(model_name, tab, mid_seq, continuation_state, continuation_select, instruments, drum_kit, bpm, time_sig,
135
  key_sig, mid, midi_events, reduce_cc_st, remap_track_channel, add_default_instr, remove_empty_channels,
 
240
  t = time.time()
241
  yield mid_seq, continuation_state, seed, send_msgs([])
242
 
 
243
  def finish_run(model_name, mid_seq):
244
  if mid_seq is None:
245
  outputs = [None] * OUTPUT_BATCH_SIZE
 
260
  create_msg("visualizer_end", i)]
261
  return *outputs, send_msgs(end_msgs)
262
 
 
263
  def synthesis_task(mid):
264
  return synthesizer.synthesis(MIDI.score2opus(mid))
265
 
266
+
267
+
268
  def render_audio(model_name, mid_seq, should_render_audio):
269
  if (not should_render_audio) or mid_seq is None:
270
  outputs = [None] * OUTPUT_BATCH_SIZE
 
345
  key_signatures = ['C♭', 'A♭m', 'G♭', 'E♭m', 'D♭', 'B♭m', 'A♭', 'Fm', 'E♭', 'Cm', 'B♭', 'Gm', 'F', 'Dm',
346
  'C', 'Am', 'G', 'Em', 'D', 'Bm', 'A', 'F♯m', 'E', 'C♯m', 'B', 'G♯m', 'F♯', 'D♯m', 'C♯', 'A♯m']
347
 
348
+
349
+
350
+
351
+
352
+
353
+ mid = tokenizer.detokenize(mid_seq[i])
354
+ audio_future = thread_pool.submit(synthesis_task, mid)
355
+ audio_futures.append(audio_future)
356
+ for future in audio_futures:
357
+ outputs.append((44100, future.result()))
358
+ if OUTPUT_BATCH_SIZE == 1:
359
+ return outputs[0]
360
+ return tuple(outputs)
361
+
362
+ def undo_continuation(model_name, mid_seq, continuation_state):
363
+ if mid_seq is None or len(continuation_state) < 2:
364
+ return mid_seq, continuation_state, send_msgs([])
365
+ tokenizer = models[model_name].tokenizer
366
+ if isinstance(continuation_state[-1], list):
367
+ mid_seq = continuation_state[-1]
368
+ else:
369
+ mid_seq = [ms[:continuation_state[-1]] for ms in mid_seq]
370
+ continuation_state = continuation_state[:-1]
371
+ end_msgs = [create_msg("progress", [0, 0])]
372
+ for i in range(OUTPUT_BATCH_SIZE):
373
+ events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
374
+ end_msgs += [create_msg("visualizer_clear", [i, tokenizer.version]),
375
+ create_msg("visualizer_append", [i, events]),
376
+ create_msg("visualizer_end", i)]
377
+ return mid_seq, continuation_state, send_msgs(end_msgs)
378
+
379
+ def create_arpeggio_events(chord, pattern, duration=480):
380
+ events = []
381
+ notes = {
382
+ 'C': [60, 64, 67],
383
+ 'D': [62, 66, 69],
384
+ 'Am': [57, 60, 64],
385
+ 'G': [55, 59, 62]
386
+ }
387
+
388
+ for step in pattern:
389
+ note = notes[chord][step]
390
+ events.extend([
391
+ ['note_on', 0, 0, 0, 0, note, 80],
392
+ ['note_off', duration, 0, 0, 0, note, 0]
393
+ ])
394
+
395
+ return events
396
+
397
+ def add_arpeggio_sequence(tokenizer, mid_seq, sequence, pattern):
398
+ events = []
399
+ for chord in sequence:
400
+ events.extend(create_arpeggio_events(chord, pattern))
401
+
402
+ tokens = [tokenizer.event2tokens(event) for event in events]
403
+ mid_seq[0].extend(tokens)
404
+ return mid_seq
405
+
406
+
407
+
408
+ if __name__ == "__main__":
409
  parser = argparse.ArgumentParser()
410
  parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
411
  parser.add_argument("--port", type=int, default=7860, help="gradio server port")
 
456
  load_javascript()
457
  app = gr.Blocks(theme=gr.themes.Soft())
458
  with app:
459
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Midi Composer with Arpeggios</h1>")
 
 
 
 
 
 
 
 
 
 
460
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
461
  js_msg.change(None, [js_msg], [], js="""
462
  (msg_json) =>{
 
475
  input_drum_kit = gr.Dropdown(label="🥁drum kit", choices=list(drum_kits2number.keys()), type="value",
476
  value="None")
477
  input_bpm = gr.Slider(label="BPM (beats per minute, auto if 0)", minimum=0, maximum=255,
478
+ step=1, value=0)
 
479
  input_time_sig = gr.Radio(label="time signature (only for tv2 models)",
480
  value="auto",
481
  choices=["auto", "4/4", "2/4", "3/4", "6/4", "7/4",
482
+ "2/2", "3/2", "4/2", "3/8", "5/8", "6/8", "7/8", "9/8", "12/8"])
 
483
  input_key_sig = gr.Radio(label="key signature (only for tv2 models)",
484
  value="auto",
485
  choices=["auto"] + key_signatures,
486
+ type="index")
487
+
488
+ with gr.Row():
489
+ arpeggio_intro = gr.Button("🎵 Intro Arpeggio", variant="primary")
490
+ arpeggio_verse = gr.Button("🎸 Verse Arpeggio", variant="primary")
491
+ arpeggio_chorus = gr.Button("🎹 Chorus Arpeggio", variant="primary")
492
+ arpeggio_outro = gr.Button("🎷 Outro Arpeggio", variant="primary")
493
+
494
+
495
+
496
  example1 = gr.Examples([
497
  [[], "None"],
498
  [["Acoustic Grand"], "None"],
 
507
  [["Electric Guitar(clean)", "Electric Guitar(muted)", "Overdriven Guitar", "Distortion Guitar",
508
  "Electric Bass(finger)"], "Standard"]
509
  ], [input_instruments, input_drum_kit])
510
+
511
  with gr.TabItem("midi prompt") as tab2:
512
  input_midi = gr.File(label="input midi", file_types=[".midi", ".mid"], type="binary")
513
  input_midi_events = gr.Slider(label="use first n midi events as prompt", minimum=1, maximum=512,
 
521
  input_remove_empty_channels = gr.Checkbox(label="remove channels without notes", value=False)
522
  example2 = gr.Examples([[file, 128] for file in glob.glob("example/*.mid")],
523
  [input_midi, input_midi_events])
524
+
525
  with gr.TabItem("last output prompt") as tab3:
526
  gr.Markdown("Continue generating on the last output.")
527
  input_continuation_select = gr.Radio(label="select output to continue generating", value="all",
 
582
  # queue=False)
583
  undo_btn.click(undo_continuation, [input_model, output_midi_seq, output_continuation_state],
584
  [output_midi_seq, output_continuation_state, js_msg], queue=False)
585
+
586
+
587
+
588
+ def add_intro_arpeggio(model_name, mid_seq):
589
+ tokenizer = models[model_name].tokenizer
590
+ sequence = ['C', 'D', 'Am', 'G']
591
+ pattern = [0, 1, 2, 1] # Root, Third, Fifth, Third
592
+ return add_arpeggio_sequence(tokenizer, mid_seq, sequence, pattern)
593
+
594
+ def add_verse_arpeggio(model_name, mid_seq):
595
+ tokenizer = models[model_name].tokenizer
596
+ sequence = ['D', 'C', 'Am', 'G']
597
+ pattern = [0, 2, 1, 2] # Root, Fifth, Third, Fifth
598
+ return add_arpeggio_sequence(tokenizer, mid_seq, sequence, pattern)
599
+
600
+ def add_chorus_arpeggio(model_name, mid_seq):
601
+ tokenizer = models[model_name].tokenizer
602
+ sequence = ['G', 'D', 'Am', 'C']
603
+ pattern = [0, 1, 2, 1, 0, 2] # Root, Third, Fifth, Third, Root, Fifth
604
+ return add_arpeggio_sequence(tokenizer, mid_seq, sequence, pattern)
605
+
606
+ def add_outro_arpeggio(model_name, mid_seq):
607
+ tokenizer = models[model_name].tokenizer
608
+ sequence = ['Am', 'G', 'D', 'C']
609
+ pattern = [2, 1, 0, 1] # Fifth, Third, Root, Third
610
+ return add_arpeggio_sequence(tokenizer, mid_seq, sequence, pattern)
611
+
612
+ arpeggio_intro.click(add_intro_arpeggio, [input_model, output_midi_seq], output_midi_seq)
613
+ arpeggio_verse.click(add_verse_arpeggio, [input_model, output_midi_seq], output_midi_seq)
614
+ arpeggio_chorus.click(add_chorus_arpeggio, [input_model, output_midi_seq], output_midi_seq)
615
+ arpeggio_outro.click(add_outro_arpeggio, [input_model, output_midi_seq], output_midi_seq)
616
+
617
+
618
  app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
619
  thread_pool.shutdown()