avans06 commited on
Commit
2dbdd2e
·
1 Parent(s): 0f9efaa

feat(separation): Implement advanced multi-stem separation and processing

Browse files

This commit significantly enhances the audio separation capabilities by exposing the full 4-stem power of the Demucs model (vocals, drums, bass, other), providing users with granular control over the transcription and audio merging pipeline.

Users can now:
- Choose between a simple 'Accompaniment' mode or an advanced mode to control each instrumental stem.
- Select multiple stems to be transcribed and automatically merged into a single MIDI file.
- Re-merge any of the original audio stems into the final rendered track.
- The UI dynamically adapts to the selected mode for a cleaner user experience.

Files changed (1) hide show
  1. app.py +221 -140
app.py CHANGED
@@ -104,9 +104,23 @@ class AppParameters:
104
  # Global Settings
105
  s8bit_preset_selector: str = "Custom"
106
  separate_vocals: bool = False
107
- remerge_vocals: bool = False
108
- transcription_target: str = "Transcribe Music (Accompaniment)"
109
- transcribe_both_stems: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  enable_stereo_processing: bool = False
111
  transcription_method: str = "General Purpose"
112
 
@@ -1333,10 +1347,7 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1333
 
1334
  # --- Use the provided timestamp for unique filenames ---
1335
  timestamped_base_name = f"{base_name}_{timestamp}"
1336
-
1337
- # This will store the other part if separation is performed
1338
- other_part_tensor = None
1339
- other_part_sr = None
1340
 
1341
  # --- Step 1: Check file type and transcribe if necessary ---
1342
  if is_midi_input:
@@ -1385,25 +1396,19 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1385
  print(f"ERROR: Could not load {filename}. Skipping. FFmpeg error: {stderr}")
1386
  return None # Return None to indicate failure
1387
 
1388
- # --- Demucs Vocal Separation Logic, now decides which stem to process ---
1389
- if not params.separate_vocals or demucs_model is None:
1390
- if params.separate_vocals and demucs_model is None:
1391
- print("ERROR: Demucs model not loaded. Skipping separation.")
1392
- # --- Standard Workflow: Transcribe the original full audio ---
1393
- audio_to_transcribe_path = os.path.join(temp_dir, f"{timestamped_base_name}_original.flac")
1394
- torchaudio.save(audio_to_transcribe_path, audio_tensor, native_sample_rate)
1395
-
1396
- update_progress(0.2, "Transcribing audio to MIDI...")
1397
- midi_path_for_rendering = _transcribe_stem(audio_to_transcribe_path, f"{timestamped_base_name}_original", temp_dir, params)
1398
- else:
1399
  # --- Vocal Separation Workflow ---
1400
- update_progress(0.2, "Separating vocals with Demucs...")
1401
- # Convert to a common format (stereo, float32) that demucs expects
1402
  audio_tensor = convert_audio(audio_tensor, native_sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
1403
-
1404
  if torch.cuda.is_available():
1405
  audio_tensor = audio_tensor.cuda()
1406
-
1407
  print("Separating audio with Demucs... This may take some time.")
1408
  # --- Wrap the model call in a no_grad() context ---
1409
  with torch.no_grad():
@@ -1411,88 +1416,84 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1411
  demucs_model,
1412
  audio_tensor[None], # The input shape is [batch, channels, samples]
1413
  device='cuda' if torch.cuda.is_available() else 'cpu',
1414
- progress=True,
1415
  )[0] # Remove the batch dimension from the output
1416
 
1417
  # --- Clear CUDA cache immediately after use ---
1418
  if torch.cuda.is_available():
1419
  torch.cuda.empty_cache()
1420
  print("CUDA cache cleared.")
