awacke1 commited on
Commit
c43c03a
ยท
verified ยท
1 Parent(s): 30755d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -426
app.py CHANGED
@@ -5,6 +5,7 @@ import glob
5
  import json
6
  import os
7
  import time
 
8
  from concurrent.futures import ThreadPoolExecutor
9
 
10
  import gradio as gr
@@ -23,89 +24,128 @@ in_space = os.getenv("SYSTEM") == "spaces"
23
 
24
  # Chord to emoji mapping
25
  CHORD_EMOJIS = {
26
- 'A': '๐ŸŽธ',
27
- 'Am': '๐ŸŽป',
28
- 'B': '๐ŸŽน',
29
- 'Bm': '๐ŸŽท',
30
- 'C': '๐ŸŽต',
31
- 'Cm': '๐ŸŽถ',
32
- 'D': '๐Ÿฅ',
33
- 'Dm': '๐Ÿช˜',
34
- 'E': '๐ŸŽค',
35
- 'Em': '๐ŸŽง',
36
- 'F': '๐Ÿช•',
37
- 'Fm': '๐ŸŽบ',
38
- 'G': '๐Ÿช—',
39
- 'Gm': '๐ŸŽป'
40
  }
41
 
42
- # Progression patterns
43
- PROGRESSION_PATTERNS = {
44
- "12-bar-blues": ["I", "I", "I", "I", "IV", "IV", "I", "I", "V", "IV", "I", "V"],
45
- "pop-verse": ["I", "V", "vi", "IV"],
46
- "pop-chorus": ["I", "IV", "V", "vi"],
47
- "jazz": ["ii", "V", "I"],
48
- "ballad": ["I", "vi", "IV", "V"]
49
- }
50
-
51
- # Roman numeral to chord offset mapping (in major scale)
52
- ROMAN_TO_OFFSET = {
53
- "I": 0,
54
- "ii": 2,
55
- "iii": 4,
56
- "IV": 5,
57
- "V": 7,
58
- "vi": 9,
59
- "vii": 11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  }
61
 
62
- @torch.inference_mode()
63
- def generate(model: MIDIModel, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
64
- disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
65
- tokenizer = model.tokenizer
66
- if disable_channels is not None:
67
- disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
68
- else:
69
- disable_channels = []
70
- max_token_seq = tokenizer.max_token_seq
71
- if prompt is None:
72
- input_tensor = torch.full((1, max_token_seq), tokenizer.pad_id, dtype=torch.long, device=model.device)
73
- input_tensor[0, 0] = tokenizer.bos_id # bos
74
- input_tensor = input_tensor.unsqueeze(0)
75
- input_tensor = torch.cat([input_tensor] * batch_size, dim=0)
76
- else:
77
- if len(prompt.shape) == 2:
78
- prompt = prompt[None, :]
79
- prompt = np.repeat(prompt, repeats=batch_size, axis=0)
80
- elif prompt.shape[0] == 1:
81
- prompt = np.repeat(prompt, repeats=batch_size, axis=0)
82
- elif len(prompt.shape) != 3 or prompt.shape[0] != batch_size:
83
- raise ValueError(f"invalid shape for prompt, {prompt.shape}")
84
- prompt = prompt[..., :max_token_seq]
85
- if prompt.shape[-1] < max_token_seq:
86
- prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
87
- mode="constant", constant_values=tokenizer.pad_id)
88
- input_tensor = torch.from_numpy(prompt).to(dtype=torch.long, device=model.device)
89
-
90
- # Basic generation logic - simplified for brevity
91
- # In a real implementation, you'd keep more of the original generation code
92
- tokens_generated = []
93
- cur_len = input_tensor.shape[1]
94
- while cur_len < max_len:
95
- # Generate next token sequence
96
- with torch.no_grad():
97
- # This is simplified - actual implementation would use the model logic
98
- next_token_seq = torch.ones((batch_size, 1, max_token_seq), dtype=torch.long, device=model.device)
99
 
100
- tokens_generated.append(next_token_seq)
101
- input_tensor = torch.cat([input_tensor, next_token_seq[:, 0].unsqueeze(1)], dim=1)
102
- cur_len += 1
103
-
104
- yield next_token_seq[:, 0].cpu().numpy()
105
-
106
- # Exit condition (simplified)
107
- if cur_len >= max_len:
108
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def create_msg(name, data):
111
  return {"name": name, "data": data}
@@ -113,170 +153,67 @@ def create_msg(name, data):
113
  def send_msgs(msgs):
114
  return json.dumps(msgs)
115
 