1421
-
1422
- # --- Robust stem handling to prevent CUDA errors ---
1423
- # Instead of complex GPU indexing, we create a dictionary of stems on the CPU.
1424
- # This is safer and more robust across different hardware.
1425
- sources = {}
1426
- for i, source_name in enumerate(demucs_model.sources):
1427
- sources[source_name] = all_stems[i]
1428
-
1429
- vocals_tensor = sources['vocals']
1430
-
1431
- # Sum the other stems to create the accompaniment.
1432
- # This loop is safer than a single complex indexing operation.
1433
- accompaniment_tensor = torch.zeros_like(vocals_tensor)
1434
- for source_name, stem_tensor in sources.items():
1435
- if source_name != 'vocals':
1436
- accompaniment_tensor += stem_tensor
1437
-
1438
- # --- Save both stems to temporary files ---
1439
- vocals_path = os.path.join(temp_dir, f"{base_name}_vocals.flac")
1440
- accompaniment_path = os.path.join(temp_dir, f"{base_name}_accompaniment.flac")
1441
- torchaudio.save(vocals_path, vocals_tensor.cpu(), demucs_model.samplerate)
1442
- torchaudio.save(accompaniment_path, accompaniment_tensor.cpu(), demucs_model.samplerate)
1443
-
1444
- # --- Determine which stem is the primary target and which is the "other part" ---
1445
- primary_target_path = vocals_path if params.transcription_target == "Transcribe Vocals" else accompaniment_path
1446
- other_part_path = accompaniment_path if params.transcription_target == "Transcribe Vocals" else vocals_path
1447
-
1448
- # Store the audio tensor of the "other part" for potential audio re-merging
1449
- other_part_tensor = accompaniment_tensor if params.transcription_target == "Transcribe Vocals" else vocals_tensor
1450
- other_part_sr = demucs_model.samplerate
1451
- print("Separation complete.")
1452
-
1453
- # --- Main Branching Logic: Transcribe one or both stems ---
1454
- if not params.transcribe_both_stems:
1455
- print(f"Transcribing primary target only: {os.path.basename(primary_target_path)}")
1456
- update_progress(0.4, f"Transcribing primary target: {os.path.basename(primary_target_path)}")
1457
- midi_path_for_rendering = _transcribe_stem(primary_target_path, os.path.splitext(os.path.basename(primary_target_path))[0], temp_dir, params)
1458
  else:
1459
- print("Transcribing BOTH stems and merging the MIDI results.")
1460
-
1461
- # Transcribe the primary target
1462
- update_progress(0.4, "Transcribing primary stem...")
1463
- midi_path_primary = _transcribe_stem(primary_target_path, os.path.splitext(os.path.basename(primary_target_path))[0], temp_dir, params)
 
1464
 
1465
- # Transcribe the other part
1466
- update_progress(0.5, "Transcribing second stem...")
1467
- midi_path_other = _transcribe_stem(other_part_path, os.path.splitext(os.path.basename(other_part_path))[0], temp_dir, params)
1468
-
1469
- # Merge the two resulting MIDI files
1470
- if midi_path_primary and midi_path_other:
1471
- update_progress(0.55, "Merging transcribed MIDIs...")
1472
- final_merged_midi_path = os.path.join(temp_dir, f"{base_name}_full_transcription.mid")
1473
- print(f"Merging transcribed MIDI files into {os.path.basename(final_merged_midi_path)}")
1474
-
1475
- # A more robust MIDI merge is needed here
1476
- primary_midi = pretty_midi.PrettyMIDI(midi_path_primary)
1477
- other_midi = pretty_midi.PrettyMIDI(midi_path_other)
1478
-
1479
- # Add all instruments from the other midi to the primary one
1480
- for instrument in other_midi.instruments:
1481
- instrument.name = f"Other - {instrument.name}" # Rename to avoid confusion
1482
- primary_midi.instruments.append(instrument)
1483
-
1484
- primary_midi.write(final_merged_midi_path)
1485
- midi_path_for_rendering = final_merged_midi_path
1486
- elif midi_path_primary:
1487
- print("Warning: Transcription of the 'other' part failed. Using primary transcription only.")
1488
- midi_path_for_rendering = midi_path_primary
1489
- else:
1490
- raise gr.Error("Transcription of the primary target failed. Aborting.")
 
 
 
 
 
 
 
 
 
 
 
 
 
1491
 
1492
  if not midi_path_for_rendering or not os.path.exists(midi_path_for_rendering):
1493
  print(f"ERROR: Transcription failed for {filename}. Skipping.")
1494
  return None
1495
-
1496
  # --- Step 2: Render the FINAL MIDI file with selected options ---
1497
  # The progress values are now conditional based on the input file type.
1498
  update_progress(0.1 if is_midi_input else 0.6, "Applying MIDI transformations...")
@@ -1515,60 +1516,70 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1515
  except Exception as e:
1516
  print(f"Could not auto-recommend parameters for {filename}: {e}.")
1517
 
 
1518
  update_progress(0.2 if is_midi_input else 0.7, "Rendering MIDI to audio...")
1519
  print(f"Proceeding to render MIDI file: {os.path.basename(midi_path_for_rendering)}")
1520
 
1521
  # Call the rendering function, Pass dictionaries directly to Render_MIDI
1522
  results_tuple = Render_MIDI(input_midi_path=midi_path_for_rendering, params=params)
1523
-
1524
- # --- Vocal Re-merging Logic ---
1525
- # Vocal Re-merging only happens for audio files, so its progress value doesn't need to be conditional.
1526
- if params.separate_vocals and params.remerge_vocals and not params.transcribe_both_stems and other_part_tensor is not None:
1527
- update_progress(0.8, "Re-merging rendered audio with vocals...")
1528
- print(f"Re-merging the non-transcribed part with newly rendered music...")
1529
 
1530
- # 1. Unpack the original rendered audio from the results
 
 
 
 
 
 
 
 
 
 
 
 
 
1531
  rendered_srate, rendered_music_int16 = results_tuple[4]
1532
-
1533
- # 2. Convert the rendered music to a float tensor
1534
  rendered_music_float = rendered_music_int16.astype(np.float32) / 32767.0
1535
- rendered_music_tensor = torch.from_numpy(rendered_music_float).T
 
1536
 
1537
- # 3. Resample if necessary
1538
- if rendered_srate != other_part_sr:
1539
- resampler = torchaudio.transforms.Resample(rendered_srate, other_part_sr)
1540
- rendered_music_tensor = resampler(rendered_music_tensor)
 
 
1541
 
1542
- # 4. Pad to match lengths
1543
- len_music = rendered_music_tensor.shape[1]
1544
- len_other = other_part_tensor.shape[1]
1545
-
1546
- if len_music > len_other:
1547
- padding = len_music - len_other
1548
- other_part_tensor = torch.nn.functional.pad(other_part_tensor, (0, padding))
1549
- elif len_other > len_music:
1550
- padding = len_other - len_music
1551
- rendered_music_tensor = torch.nn.functional.pad(rendered_music_tensor, (0, padding))
1552
 
1553
- # 5. Merge and normalize
1554
- merged_audio_tensor = rendered_music_tensor + other_part_tensor.cpu()
1555
- max_abs = torch.max(torch.abs(merged_audio_tensor))
1556
- if max_abs > 1.0:
1557
- merged_audio_tensor /= max_abs
1558
 
1559
- # 6. Convert back to the required format (int16 numpy array)
1560
- merged_audio_int16 = (merged_audio_tensor.T.numpy() * 32767).astype(np.int16)
 
1561
 
1562
- # 7. Create the new audio tuple and UPDATE the main results_tuple
1563
- new_audio_tuple = (other_part_sr, merged_audio_int16)
1564
 
 
1565
  temp_results_list = list(results_tuple)
1566
- temp_results_list[4] = new_audio_tuple
1567
  results_tuple = tuple(temp_results_list) # results_tuple is now updated
1568
  print("Re-merging complete.")
1569
-
1570
  # --- Save final audio and return path ---
1571
- update_progress(0.9, "Saving final files...")
1572
  final_srate, final_audio_data = results_tuple[4]
1573
  final_midi_path_from_render = results_tuple[3] # Get the path of the processed MIDI
1574
 
@@ -1577,7 +1588,7 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
1577
  output_midi_dir = "output/final_midi"
1578
  os.makedirs(output_audio_dir, exist_ok=True)
1579
  os.makedirs(output_midi_dir, exist_ok=True)