116
- def get_chord_progressions(root_chord, progression_type):
117
- """Convert a roman numeral progression to actual chords starting from root"""
118
- major_scale = ["C", "D", "E", "F", "G", "A", "B"]
119
- minor_scale = ["Cm", "Dm", "Em", "Fm", "Gm", "Am", "Bm"]
120
-
121
- # Find root index in major scale
122
- root_idx = 0
123
- for i, chord in enumerate(major_scale):
124
- if chord == root_chord:
125
- root_idx = i
126
- break
127
-
128
- # Get progression pattern
129
- pattern = PROGRESSION_PATTERNS.get(progression_type, PROGRESSION_PATTERNS["pop-verse"])
130
-
131
- # Generate actual chord progression
132
- progression = []
133
- for numeral in pattern:
134
- is_minor = numeral.islower()
135
- # Remove m if present in the numeral
136
- base_numeral = numeral.replace("m", "")
137
- # Get offset
138
- offset = ROMAN_TO_OFFSET.get(base_numeral, 0)
139
-
140
- # Calculate actual chord index
141
- chord_idx = (root_idx + offset) % 7
142
-
143
- # Add chord to progression
144
- if is_minor:
145
- progression.append(minor_scale[chord_idx])
146
- else:
147
- progression.append(major_scale[chord_idx])
148
-
149
- return progression
150
-
151
  def create_chord_events(chord, duration=480, velocity=80):
152
  """Create MIDI events for a chord"""
153
  events = []
154
- chord_notes = {
155
- 'C': [60, 64, 67], # C major (C, E, G)
156
- 'Cm': [60, 63, 67], # C minor (C, Eb, G)
157
- 'D': [62, 66, 69], # D major (D, F#, A)
158
- 'Dm': [62, 65, 69], # D minor (D, F, A)
159
- 'E': [64, 68, 71], # E major (E, G#, B)
160
- 'Em': [64, 67, 71], # E minor (E, G, B)
161
- 'F': [65, 69, 72], # F major (F, A, C)
162
- 'Fm': [65, 68, 72], # F minor (F, Ab, C)
163
- 'G': [67, 71, 74], # G major (G, B, D)
164
- 'Gm': [67, 70, 74], # G minor (G, Bb, D)
165
- 'A': [69, 73, 76], # A major (A, C#, E)
166
- 'Am': [69, 72, 76], # A minor (A, C, E)
167
- 'B': [71, 75, 78], # B major (B, D#, F#)
168
- 'Bm': [71, 74, 78] # B minor (B, D, F#)
169
- }
170
 
171
- if chord in chord_notes:
172
- notes = chord_notes[chord]
173
  # Note on events
174
  for note in notes:
175
  events.append(['note_on', 0, 0, 0, 0, note, velocity])
176
 
177
- # Note off events
178
  for note in notes:
179
  events.append(['note_off', duration, 0, 0, 0, note, 0])
180
 
181
  return events
182
 