1580
-
1581
  final_audio_path = os.path.join(output_audio_dir, f"{timestamped_base_name}_rendered.flac")
1582
  # Also, copy the final processed MIDI to a consistent output directory with a timestamped name
1583
  final_midi_path = os.path.join(output_midi_dir, f"{timestamped_base_name}_processed.mid")
@@ -2274,6 +2285,35 @@ if __name__ == "__main__":
2274
  updates[component] = gr.update(value=value)
2275
 
2276
  return updates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2277
 
2278
  # --- Use the dataclass to define the master list of parameter keys ---
2279
  # This is now the single source of truth for parameter order.
@@ -2363,16 +2403,41 @@ if __name__ == "__main__":
2363
  enable_stereo_processing = gr.Checkbox(label="Enable Stereo Transcription", value=False,
2364
  info="For stereo audio files only. When enabled, transcribes left and right channels independently, then merges them. Note: This will double the transcription time.")
2365
 
2366
- # --- Vocal Separation Checkboxes ---
2367
  with gr.Group():
2368
- separate_vocals = gr.Checkbox(label="Separate Vocals", value=False,
2369
- info="If checked, separates the audio into vocals and music stems before processing.")
2370
- transcription_target = gr.Radio(["Transcribe Music (Accompaniment)", "Transcribe Vocals"], label="Transcription Target", value="Transcribe Music (Accompaniment)", visible=False,
2371
- info="Choose which part of the separated audio to transcribe to MIDI.")
2372
- remerge_vocals = gr.Checkbox(label="Re-merge Other Part with Rendered Audio", value=False, visible=False,
2373
- info="After rendering, merges the non-transcribed part (e.g., original vocals) back with the new music.")
2374
- transcribe_both_stems = gr.Checkbox(label="Transcribe Both Parts & Merge MIDI", value=False, visible=False,
2375
- info="If checked, transcribes BOTH vocals and music, then merges them into one MIDI file for rendering. Disables audio re-merging.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2376
 
2377
  with gr.Accordion("General Purpose Transcription Settings", open=True) as general_transcription_settings:
2378
  # --- Preset dropdown for basic_pitch ---
@@ -2657,10 +2722,26 @@ if __name__ == "__main__":
2657
  )
2658
 
2659
  # Event listeners for UI visibility and presets
 
2660
  separate_vocals.change(
2661
- fn=update_vocal_ui_visibility,
2662
  inputs=separate_vocals,
2663
- outputs=[transcription_target, remerge_vocals, transcribe_both_stems]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2664
  )
2665
 
2666
  # --- Listeners for dynamic UI updates ---
 
104
  # Global Settings
105
  s8bit_preset_selector: str = "Custom"
106
  separate_vocals: bool = False
107
+
108
+ # --- Advanced Separation and Merging Controls ---
109
+ enable_advanced_separation: bool = False # Controls visibility of advanced options
110
+ separate_drums: bool = True
111
+ separate_bass: bool = True
112
+ separate_other: bool = True
113
+
114
+ transcribe_vocals: bool = False
115
+ transcribe_drums: bool = False
116
+ transcribe_bass: bool = False
117
+ transcribe_other_or_accompaniment: bool = True # Default to transcribe 'other' as it's most common
118
+
119
+ merge_vocals_to_render: bool = False
120
+ merge_drums_to_render: bool = False
121
+ merge_bass_to_render: bool = False
122
+ merge_other_or_accompaniment: bool = False
123
+
124
  enable_stereo_processing: bool = False
125
  transcription_method: str = "General Purpose"
126
 
 
1347
 
1348
  # --- Use the provided timestamp for unique filenames ---
1349
  timestamped_base_name = f"{base_name}_{timestamp}"
1350
+
 
 
 
1351
 
1352
  # --- Step 1: Check file type and transcribe if necessary ---
1353
  if is_midi_input:
 
1396
  print(f"ERROR: Could not load {filename}. Skipping. FFmpeg error: {stderr}")
1397
  return None # Return None to indicate failure
1398
 