183
- def create_chord_sequence(tokenizer, chords, pattern="simple", duration=480):
184
- """Create a sequence of chord events with a pattern"""
185
- events = []
186
-
187
- for chord in chords:
188
- if pattern == "simple":
189
- # Just play the chord
190
- events.extend(create_chord_events(chord, duration))
191
- elif pattern == "arpeggio":
192
- # Arpeggiate the chord
193
- chord_notes = {
194
- 'C': [60, 64, 67],
195
- 'Cm': [60, 63, 67],
196
- 'D': [62, 66, 69],
197
- 'Dm': [62, 65, 69],
198
- 'E': [64, 68, 71],
199
- 'Em': [64, 67, 71],
200
- 'F': [65, 69, 72],
201
- 'Fm': [65, 68, 72],
202
- 'G': [67, 71, 74],
203
- 'Gm': [67, 70, 74],
204
- 'A': [69, 73, 76],
205
- 'Am': [69, 72, 76],
206
- 'B': [71, 75, 78],
207
- 'Bm': [71, 74, 78]
208
- }
209
-
210
- if chord in chord_notes:
211
- notes = chord_notes[chord]
212
- for i, note in enumerate(notes):
213
- events.append(['note_on', 0 if i == 0 else duration//4, 0, 0, 0, note, 80])
214
- events.append(['note_off', duration//4, 0, 0, 0, note, 0])
215
-
216
- # Add final pause to complete the bar
217
- events.append(['note_on', 0, 0, 0, 0, notes[0], 0])
218
- events.append(['note_off', duration//4, 0, 0, 0, notes[0], 0])
219
-
220
- # Convert events to tokens
221
- tokens = []
222
- for event in events:
223
- tokens.append(tokenizer.event2tokens(event))
224
-
225
- return tokens
226
 
227
- def add_chord_sequence(model_name, mid_seq, root_chord="C", progression_type="pop-verse", pattern="simple"):
228
- """Add a chord sequence to the MIDI sequence"""
229
- tokenizer = models[model_name].tokenizer
230
-
231
- # Generate chord progression
232
- chord_progression = create_chord_progressions(root_chord, progression_type)
233
-
234
- # Create chord sequence tokens
235
- tokens = create_chord_sequence(tokenizer, chord_progression, pattern)
236
-
237
- # Add tokens to sequence
238
- if mid_seq is None:
239
- mid_seq = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
240
- mid_seq = [mid_seq] * OUTPUT_BATCH_SIZE
241
-
242
- # Add tokens to the first sequence
243
- mid_seq[0].extend(tokens)
244
-
245
- return mid_seq
246
 
247
- def create_song_structure(model_name, root_chord="C"):
248
- """Create a complete song structure with verse, chorus, etc."""
249
- tokenizer = models[model_name].tokenizer
250
-
251
- # Initialize sequence
252
- mid_seq = [[tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)]
253
- mid_seq = [mid_seq] * OUTPUT_BATCH_SIZE
254
-
255
- # Add intro
256
- intro_tokens = create_chord_sequence(tokenizer,
257
- create_chord_progressions(root_chord, "pop-verse"),
258
- "arpeggio")
259
- mid_seq[0].extend(intro_tokens)
260
-
261
- # Add verse
262
- verse_tokens = create_chord_sequence(tokenizer,
263
- create_chord_progressions(root_chord, "pop-verse"),
264
- "simple")
265
- mid_seq[0].extend(verse_tokens)
266
-
267
- # Add chorus
268
- chorus_tokens = create_chord_sequence(tokenizer,
269
- create_chord_progressions(root_chord, "pop-chorus"),
270
- "simple")
271
- mid_seq[0].extend(chorus_tokens)
272
-
273
- # Add outro
274
- outro_tokens = create_chord_sequence(tokenizer,
275
- create_chord_progressions(root_chord, "ballad"),
276
- "arpeggio")
277
- mid_seq[0].extend(outro_tokens)
278
 
279
- return mid_seq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  def load_javascript(dir="javascript"):
282
  scripts_list = glob.glob(f"{dir}/*.js")
@@ -298,27 +235,19 @@ def load_javascript(dir="javascript"):
298
 
299
  gr.routes.templates.TemplateResponse = template_response
300
 
301
- def render_audio(model_name, mid_seq, should_render_audio):
302
- if (not should_render_audio) or mid_seq is None:
303
- outputs = [None] * OUTPUT_BATCH_SIZE
304
- return tuple(outputs)
305
- tokenizer = models[model_name].tokenizer
306
- outputs = []
307
- if not os.path.exists("outputs"):
308
- os.mkdir("outputs")
309
- audio_futures = []
310
- for i in range(OUTPUT_BATCH_SIZE):
311
- mid = tokenizer.detokenize(mid_seq[i])
312
- audio_future = thread_pool.submit(synthesis_task, mid)
313
- audio_futures.append(audio_future)
314
- for future in audio_futures:
315
- outputs.append((44100, future.result()))
316
- if OUTPUT_BATCH_SIZE == 1:
317
- return outputs[0]
318
- return tuple(outputs)
319
-
320
- def synthesis_task(mid):
321
- return synthesizer.synthesis(MIDI.score2opus(mid))
322
 
323
  if __name__ == "__main__":
324
  parser = argparse.ArgumentParser()
@@ -330,29 +259,56 @@ if __name__ == "__main__":
330
  opt = parser.parse_args()
331
  OUTPUT_BATCH_SIZE = opt.batch
332
 
 
 
 
333
  # Initialize models (simplified version)
334
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
335
  thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
336
  synthesizer = MidiSynthesizer(soundfont_path)
337
 
338
- models_info = {
339
- "generic pretrain model (tv2o-medium) by skytnt": [
340
- "skytnt/midi-model-tv2o-medium", {}
341
- ]
342
- }
343
 
344
- models = {}
345
- # Initialize models (simplified)
346
- for name, (repo_id, loras) in models_info.items():
347
- model = MIDIModel.from_pretrained(repo_id)
348
- model.to(device="cpu", dtype=torch.float32)
349
- models[name] = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  load_javascript()
352
- app = gr.Blocks(theme=gr.themes.Soft())
353
 
354
  with app:
355
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>๐ŸŽต Chord-Emoji MIDI Composer ๐ŸŽต</h1>")
356
 
357
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
358
  js_msg.change(None, [js_msg], [], js="""
@@ -363,183 +319,117 @@ if __name__ == "__main__":
363
  }
364
  """)
365
 
366
- input_model = gr.Dropdown(label="Select Model", choices=list(models.keys()),
367
- type="value", value=list(models.keys())[0])
368
-
369
- # Main chord progression section
370
- with gr.Tabs():
371
- with gr.TabItem("Chord Progressions") as tab1:
372
- with gr.Row():
373
- root_chord = gr.Dropdown(label="Root Chord", choices=["C", "D", "E", "F", "G", "A", "B"],
374
- value="C")
375
- progression_type = gr.Dropdown(label="Progression Type",
376
- choices=list(PROGRESSION_PATTERNS.keys()),
377
- value="pop-verse")
378
-
379
- # Emoji-Chord Button Grid - Create a 2x7 grid of chord buttons
380
- gr.Markdown("### Chord Buttons - Click to Add Individual Chords")
381
-
382
- with gr.Row():
383
- chord_buttons_major = []
384
- for chord in ["C", "D", "E", "F", "G", "A", "B"]:
385
- emoji = CHORD_EMOJIS.get(chord, "๐ŸŽต")
386
- btn = gr.Button(f"{emoji} {chord}", size="sm")
387
- chord_buttons_major.append((chord, btn))
388
-
389
- with gr.Row():
390
- chord_buttons_minor = []
391
- for chord in ["Cm", "Dm", "Em", "Fm", "Gm", "Am", "Bm"]:
392
- emoji = CHORD_EMOJIS.get(chord, "๐ŸŽต")
393
- btn = gr.Button(f"{emoji} {chord}", size="sm")
394
- chord_buttons_minor.append((chord, btn))
395
-
396
- # Song structure buttons
397
- gr.Markdown("### Song Structure Patterns - Click to Add a Pattern")
398
- with gr.Row():
399
- intro_btn = gr.Button("๐ŸŽต Intro", variant="primary")
400
- verse_btn = gr.Button("๐ŸŽธ Verse", variant="primary")
401
- chorus_btn = gr.Button("๐ŸŽน Chorus", variant="primary")
402
- bridge_btn = gr.Button("๐ŸŽท Bridge", variant="primary")
403
- outro_btn = gr.Button("๐Ÿช— Outro", variant="primary")
404
-
405
- with gr.Row():
406
- blues_btn = gr.Button("๐ŸŽบ 12-Bar Blues", variant="primary")
407
- jazz_btn = gr.Button("๐ŸŽป Jazz Pattern", variant="primary")
408
- ballad_btn = gr.Button("๐ŸŽค Ballad", variant="primary")
409
-
410
- with gr.Row():
411
- pattern_type = gr.Radio(label="Pattern Style",
412
- choices=["simple", "arpeggio"],
413
- value="simple")
414
-
415
- with gr.Row():
416
- clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear Sequence", variant="secondary")
417
- play_btn = gr.Button("โ–ถ๏ธ Play Current Sequence", variant="primary")
418
-
419
- with gr.TabItem("Custom MIDI Settings") as tab2:
420
- input_instruments = gr.Dropdown(label="๐Ÿช— Instruments (auto if empty)",
421
- choices=["Acoustic Grand", "Electric Piano", "Violin", "Guitar"],
422
- multiselect=True, type="value")
423
- input_bpm = gr.Slider(label="BPM (beats per minute)", minimum=60, maximum=180,
424
- step=1, value=120)
425
-
426
- # Output section
427
- output_midi_seq = gr.State()
428
- output_continuation_state = gr.State([0])
429
-
430
- midi_outputs = []
431
- audio_outputs = []
432
-
433
- with gr.Tabs(elem_id="output_tabs"):
434
- for i in range(OUTPUT_BATCH_SIZE):
435
- with gr.TabItem(f"Output {i + 1}") as tab:
436
- output_midi_visualizer = gr.HTML(elem_id=f"midi_visualizer_container_{i}")
437
- output_audio = gr.Audio(label="Output Audio", format="mp3", elem_id=f"midi_audio_{i}")
438
- output_midi = gr.File(label="Output MIDI", file_types=[".mid"])
439
- midi_outputs.append(output_midi)
440
- audio_outputs.append(output_audio)
441
-
442
- # Connect chord buttons to functions
443
- for chord, btn in chord_buttons_major + chord_buttons_minor:
444
- btn.click(
445
- fn=lambda chord=chord, m=input_model, seq=output_midi_seq, pt=pattern_type:
446
- add_chord_sequence(m, seq, chord, "ballad", pt.value),
447
- inputs=[input_model, output_midi_seq, pattern_type],
448
- outputs=[output_midi_seq]
449
- )
450
-
451
- # Connect song structure buttons
452
- intro_btn.click(
453
- fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord:
454
- add_chord_sequence(m, seq, rc.value, "pop-verse", "arpeggio"),
455
- inputs=[input_model, output_midi_seq, root_chord],
456
- outputs=[output_midi_seq]
457
- )
458
-
459
- verse_btn.click(
460
- fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord:
461
- add_chord_sequence(m, seq, rc.value, "pop-verse", "simple"),
462
- inputs=[input_model, output_midi_seq, root_chord],
463
- outputs=[output_midi_seq]
464
- )
465
 
466
- chorus_btn.click(
467
- fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord:
468
- add_chord_sequence(m, seq, rc.value, "pop-chorus", "simple"),
469
- inputs=[input_model, output_midi_seq, root_chord],
470
- outputs=[output_midi_seq]
471
- )
472
 
473
- bridge_btn.click(
474
- fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord:
475
- add_chord_sequence(m, seq, rc.value, "jazz", "simple"),
476
- inputs=[input_model, output_midi_seq, root_chord],
477
- outputs=[output_midi_seq]
478
- )
479
 
480
- outro_btn.click(
481
- fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord:
482
- add_chord_sequence(m, seq, rc.value, "ballad", "arpeggio"),
483
- inputs=[input_model, output_midi_seq, root_chord],
484
- outputs=[output_midi_seq]
485
- )
486
 
487
- blues_btn.click(
488
- fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord:
489
- add_chord_sequence(m, seq, rc.value, "12-bar-blues", "simple"),
490
- inputs=[input_model, output_midi_seq, root_chord],
491
- outputs=[output_midi_seq]
492
- )
493
 
494
- jazz_btn.click(
495
- fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord:
496
- add_chord_sequence(m, seq, rc.value, "jazz", "simple"),
497
- inputs=[input_model, output_midi_seq, root_chord],
498
- outputs=[output_midi_seq]
499
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
- ballad_btn.click(
502
- fn=lambda m=input_model, seq=output_midi_seq, rc=root_chord:
503
- add_chord_sequence(m, seq, rc.value, "ballad", "simple"),
504
- inputs=[input_model, output_midi_seq, root_chord],
505
- outputs=[output_midi_seq]
506
  )
507
 
508
- # Clear and play buttons
509
- clear_btn.click(
510
- fn=lambda m=input_model: [[models[m].tokenizer.bos_id] +
511
- [models[m].tokenizer.pad_id] * (models[m].tokenizer.max_token_seq - 1)] * OUTPUT_BATCH_SIZE,
512
- inputs=[input_model],
513
- outputs=[output_midi_seq]
514
  )
515
 
516
- # Play functionality - render audio and visualize
517
- def prepare_playback(model_name, mid_seq):
518
- if mid_seq is None:
519
- return mid_seq, [], send_msgs([])
520
-
521
- tokenizer = models[model_name].tokenizer
522
- msgs = []
523
-
524
- for i in range(OUTPUT_BATCH_SIZE):
525
- events = [tokenizer.tokens2event(tokens) for tokens in mid_seq[i]]
526
- msgs += [
527
- create_msg("visualizer_clear", [i, tokenizer.version]),
528
- create_msg("visualizer_append", [i, events]),
529
- create_msg("visualizer_end", i)
530
- ]
531
-
532
- return mid_seq, mid_seq, send_msgs(msgs)
533
-
534
- play_btn.click(
535
- fn=prepare_playback,
536
- inputs=[input_model, output_midi_seq],
537
- outputs=[output_midi_seq, output_continuation_state, js_msg]
538
  ).then(
539
- fn=render_audio,
540
- inputs=[input_model, output_midi_seq, gr.State(True)],
541
- outputs=audio_outputs
542
  )
543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
545
- thread_pool.shutdown()
 
 
 
5
  import json
6
  import os
7
  import time
8
+ import rtmidi
9
  from concurrent.futures import ThreadPoolExecutor
10
 
11
  import gradio as gr
 
24
 
25
  # Chord to emoji mapping
26
  CHORD_EMOJIS = {
27
+ 'C': '๐ŸŽต', 'Cm': '๐ŸŽถ', 'C7': '๐ŸŽผ', 'Cmaj7': '๐ŸŽน', 'Cm7': '๐ŸŽป',
28
+ 'D': '๐Ÿฅ', 'Dm': '๐Ÿช˜', 'D7': '๐ŸŽท', 'Dmaj7': '๐ŸŽบ', 'Dm7': '๐Ÿช•',
29
+ 'E': '๐ŸŽธ', 'Em': '๐ŸŽป', 'E7': '๐ŸŽต', 'Emaj7': '๐ŸŽถ', 'Em7': '๐ŸŽผ',
30
+ 'F': '๐ŸŽน', 'Fm': '๐ŸŽธ', 'F7': '๐ŸŽป', 'Fmaj7': '๐ŸŽท', 'Fm7': '๐ŸŽบ',
31
+ 'G': '๐Ÿช—', 'Gm': '๐ŸŽต', 'G7': '๐ŸŽถ', 'Gmaj7': '๐ŸŽผ', 'Gm7': '๐ŸŽน',
32
+ 'A': '๐ŸŽธ', 'Am': '๐ŸŽป', 'A7': '๐ŸŽท', 'Amaj7': '๐ŸŽบ', 'Am7': '๐Ÿช•',
33
+ 'B': '๐ŸŽต', 'Bm': '๐ŸŽถ', 'B7': '๐ŸŽผ', 'Bmaj7': '๐ŸŽน', 'Bm7': '๐ŸŽป'
 
 
 
 
 
 
 
34
  }
35
 
36
+ # Chord note definitions (MIDI note numbers)
37
+ CHORD_NOTES = {
38
+ 'C': [60, 64, 67], # C major (C, E, G)
39
+ 'Cm': [60, 63, 67], # C minor (C, Eb, G)
40
+ 'C7': [60, 64, 67, 70], # C7 (C, E, G, Bb)
41
+ 'Cmaj7': [60, 64, 67, 71], # Cmaj7 (C, E, G, B)
42
+ 'Cm7': [60, 63, 67, 70], # Cm7 (C, Eb, G, Bb)
43
+
44
+ 'D': [62, 66, 69], # D major (D, F#, A)
45
+ 'Dm': [62, 65, 69], # D minor (D, F, A)
46
+ 'D7': [62, 66, 69, 72], # D7 (D, F#, A, C)
47
+ 'Dmaj7': [62, 66, 69, 73], # Dmaj7 (D, F#, A, C#)
48
+ 'Dm7': [62, 65, 69, 72], # Dm7 (D, F, A, C)
49
+
50
+ 'E': [64, 68, 71], # E major (E, G#, B)
51
+ 'Em': [64, 67, 71], # E minor (E, G, B)
52
+ 'E7': [64, 68, 71, 74], # E7 (E, G#, B, D)
53
+ 'Emaj7': [64, 68, 71, 75], # Emaj7 (E, G#, B, D#)
54
+ 'Em7': [64, 67, 71, 74], # Em7 (E, G, B, D)
55
+
56
+ 'F': [65, 69, 72], # F major (F, A, C)
57
+ 'Fm': [65, 68, 72], # F minor (F, Ab, C)
58
+ 'F7': [65, 69, 72, 75], # F7 (F, A, C, Eb)
59
+ 'Fmaj7': [65, 69, 72, 76], # Fmaj7 (F, A, C, E)
60
+ 'Fm7': [65, 68, 72, 75], # Fm7 (F, Ab, C, Eb)
61
+
62
+ 'G': [67, 71, 74], # G major (G, B, D)
63
+ 'Gm': [67, 70, 74], # G minor (G, Bb, D)
64
+ 'G7': [67, 71, 74, 77], # G7 (G, B, D, F)
65
+ 'Gmaj7': [67, 71, 74, 78], # Gmaj7 (G, B, D, F#)
66
+ 'Gm7': [67, 70, 74, 77], # Gm7 (G, Bb, D, F)
67
+
68
+ 'A': [69, 73, 76], # A major (A, C#, E)
69
+ 'Am': [69, 72, 76], # A minor (A, C, E)
70
+ 'A7': [69, 73, 76, 79], # A7 (A, C#, E, G)
71
+ 'Amaj7': [69, 73, 76, 80], # Amaj7 (A, C#, E, G#)
72
+ 'Am7': [69, 72, 76, 79], # Am7 (A, C, E, G)
73
+
74
+ 'B': [71, 75, 78], # B major (B, D#, F#)
75
+ 'Bm': [71, 74, 78], # B minor (B, D, F#)
76
+ 'B7': [71, 75, 78, 81], # B7 (B, D#, F#, A)
77
+ 'Bmaj7': [71, 75, 78, 82], # Bmaj7 (B, D#, F#, A#)
78
+ 'Bm7': [71, 74, 78, 81] # Bm7 (B, D, F#, A)
79
  }
80
 
81
+ # MIDI device manager
82
+ class MIDIDeviceManager:
83
+ def __init__(self):
84
+ self.midi_out = rtmidi.MidiOut()
85
+ self.available_ports = self.midi_out.get_ports()
86
+ self.current_port = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ def get_available_devices(self):
89
+ """Return list of available MIDI output devices"""
90
+ self.available_ports = self.midi_out.get_ports()
91
+ return self.available_ports
92
+
93
+ def open_port(self, port_index):
94
+ """Open a MIDI port by index"""
95
+ if 0 <= port_index < len(self.available_ports):
96
+ if self.current_port is not None:
97
+ self.midi_out.close_port()
98
+ self.midi_out.open_port(port_index)
99
+ self.current_port = port_index
100
+ return True
101
+ return False
102
+
103
+ def send_note_on(self, note, velocity=64, channel=0):
104
+ """Send MIDI note on message"""
105
+ if self.current_port is not None:
106
+ message = [0x90 + channel, note, velocity]
107
+ self.midi_out.send_message(message)
108
+
109
+ def send_note_off(self, note, velocity=0, channel=0):
110
+ """Send MIDI note off message"""
111
+ if self.current_port is not None:
112
+ message = [0x80 + channel, note, velocity]
113
+ self.midi_out.send_message(message)
114
+
115
+ def send_program_change(self, program, channel=0):
116
+ """Send MIDI program change message"""
117
+ if self.current_port is not None:
118
+ message = [0xC0 + channel, program]
119
+ self.midi_out.send_message(message)
120
+
121
+ def play_chord(self, chord_name, velocity=80, channel=0, duration=None):
122
+ """Play a chord by name with optional automatic release"""
123
+ if chord_name in CHORD_NOTES:
124
+ notes = CHORD_NOTES[chord_name]
125
+ for note in notes:
126
+ self.send_note_on(note, velocity, channel)
127
+
128
+ if duration is not None:
129
+ # Automatic note off after duration
130
+ time.sleep(duration)
131
+ for note in notes:
132
+ self.send_note_off(note, 0, channel)
133
+
134
+ def release_chord(self, chord_name, channel=0):
135
+ """Release all notes in a chord"""
136
+ if chord_name in CHORD_NOTES:
137
+ notes = CHORD_NOTES[chord_name]
138
+ for note in notes:
139
+ self.send_note_off(note, 0, channel)
140
+
141
+ def close(self):
142
+ """Close current MIDI port"""
143
+ if self.current_port is not None:
144
+ self.midi_out.close_port()
145
+ self.current_port = None
146
+
147
+ # Global MIDI manager
148
+ midi_manager = MIDIDeviceManager()
149
 
150
  def create_msg(name, data):
151
  return {"name": name, "data": data}
 
153
  def send_msgs(msgs):
154
  return json.dumps(msgs)
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def create_chord_events(chord, duration=480, velocity=80):
157
  """Create MIDI events for a chord"""
158
  events = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ if chord in CHORD_NOTES:
161
+ notes = CHORD_NOTES[chord]
162
  # Note on events
163
  for note in notes:
164
  events.append(['note_on', 0, 0, 0, 0, note, velocity])
165
 
166
+ # Note off events after specified duration
167
  for note in notes:
168
  events.append(['note_off', duration, 0, 0, 0, note, 0])
169
 
170
  return events
171
 
172
+ def add_chord_to_queue(chord_name, chord_queue, max_queue_size=8):
173
+ """Add a chord to the playback queue"""
174
+ if len(chord_queue) >= max_queue_size:
175
+ chord_queue.pop(0) # Remove oldest chord
176
+ chord_queue.append(chord_name)
177
+ return chord_queue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ def play_chord_on_device(chord_name, midi_device_index):
180
+ """Play a chord on the selected MIDI device"""
181
+ if midi_device_index is not None and midi_device_index >= 0:
182
+ midi_manager.open_port(midi_device_index)
183
+ midi_manager.play_chord(chord_name, duration=0.5)
184
+ return chord_name
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ def play_chord_sequence(chord_queue, midi_device_index, tempo=120):
187
+ """Play a sequence of chords at the specified tempo"""
188
+ if midi_device_index is not None and midi_device_index >= 0:
189
+ # Calculate timing based on tempo (beats per minute)
190
+ beat_duration = 60.0 / tempo # seconds per beat
191
+
192
+ midi_manager.open_port(midi_device_index)
193
+
194
+ for chord in chord_queue:
195
+ midi_manager.play_chord(chord, duration=beat_duration)
196
+ # Add a small gap between chords
197
+ time.sleep(0.05)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ return chord_queue
200
+
201
+ def refresh_midi_devices():
202
+ """Refresh the list of available MIDI devices"""
203
+ return gr.Dropdown.update(choices=midi_manager.get_available_devices())
204
+
205
+ def hf_hub_download_retry(repo_id, filename):
206
+ print(f"downloading {repo_id} {filename}")
207
+ retry = 0
208
+ err = None
209
+ while retry < 30:
210
+ try:
211
+ return hf_hub_download(repo_id=repo_id, filename=filename)
212
+ except Exception as e:
213
+ err = e
214
+ retry += 1
215
+ if err:
216
+ raise err
217
 
218
  def load_javascript(dir="javascript"):
219
  scripts_list = glob.glob(f"{dir}/*.js")
 
235
 
236
  gr.routes.templates.TemplateResponse = template_response
237
 
238
+ def create_virtual_keyboard(chord_types):
239
+ """Create virtual keyboard buttons organized by root note and chord type"""
240
+ root_notes = ['C', 'D', 'E', 'F', 'G', 'A', 'B']
241
+ buttons = {}
242
+
243
+ for root in root_notes:
244
+ buttons[root] = {}
245
+ for chord_type in chord_types:
246
+ chord_name = f"{root}{chord_type}"
247
+ emoji = CHORD_EMOJIS.get(chord_name, "๐ŸŽต")
248
+ buttons[root][chord_type] = (chord_name, emoji)
249
+
250
+ return buttons
 
 
 
 
 
 
 
 
251
 
252
  if __name__ == "__main__":
253
  parser = argparse.ArgumentParser()
 
259
  opt = parser.parse_args()
260
  OUTPUT_BATCH_SIZE = opt.batch
261
 
262
+ # Initialize MIDI device manager
263
+ midi_manager = MIDIDeviceManager()
264
+
265
  # Initialize models (simplified version)
266
  soundfont_path = hf_hub_download_retry(repo_id="skytnt/midi-model", filename="soundfont.sf2")
267
  thread_pool = ThreadPoolExecutor(max_workers=OUTPUT_BATCH_SIZE)
268
  synthesizer = MidiSynthesizer(soundfont_path)
269
 
270
+ # Define chord types to use in the virtual keyboard
271
+ chord_types = ['', 'm', '7', 'maj7', 'm7']
272
+
273
+ # Create virtual keyboard structure
274
+ keyboard = create_virtual_keyboard(chord_types)
275
 
276
+ # Define CSS for the virtual keyboard
277
+ keyboard_css = """
278
+ .chord-button {
279
+ margin: 4px;
280
+ min-width: 80px;
281
+ height: 60px;
282
+ font-size: 18px;
283
+ font-weight: bold;
284
+ border-radius: 8px;
285
+ transition: all 0.2s;
286
+ }
287
+ .chord-button:active {
288
+ transform: scale(0.95);
289
+ }
290
+ .chord-queue {
291
+ padding: 10px;
292
+ background: #f5f5f5;
293
+ border-radius: 8px;
294
+ min-height: 50px;
295
+ font-size: 16px;
296
+ margin-bottom: 15px;
297
+ }
298
+ .root-c { background-color: #FFCDD2; }
299
+ .root-d { background-color: #F8BBD0; }
300
+ .root-e { background-color: #E1BEE7; }
301
+ .root-f { background-color: #D1C4E9; }
302
+ .root-g { background-color: #C5CAE9; }
303
+ .root-a { background-color: #BBDEFB; }
304
+ .root-b { background-color: #B3E5FC; }
305
+ """
306
 
307
  load_javascript()
308
+ app = gr.Blocks(theme=gr.themes.Soft(), css=keyboard_css)
309
 
310
  with app:
311
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>๐ŸŽต Real-Time MIDI Chord Keyboard ๐ŸŽต</h1>")
312
 
313
  js_msg = gr.Textbox(elem_id="msg_receiver", visible=False)
314
  js_msg.change(None, [js_msg], [], js="""
 
319
  }
320
  """)
321
 
322
+ # MIDI Device Settings
323
+ with gr.Row():
324
+ with gr.Column(scale=3):
325
+ midi_device = gr.Dropdown(label="MIDI Output Device",
326
+ choices=midi_manager.get_available_devices(),
327
+ type="index")
328
+ refresh_button = gr.Button("๐Ÿ”„ Refresh MIDI Devices")
329
+
330
+ with gr.Column(scale=1):
331
+ tempo = gr.Slider(label="Tempo (BPM)",
332
+ minimum=40,
333
+ maximum=200,
334
+ value=120,
335
+ step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
+ # Chord Queue Display
338
+ chord_queue = gr.State([])
339
+ queue_display = gr.Markdown("### Current Chord Queue\n*No chords in queue*",
340
+ elem_classes=["chord-queue"])
 
 
341
 
342
+ # Play queue button
343
+ play_queue_button = gr.Button("โ–ถ๏ธ Play Chord Sequence", variant="primary", size="lg")
 
 
 
 
344
 
345
+ # Clear queue button
346
+ clear_queue_button = gr.Button("๐Ÿ—‘๏ธ Clear Queue", variant="secondary")
 
 
 
 
347
 
348
+ # Virtual Keyboard - Create sections for each root note
349
+ gr.Markdown("## Virtual Chord Keyboard")
 
 
 
 
350
 
351
+ for root in ['C', 'D', 'E', 'F', 'G', 'A', 'B']:
352
+ with gr.Row():
353
+ gr.Markdown(f"### {root}")
354
+ for chord_type in chord_types:
355
+ chord_name, emoji = keyboard[root][chord_type]
356
+ display_name = chord_name if chord_type == '' else chord_name
357
+ button = gr.Button(f"{emoji} {display_name}",
358
+ elem_classes=[f"chord-button root-{root.lower()}"])
359
+
360
+ # Connect the button to add chord to queue and play it immediately
361
+ button.click(
362
+ fn=play_chord_on_device,
363
+ inputs=[gr.State(chord_name), midi_device],
364
+ outputs=None
365
+ ).then(
366
+ fn=add_chord_to_queue,
367
+ inputs=[gr.State(chord_name), chord_queue],
368
+ outputs=[chord_queue]
369
+ ).then(
370
+ fn=lambda q: f"### Current Chord Queue\n" + " โ†’ ".join(q) if q else "*No chords in queue*",
371
+ inputs=[chord_queue],
372
+ outputs=[queue_display]
373
+ )
374
 
375
+ # Connect refresh button
376
+ refresh_button.click(
377
+ fn=refresh_midi_devices,
378
+ inputs=None,
379
+ outputs=[midi_device]
380
  )
381
 
382
+ # Connect play queue button
383
+ play_queue_button.click(
384
+ fn=play_chord_sequence,
385
+ inputs=[chord_queue, midi_device, tempo],
386
+ outputs=[chord_queue]
 
387
  )
388
 
389
+ # Connect clear queue button
390
+ clear_queue_button.click(
391
+ fn=lambda: [],
392
+ inputs=None,
393
+ outputs=[chord_queue]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  ).then(
395
+ fn=lambda: "### Current Chord Queue\n*No chords in queue*",
396
+ inputs=None,
397
+ outputs=[queue_display]
398
  )
399
 
400
+ # MIDI Generation Settings (for advanced users)
401
+ with gr.Accordion("Advanced MIDI Settings", open=False):
402
+ with gr.Row():
403
+ midi_channel = gr.Slider(label="MIDI Channel",
404
+ minimum=0,
405
+ maximum=15,
406
+ value=0,
407
+ step=1)
408
+
409
+ instrument = gr.Dropdown(label="Instrument",
410
+ choices=[(f"{i}: {name}", i) for i, name in enumerate([
411
+ "Acoustic Grand Piano", "Bright Acoustic Piano", "Electric Grand Piano",
412
+ "Honky-tonk Piano", "Electric Piano 1", "Electric Piano 2", "Harpsichord",
413
+ "Clavinet", "Celesta", "Glockenspiel", "Music Box", "Vibraphone",
414
+ "Marimba", "Xylophone", "Tubular Bells", "Dulcimer"
415
+ ])],
416
+ value=0)
417
+
418
+ velocity = gr.Slider(label="Velocity",
419
+ minimum=1,
420
+ maximum=127,
421
+ value=80,
422
+ step=1)
423
+
424
+ # Program change button
425
+ program_change_button = gr.Button("Send Program Change")
426
+ program_change_button.click(
427
+ fn=lambda inst, chan: midi_manager.send_program_change(inst, chan),
428
+ inputs=[instrument, midi_channel],
429
+ outputs=None
430
+ )
431
+
432
  app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True, ssr_mode=False)
433
+
434
+ # Clean up MIDI connections when the app closes
435
+ midi_manager.close()