1399
+ # --- Demucs Vocal Separation Logic ---
1400
+ # This block now handles multi-stem separation, transcription, and merging logic.
1401
+ separated_stems = {} # This will store the audio tensors for merging
1402
+
1403
+ if params.separate_vocals and demucs_model is not None:
 
 
 
 
 
 
1404
  # --- Vocal Separation Workflow ---
1405
+ update_progress(0.2, "Separating audio with Demucs...")
1406
+ # Convert to the format Demucs expects (e.g., 44.1kHz, stereo)
1407
  audio_tensor = convert_audio(audio_tensor, native_sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
1408
+ # Move tensor to GPU if available for faster processing
1409
  if torch.cuda.is_available():
1410
  audio_tensor = audio_tensor.cuda()
1411
+
1412
  print("Separating audio with Demucs... This may take some time.")
1413
  # --- Wrap the model call in a no_grad() context ---
1414
  with torch.no_grad():
 
1416
  demucs_model,
1417
  audio_tensor[None], # The input shape is [batch, channels, samples]
1418
  device='cuda' if torch.cuda.is_available() else 'cpu',
1419
+ progress=True
1420
  )[0] # Remove the batch dimension from the output
1421
 
1422
  # --- Clear CUDA cache immediately after use ---
1423
  if torch.cuda.is_available():
1424
  torch.cuda.empty_cache()
1425
  print("CUDA cache cleared.")
1426
+
1427
+ sources = {name: stem for name, stem in zip(demucs_model.sources, all_stems)}
1428
+
1429
+ # --- Store original stems for potential re-merging ---
1430
+ for name, tensor in sources.items():
1431
+ separated_stems[name] = (tensor.cpu(), demucs_model.samplerate)
1432
+
1433
+ # --- Prepare Stems for Transcription ---
1434
+ stems_to_transcribe = {}
1435
+ if params.enable_advanced_separation:
1436
+ # User is in advanced mode, handle each stem individually
1437
+ if params.transcribe_vocals:
1438
+ stems_to_transcribe['vocals'] = sources['vocals']
1439
+ if params.transcribe_drums:
1440
+ stems_to_transcribe['drums'] = sources['drums']
1441
+ if params.transcribe_bass:
1442
+ stems_to_transcribe['bass'] = sources['bass']
1443
+ if params.transcribe_other_or_accompaniment:
1444
+ stems_to_transcribe['other'] = sources['other']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1445
  else:
1446
+ # User is in simple mode, create a single 'accompaniment' stem
1447
+ accompaniment_tensor = sources['drums'] + sources['bass'] + sources['other']
1448
+ if params.transcribe_vocals:
1449
+ stems_to_transcribe['vocals'] = sources['vocals']
1450
+ if params.transcribe_other_or_accompaniment:
1451
+ stems_to_transcribe['accompaniment'] = accompaniment_tensor
1452
 
1453
+ # --- Transcribe Selected Stems ---
1454
+ transcribed_midi_paths = []
1455
+ if stems_to_transcribe:
1456
+ stem_count = len(stems_to_transcribe)
1457
+ for i, (name, tensor) in enumerate(stems_to_transcribe.items()):
1458
+ update_progress(0.3 + (0.3 * (i / stem_count)), f"Transcribing stem: {name}...")
1459
+ stem_path = os.path.join(temp_dir, f"{timestamped_base_name}_{name}.flac")
1460
+ torchaudio.save(stem_path, tensor.cpu(), demucs_model.samplerate)
1461
+ midi_path = _transcribe_stem(stem_path, f"{timestamped_base_name}_{name}", temp_dir, params)
1462
+ if midi_path:
1463
+ transcribed_midi_paths.append((name, midi_path))
1464
+
1465
+ # --- Merge Transcribed MIDIs ---
1466
+ if not transcribed_midi_paths:
1467
+ raise gr.Error("Separation was enabled, but no stems were selected for transcription, or transcription failed.")
1468
+ elif len(transcribed_midi_paths) == 1:
1469
+ midi_path_for_rendering = transcribed_midi_paths[0][1]
1470
+ else:
1471
+ update_progress(0.6, "Merging transcribed MIDIs...")
1472
+ merged_midi = pretty_midi.PrettyMIDI()
1473
+ for name, path in transcribed_midi_paths:
1474
+ try:
1475
+ midi_stem = pretty_midi.PrettyMIDI(path)
1476
+ for inst in midi_stem.instruments:
1477
+ inst.name = f"{name.capitalize()} - {inst.name}"
1478
+ merged_midi.instruments.append(inst)
1479
+ except Exception as e:
1480
+ print(f"Warning: Could not merge MIDI for stem {name}. Error: {e}")
1481
+ final_merged_midi_path = os.path.join(temp_dir, f"{timestamped_base_name}_full_transcription.mid")
1482
+ merged_midi.write(final_merged_midi_path)
1483
+ midi_path_for_rendering = final_merged_midi_path
1484
+
1485
+ else: # Standard workflow without separation
1486
+ # --- Standard Workflow: Transcribe the original full audio ---
1487
+ audio_to_transcribe_path = os.path.join(temp_dir, f"{timestamped_base_name}_original.flac")
1488
+ torchaudio.save(audio_to_transcribe_path, audio_tensor, native_sample_rate)
1489
+
1490
+ update_progress(0.2, "Transcribing audio to MIDI...")
1491
+ midi_path_for_rendering = _transcribe_stem(audio_to_transcribe_path, f"{timestamped_base_name}_original", temp_dir, params)
1492
 
1493
  if not midi_path_for_rendering or not os.path.exists(midi_path_for_rendering):
1494
  print(f"ERROR: Transcription failed for {filename}. Skipping.")
1495
  return None
1496
+
1497
  # --- Step 2: Render the FINAL MIDI file with selected options ---
1498
  # The progress values are now conditional based on the input file type.
1499
  update_progress(0.1 if is_midi_input else 0.6, "Applying MIDI transformations...")
 
1516
  except Exception as e:
1517
  print(f"Could not auto-recommend parameters for {filename}: {e}.")
1518
 
1519
+ # --- Step 2: Render the FINAL MIDI file ---
1520
  update_progress(0.2 if is_midi_input else 0.7, "Rendering MIDI to audio...")
1521
  print(f"Proceeding to render MIDI file: {os.path.basename(midi_path_for_rendering)}")
1522
 
1523
  # Call the rendering function, Pass dictionaries directly to Render_MIDI
1524
  results_tuple = Render_MIDI(input_midi_path=midi_path_for_rendering, params=params)
1525
+
1526
+ # --- Final Audio Merging Logic ---
1527
+ stems_to_merge = []
1528
+ if params.separate_vocals:
1529
+ if params.merge_vocals_to_render and 'vocals' in separated_stems:
1530
+ stems_to_merge.append(separated_stems['vocals'])
1531
 
1532
+ if params.enable_advanced_separation:
1533
+ if params.merge_drums_to_render and 'drums' in separated_stems:
1534
+ stems_to_merge.append(separated_stems['drums'])
1535
+ if params.merge_bass_to_render and 'bass' in separated_stems:
1536
+ stems_to_merge.append(separated_stems['bass'])
1537
+ if params.merge_other_or_accompaniment and 'other' in separated_stems:
1538
+ stems_to_merge.append(separated_stems['other'])
1539
+ else: # Simple mode
1540
+ if params.merge_other_or_accompaniment: # 'other' checkbox now controls the whole accompaniment
1541
+ accompaniment_tensor = separated_stems['drums'][0] + separated_stems['bass'][0] + separated_stems['other'][0]
1542
+ stems_to_merge.append((accompaniment_tensor, demucs_model.samplerate))
1543
+
1544
+ if stems_to_merge:
1545
+ update_progress(0.9, "Re-merging audio stems...")
1546
  rendered_srate, rendered_music_int16 = results_tuple[4]
 
 
1547
  rendered_music_float = rendered_music_int16.astype(np.float32) / 32767.0
1548
+ final_mix_tensor = torch.from_numpy(rendered_music_float).T
1549
+ final_srate = rendered_srate
1550
 
1551
+ for stem_tensor, stem_srate in stems_to_merge:
1552
+ # Resample if necessary
1553
+ if stem_srate != final_srate:
1554
+ # Resample all stems to match the rendered audio's sample rate
1555
+ resampler = torchaudio.transforms.Resample(stem_srate, final_srate)
1556
+ stem_tensor = resampler(stem_tensor)
1557
 
1558
+ # Pad and add to the final mix
1559
+ len_mix = final_mix_tensor.shape[1]
1560
+ len_stem = stem_tensor.shape[1]
1561
+ if len_mix > len_stem:
1562
+ stem_tensor = torch.nn.functional.pad(stem_tensor, (0, len_mix - len_stem))
1563
+ elif len_stem > len_mix:
1564
+ final_mix_tensor = torch.nn.functional.pad(final_mix_tensor, (0, len_stem - len_mix))
 
 
 
1565
 
1566
+ final_mix_tensor += stem_tensor
 
 
 
 
1567
 
1568
+ # Normalize final mix to prevent clipping
1569
+ max_abs = torch.max(torch.abs(final_mix_tensor))
1570
+ if max_abs > 1.0: final_mix_tensor /= max_abs
1571
 
1572
+ # Convert back to the required format (int16 numpy array)
1573
+ merged_audio_int16 = (final_mix_tensor.T.numpy() * 32767).astype(np.int16)
1574
 
1575
+ # Update the results tuple with the newly merged audio
1576
  temp_results_list = list(results_tuple)
1577
+ temp_results_list[4] = (final_srate, merged_audio_int16)
1578
  results_tuple = tuple(temp_results_list) # results_tuple is now updated
1579
  print("Re-merging complete.")
1580
+
1581
  # --- Save final audio and return path ---
1582
+ update_progress(0.95, "Saving final files...")
1583
  final_srate, final_audio_data = results_tuple[4]
1584
  final_midi_path_from_render = results_tuple[3] # Get the path of the processed MIDI
1585
 
 
1588
  output_midi_dir = "output/final_midi"
1589
  os.makedirs(output_audio_dir, exist_ok=True)
1590
  os.makedirs(output_midi_dir, exist_ok=True)
1591
+
1592
  final_audio_path = os.path.join(output_audio_dir, f"{timestamped_base_name}_rendered.flac")
1593
  # Also, copy the final processed MIDI to a consistent output directory with a timestamped name
1594
  final_midi_path = os.path.join(output_midi_dir, f"{timestamped_base_name}_processed.mid")
 
2285
  updates[component] = gr.update(value=value)
2286
 
2287
  return updates
2288
+
2289
+ # --- UI Controller Function for Dynamic Visibility ---
2290
+ def update_separation_mode_ui(is_advanced):
2291
+ """
2292
+ Updates the visibility and labels of UI components based on whether
2293
+ the advanced separation mode is enabled.
2294
+ """
2295
+ if is_advanced:
2296
+ # Advanced Mode: Show individual controls, label becomes "Other"
2297
+ return {
2298
+ advanced_separation_controls: gr.update(visible=True),
2299
+ transcribe_drums: gr.update(visible=True),
2300
+ transcribe_bass: gr.update(visible=True),
2301
+ transcribe_other_or_accompaniment: gr.update(label="Transcribe Other"),
2302
+ merge_drums_to_render: gr.update(visible=True),
2303
+ merge_bass_to_render: gr.update(visible=True),
2304
+ merge_other_or_accompaniment: gr.update(label="Merge Other")
2305
+ }
2306
+ else:
2307
+ # Simple Mode: Hide individual controls, label becomes "Accompaniment"
2308
+ return {
2309
+ advanced_separation_controls: gr.update(visible=False),
2310
+ transcribe_drums: gr.update(visible=False),
2311
+ transcribe_bass: gr.update(visible=False),
2312
+ transcribe_other_or_accompaniment: gr.update(label="Transcribe Accompaniment"),
2313
+ merge_drums_to_render: gr.update(visible=False),
2314
+ merge_bass_to_render: gr.update(visible=False),
2315
+ merge_other_or_accompaniment: gr.update(label="Merge Accompaniment")
2316
+ }
2317
 
2318
  # --- Use the dataclass to define the master list of parameter keys ---
2319
  # This is now the single source of truth for parameter order.
 
2403
  enable_stereo_processing = gr.Checkbox(label="Enable Stereo Transcription", value=False,
2404
  info="For stereo audio files only. When enabled, transcribes left and right channels independently, then merges them. Note: This will double the transcription time.")
2405
 
2406
+ # --- Vocal Separation Group ---
2407
  with gr.Group():
2408
+ separate_vocals = gr.Checkbox(label="Enable Source Separation (Demucs)", value=False,
2409
+ info="If checked, separates the audio into its component stems (vocals, drums, etc.) before processing.")
2410
+
2411
+ # --- Container for all separation options, visible only when enabled ---
2412
+ with gr.Group(visible=False) as separation_options_box:
2413
+ gr.Markdown("#### 1. Stem Separation Options")
2414
+ enable_advanced_separation = gr.Checkbox(label="Enable Advanced Stem Control (for Accompaniment)", value=False,
2415
+ info="If checked, you can individually control drums, bass, and other. If unchecked, they are treated as a single 'Accompaniment' track.")
2416
+
2417
+ with gr.Row(visible=False) as advanced_separation_controls:
2418
+ separate_drums = gr.Checkbox(label="Drums", value=True)
2419
+ separate_bass = gr.Checkbox(label="Bass", value=True)
2420
+ separate_other = gr.Checkbox(label="Other", value=True)
2421
+
2422
+ gr.Markdown("#### 2. Transcription Targets")
2423
+ gr.Markdown("_Select which separated stem(s) to convert to MIDI._")
2424
+ with gr.Row():
2425
+ transcribe_vocals = gr.Checkbox(label="Transcribe Vocals", value=False)
2426
+ # These two will be hidden/shown dynamically
2427
+ transcribe_drums = gr.Checkbox(label="Transcribe Drums", value=False, visible=False)
2428
+ transcribe_bass = gr.Checkbox(label="Transcribe Bass", value=False, visible=False)
2429
+ # This checkbox will have its label changed dynamically
2430
+ transcribe_other_or_accompaniment = gr.Checkbox(label="Transcribe Accompaniment", value=True)
2431
+
2432
+ gr.Markdown("#### 3. Audio Merging Targets")
2433
+ gr.Markdown("_Select which original stem(s) to re-merge with the final rendered audio._")
2434
+ with gr.Row():
2435
+ merge_vocals_to_render = gr.Checkbox(label="Merge Vocals", value=False)
2436
+ # These two will be hidden/shown dynamically
2437
+ merge_drums_to_render = gr.Checkbox(label="Merge Drums", value=False, visible=False)
2438
+ merge_bass_to_render = gr.Checkbox(label="Merge Bass", value=False, visible=False)
2439
+ # This checkbox will have its label changed dynamically
2440
+ merge_other_or_accompaniment = gr.Checkbox(label="Merge Accompaniment", value=True)
2441
 
2442
  with gr.Accordion("General Purpose Transcription Settings", open=True) as general_transcription_settings:
2443
  # --- Preset dropdown for basic_pitch ---
 
2722
  )
2723
 
2724
  # Event listeners for UI visibility and presets
2725
+ # When the main separation checkbox is toggled
2726
  separate_vocals.change(
2727
+ fn=lambda x: gr.update(visible=x),
2728
  inputs=separate_vocals,
2729
+ outputs=[separation_options_box]
2730
+ )
2731
+
2732
+ # When the advanced stem control checkbox is toggled, update all relevant UI parts
2733
+ enable_advanced_separation.change(
2734
+ fn=update_separation_mode_ui,
2735
+ inputs=enable_advanced_separation,
2736
+ outputs=[
2737
+ advanced_separation_controls,
2738
+ transcribe_drums,
2739
+ transcribe_bass,
2740
+ transcribe_other_or_accompaniment,
2741
+ merge_drums_to_render,
2742
+ merge_bass_to_render,
2743
+ merge_other_or_accompaniment
2744
+ ]
2745
  )
2746
 
2747
  # --- Listeners for dynamic UI updates ---