avans06 commited on
Commit
d050f96
·
1 Parent(s): 7a65465

feat(separation): Integrate BS-RoFormer & Mel-RoFormer models

Browse files

This commit introduces support for two additional audio separation models,
BS-RoFormer and Mel-RoFormer, providing users with more specialized
options for vocal and instrumental separation.

Files changed (2) hide show
  1. app.py +166 -44
  2. requirements.txt +4 -1
app.py CHANGED
@@ -61,6 +61,7 @@ import torchaudio
61
  from demucs.apply import apply_model
62
  from demucs.pretrained import get_model
63
  from demucs.audio import convert_audio
 
64
 
65
  from src.piano_transcription.utils import initialize_app
66
  from piano_transcription_inference import PianoTranscription, utilities, sample_rate as transcription_sample_rate
@@ -106,6 +107,7 @@ class AppParameters:
106
  # Global Settings
107
  s8bit_preset_selector: str = "Custom"
108
  separate_vocals: bool = False
 
109
 
110
  # --- Advanced Separation and Merging Controls ---
111
  enable_advanced_separation: bool = False # Controls visibility of advanced options
@@ -1609,7 +1611,7 @@ def TranscribePianoAudio(input_file):
1609
  # Use 'cuda' if a GPU is available and configured, otherwise 'cpu'
1610
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
1611
  print(f'Loading transcriptor model... device= {device}')
1612
- transcriptor = PianoTranscription(device=device, checkpoint_path="src/models/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth")
1613
  print('Transcriptor loaded.')
1614
  print('-' * 70)
1615
 
@@ -2377,7 +2379,7 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
2377
 
2378
  midi_path_for_rendering = input_file_path
2379
  else:
2380
- temp_dir = "output/temp_transcribe" # Define temp_dir early for the fallback
2381
  os.makedirs(temp_dir, exist_ok=True)
2382
 
2383
  # --- Audio Loading ---
@@ -2413,44 +2415,98 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
2413
  print(f"ERROR: Could not load {filename}. Skipping. FFmpeg error: {stderr}")
2414
  return None # Return None to indicate failure
2415
 
2416
- # --- Demucs Vocal Separation Logic ---
2417
  # This block now handles multi-stem separation, transcription, and merging logic.
2418
  separated_stems = {} # This will store the audio tensors for merging
 
2419
 
2420
- if params.separate_vocals and demucs_model is not None:
2421
- # --- Vocal Separation Workflow ---
2422
- update_progress(0.2, "Separating audio with Demucs...")
2423
- # Convert to the format Demucs expects (e.g., 44.1kHz, stereo)
2424
- audio_tensor = convert_audio(audio_tensor, native_sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
2425
- # Move tensor to GPU if available for faster processing
2426
- if torch.cuda.is_available():
2427
- audio_tensor = audio_tensor.cuda()
2428
-
2429
- print("Separating audio with Demucs... This may take some time.")
2430
- # --- Wrap the model call in a no_grad() context ---
2431
- with torch.no_grad():
2432
- all_stems = apply_model(
2433
- demucs_model,
2434
- audio_tensor[None], # The input shape is [batch, channels, samples]
2435
- device='cuda' if torch.cuda.is_available() else 'cpu',
2436
- progress=True
2437
- )[0] # Remove the batch dimension from the output
2438
 
2439
- # --- Clear CUDA cache immediately after use ---
2440
- if torch.cuda.is_available():
2441
- torch.cuda.empty_cache()
2442
- print("CUDA cache cleared.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2443
 
2444
- sources = {name: stem for name, stem in zip(demucs_model.sources, all_stems)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2445
 
2446
- # --- Store original stems for potential re-merging ---
2447
- for name, tensor in sources.items():
2448
- separated_stems[name] = (tensor.cpu(), demucs_model.samplerate)
2449
 
2450
- # --- Prepare Stems for Transcription ---
 
 
 
 
 
2451
  stems_to_transcribe = {}
 
2452
  if params.enable_advanced_separation:
2453
- # User is in advanced mode, handle each stem individually
2454
  if params.transcribe_vocals:
2455
  stems_to_transcribe['vocals'] = sources['vocals']
2456
  if params.transcribe_drums:
@@ -2460,7 +2516,9 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
2460
  if params.transcribe_other_or_accompaniment:
2461
  stems_to_transcribe['other'] = sources['other']
2462
  else:
2463
- # User is in simple mode, create a single 'accompaniment' stem
 
 
2464
  accompaniment_tensor = sources['drums'] + sources['bass'] + sources['other']
2465
  if params.transcribe_vocals:
2466
  stems_to_transcribe['vocals'] = sources['vocals']
@@ -2471,10 +2529,13 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
2471
  transcribed_midi_paths = []
2472
  if stems_to_transcribe:
2473
  stem_count = len(stems_to_transcribe)
 
 
 
2474
  for i, (name, tensor) in enumerate(stems_to_transcribe.items()):
2475
  update_progress(0.3 + (0.3 * (i / stem_count)), f"Transcribing stem: {name}...")
2476
  stem_path = os.path.join(temp_dir, f"{timestamped_base_name}_{name}.flac")
2477
- torchaudio.save(stem_path, tensor.cpu(), demucs_model.samplerate)
2478
  midi_path, used_bp_params = _transcribe_stem(stem_path, f"{timestamped_base_name}_{name}", temp_dir, params)
2479
  if midi_path:
2480
  transcribed_midi_paths.append((name, midi_path))
@@ -2554,7 +2615,7 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
2554
 
2555
  # --- Final Audio Merging Logic ---
2556
  stems_to_merge = []
2557
- if params.separate_vocals:
2558
  if params.merge_vocals_to_render and 'vocals' in separated_stems:
2559
  stems_to_merge.append(separated_stems['vocals'])
2560
 
@@ -2565,10 +2626,12 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
2565
  stems_to_merge.append(separated_stems['bass'])
2566
  if params.merge_other_or_accompaniment and 'other' in separated_stems:
2567
  stems_to_merge.append(separated_stems['other'])
2568
- else: # Simple mode
2569
- if params.merge_other_or_accompaniment: # 'other' checkbox now controls the whole accompaniment
 
2570
  accompaniment_tensor = separated_stems['drums'][0] + separated_stems['bass'][0] + separated_stems['other'][0]
2571
- stems_to_merge.append((accompaniment_tensor, demucs_model.samplerate))
 
2572
 
2573
  if stems_to_merge:
2574
  update_progress(0.9, "Re-merging audio stems...")
@@ -2584,6 +2647,10 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
2584
  resampler = torchaudio.transforms.Resample(stem_srate, final_srate)
2585
  stem_tensor = resampler(stem_tensor)
2586
 
 
 
 
 
2587
  # Pad and add to the final mix
2588
  len_mix = final_mix_tensor.shape[1]
2589
  len_stem = stem_tensor.shape[1]
@@ -2613,8 +2680,8 @@ def run_single_file_pipeline(input_file_path: str, timestamp: str, params: AppPa
2613
  final_midi_path_from_render = results_tuple[3] # Get the path of the processed MIDI
2614
 
2615
  # --- Use timestamped names for final outputs ---
2616
- output_audio_dir = "output/final_audio"
2617
- output_midi_dir = "output/final_midi"
2618
  os.makedirs(output_audio_dir, exist_ok=True)
2619
  os.makedirs(output_midi_dir, exist_ok=True)
2620
 
@@ -2835,7 +2902,7 @@ if __name__ == "__main__":
2835
  initialize_app()
2836
 
2837
  # --- Prepare soundfonts and make the map globally accessible ---
2838
- global soundfonts_dict, demucs_model
2839
  # On application start, download SoundFonts from Hugging Face Hub if they don't exist.
2840
  soundfonts_dict = prepare_soundfonts()
2841
  print(f"Found {len(soundfonts_dict)} local SoundFonts.")
@@ -2855,6 +2922,25 @@ if __name__ == "__main__":
2855
  print(f"Warning: Could not load Demucs model. Vocal separation will not be available. Error: {e}")
2856
  demucs_model = None
2857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2858
  # --- Dictionary containing descriptions for each render type ---
2859
  RENDER_TYPE_DESCRIPTIONS = {
2860
  "Render as-is": "**Mode: Pass-through.** Renders the MIDI file directly without any modifications. Advanced MIDI options will be ignored.",
@@ -3385,6 +3471,19 @@ if __name__ == "__main__":
3385
  merge_other_or_accompaniment: gr.update(label="Merge Accompaniment")
3386
  }
3387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3388
  # --- Use the dataclass to define the master list of parameter keys ---
3389
  # This is now the single source of truth for parameter order.
3390
  ALL_PARAM_KEYS = [field.name for field in fields(AppParameters) if field.name not in ["input_file", "batch_input_files"]]
@@ -3475,14 +3574,21 @@ if __name__ == "__main__":
3475
 
3476
  # --- Vocal Separation Group ---
3477
  with gr.Group():
3478
- separate_vocals = gr.Checkbox(label="Enable Source Separation (Demucs)", value=False,
3479
  info="If checked, separates the audio into its component stems (vocals, drums, etc.) before processing.")
3480
 
3481
  # --- Container for all separation options, visible only when enabled ---
3482
  with gr.Group(visible=False) as separation_options_box:
 
 
 
 
 
 
 
3483
  gr.Markdown("#### 1. Stem Separation Options")
3484
- enable_advanced_separation = gr.Checkbox(label="Enable Advanced Stem Control (for Accompaniment)", value=False,
3485
- info="If checked, you can individually control drums, bass, and other. If unchecked, they are treated as a single 'Accompaniment' track.")
3486
 
3487
  with gr.Row(visible=False) as advanced_separation_controls:
3488
  separate_drums = gr.Checkbox(label="Drums", value=True)
@@ -4066,7 +4172,23 @@ if __name__ == "__main__":
4066
  outputs=[separation_options_box]
4067
  )
4068
 
4069
- # When the advanced stem control checkbox is toggled, update all relevant UI parts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4070
  enable_advanced_separation.change(
4071
  fn=update_separation_mode_ui,
4072
  inputs=enable_advanced_separation,
 
61
  from demucs.apply import apply_model
62
  from demucs.pretrained import get_model
63
  from demucs.audio import convert_audio
64
+ from audio_separator.separator import Separator
65
 
66
  from src.piano_transcription.utils import initialize_app
67
  from piano_transcription_inference import PianoTranscription, utilities, sample_rate as transcription_sample_rate
 
107
  # Global Settings
108
  s8bit_preset_selector: str = "Custom"
109
  separate_vocals: bool = False
110
+ separation_model: str = "Demucs (4-stem)"
111
 
112
  # --- Advanced Separation and Merging Controls ---
113
  enable_advanced_separation: bool = False # Controls visibility of advanced options
 
1611
  # Use 'cuda' if a GPU is available and configured, otherwise 'cpu'
1612
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
1613
  print(f'Loading transcriptor model... device= {device}')
1614
+ transcriptor = PianoTranscription(device=device, checkpoint_path=os.path.join("src", "models", "CRNN_note_F1=0.9677_pedal_F1=0.9186.pth"))
1615
  print('Transcriptor loaded.')
1616
  print('-' * 70)
1617
 
 
2379
 
2380
  midi_path_for_rendering = input_file_path
2381
  else:
2382
+ temp_dir = os.path.join("output", "temp_transcribe") # Define temp_dir early for the fallback
2383
  os.makedirs(temp_dir, exist_ok=True)
2384
 
2385
  # --- Audio Loading ---
 
2415
  print(f"ERROR: Could not load {filename}. Skipping. FFmpeg error: {stderr}")
2416
  return None # Return None to indicate failure
2417
 
2418
+ # --- Vocal Separation Logic ---
2419
  # This block now handles multi-stem separation, transcription, and merging logic.
2420
  separated_stems = {} # This will store the audio tensors for merging
2421
+ sources = {} # This will hold the tensors for transcription processing
2422
 
2423
+ if params.separate_vocals:
2424
+ model_name = params.separation_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2425
 
2426
+ # --- Demucs Separation Workflow (4-stem) ---
2427
+ if 'Demucs' in model_name and demucs_model is not None:
2428
+ update_progress(0.2, "Separating audio with Demucs...")
2429
+ # Convert to the format Demucs expects (e.g., 44.1kHz, stereo)
2430
+ audio_tensor_demucs = convert_audio(audio_tensor, native_sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
2431
+ # Move tensor to GPU if available for faster processing
2432
+ if torch.cuda.is_available():
2433
+ audio_tensor_demucs = audio_tensor_demucs.cuda()
2434
+
2435
+ print("Separating audio with Demucs... This may take some time.")
2436
+ # --- Wrap the model call in a no_grad() context ---
2437
+ with torch.no_grad():
2438
+ all_stems = apply_model(
2439
+ demucs_model,
2440
+ audio_tensor_demucs[None], # The input shape is [batch, channels, samples]
2441
+ device='cuda' if torch.cuda.is_available() else 'cpu',
2442
+ progress=True
2443
+ )[0] # Remove the batch dimension from the output
2444
 
2445
+ # --- Clear CUDA cache immediately after use ---
2446
+ if torch.cuda.is_available():
2447
+ torch.cuda.empty_cache()
2448
+ print("CUDA cache cleared.")
2449
+
2450
+ # Populate sources for transcription and separated_stems for merging
2451
+ sources = {name: stem for name, stem in zip(demucs_model.sources, all_stems)}
2452
+
2453
+ # --- Store original stems for potential re-merging ---
2454
+ for name, tensor in sources.items():
2455
+ separated_stems[name] = (tensor.cpu(), demucs_model.samplerate)
2456
+
2457
+ # --- RoFormer Separation Workflow (2-stem) ---
2458
+ elif ('BS-RoFormer' in model_name or 'Mel-RoFormer' in model_name):
2459
+ if not separator_models:
2460
+ print("Warning: RoFormer models are not loaded. Skipping separation.")
2461
+ params.separate_vocals = False
2462
+ else:
2463
+ roformer_key = 'BS-RoFormer' if 'BS-RoFormer' in model_name else 'Mel-RoFormer'
2464
+ update_progress(0.2, f"Separating audio with {roformer_key}...")
2465
+
2466
+ temp_input_path = os.path.join(temp_dir, f"{timestamped_base_name}_roformer_in.flac")
2467
+ torchaudio.save(temp_input_path, audio_tensor.cpu(), native_sample_rate)
2468
+
2469
+ try:
2470
+ separator = separator_models[roformer_key]
2471
+ output_paths = separator.separate(temp_input_path)
2472
+
2473
+ vocals_path, accompaniment_path = None, None
2474
+ for path in output_paths:
2475
+ basename = os.path.basename(path).lower()
2476
+ path = os.path.join(temp_dir, path)
2477
+ if '(vocals)' in basename:
2478
+ vocals_path = path
2479
+ elif '(instrumental)' in basename:
2480
+ accompaniment_path = path
2481
+
2482
+ if not vocals_path or not accompaniment_path:
2483
+ raise RuntimeError(f"Could not find expected vocal/instrumental stems in output: {output_paths}")
2484
+
2485
+ print(f"Separation complete. Vocals: {os.path.basename(vocals_path)}, Accompaniment: {os.path.basename(accompaniment_path)}")
2486
+
2487
+ vocals_tensor, stem_sr = torchaudio.load(vocals_path)
2488
+ accompaniment_tensor, stem_sr = torchaudio.load(accompaniment_path)
2489
+
2490
+ # Populate 'sources' and 'separated_stems' to match Demucs structure
2491
+ # This ensures compatibility with downstream logic
2492
+ sources['vocals'] = vocals_tensor
2493
+ sources['other'] = accompaniment_tensor # The entire accompaniment
2494
+ sources['drums'] = torch.zeros_like(accompaniment_tensor) # Dummy tensor
2495
+ sources['bass'] = torch.zeros_like(accompaniment_tensor) # Dummy tensor
2496
 
2497
+ for name, tensor in sources.items():
2498
+ separated_stems[name] = (tensor.cpu(), stem_sr)
 
2499
 
2500
+ except Exception as e:
2501
+ print(f"ERROR: {roformer_key} separation failed: {e}. Skipping separation.")
2502
+ params.separate_vocals = False
2503
+
2504
+ # --- Prepare Stems for Transcription ---
2505
+ if params.separate_vocals and sources: # Check if separation was successful
2506
  stems_to_transcribe = {}
2507
+ # NOTE: When a 2-stem model is used, the UI should ensure 'enable_advanced_separation' is False.
2508
  if params.enable_advanced_separation:
2509
+ # User is in advanced mode (Demucs only)
2510
  if params.transcribe_vocals:
2511
  stems_to_transcribe['vocals'] = sources['vocals']
2512
  if params.transcribe_drums:
 
2516
  if params.transcribe_other_or_accompaniment:
2517
  stems_to_transcribe['other'] = sources['other']
2518
  else:
2519
+ # Simple mode (Demucs) or RoFormer mode
2520
+ # This logic correctly combines drums, bass, and other. For RoFormer, drums/bass are zero,
2521
+ # so this correctly results in just the accompaniment tensor.
2522
  accompaniment_tensor = sources['drums'] + sources['bass'] + sources['other']
2523
  if params.transcribe_vocals:
2524
  stems_to_transcribe['vocals'] = sources['vocals']
 
2529
  transcribed_midi_paths = []
2530
  if stems_to_transcribe:
2531
  stem_count = len(stems_to_transcribe)
2532
+ # The samplerate of all stems from a single separation will be the same
2533
+ stem_samplerate = separated_stems.get('vocals', (None, native_sample_rate))[1]
2534
+
2535
  for i, (name, tensor) in enumerate(stems_to_transcribe.items()):
2536
  update_progress(0.3 + (0.3 * (i / stem_count)), f"Transcribing stem: {name}...")
2537
  stem_path = os.path.join(temp_dir, f"{timestamped_base_name}_{name}.flac")
2538
+ torchaudio.save(stem_path, tensor.cpu(), stem_samplerate)
2539
  midi_path, used_bp_params = _transcribe_stem(stem_path, f"{timestamped_base_name}_{name}", temp_dir, params)
2540
  if midi_path:
2541
  transcribed_midi_paths.append((name, midi_path))
 
2615
 
2616
  # --- Final Audio Merging Logic ---
2617
  stems_to_merge = []
2618
+ if params.separate_vocals and separated_stems:
2619
  if params.merge_vocals_to_render and 'vocals' in separated_stems:
2620
  stems_to_merge.append(separated_stems['vocals'])
2621
 
 
2626
  stems_to_merge.append(separated_stems['bass'])
2627
  if params.merge_other_or_accompaniment and 'other' in separated_stems:
2628
  stems_to_merge.append(separated_stems['other'])
2629
+ else: # Simple mode or RoFormer
2630
+ if params.merge_other_or_accompaniment:
2631
+ # This correctly combines the accompaniment, which for RoFormer is just the 'other' stem.
2632
  accompaniment_tensor = separated_stems['drums'][0] + separated_stems['bass'][0] + separated_stems['other'][0]
2633
+ accompaniment_sr = separated_stems['other'][1]
2634
+ stems_to_merge.append((accompaniment_tensor, accompaniment_sr))
2635
 
2636
  if stems_to_merge:
2637
  update_progress(0.9, "Re-merging audio stems...")
 
2647
  resampler = torchaudio.transforms.Resample(stem_srate, final_srate)
2648
  stem_tensor = resampler(stem_tensor)
2649
 
2650
+ # Ensure stem is stereo if mix is stereo
2651
+ if final_mix_tensor.shape[0] == 2 and stem_tensor.shape[0] == 1:
2652
+ stem_tensor = stem_tensor.repeat(2, 1)
2653
+
2654
  # Pad and add to the final mix
2655
  len_mix = final_mix_tensor.shape[1]
2656
  len_stem = stem_tensor.shape[1]
 
2680
  final_midi_path_from_render = results_tuple[3] # Get the path of the processed MIDI
2681
 
2682
  # --- Use timestamped names for final outputs ---
2683
+ output_audio_dir = os.path.join("output", "final_audio")
2684
+ output_midi_dir = os.path.join("output", "final_midi")
2685
  os.makedirs(output_audio_dir, exist_ok=True)
2686
  os.makedirs(output_midi_dir, exist_ok=True)
2687
 
 
2902
  initialize_app()
2903
 
2904
  # --- Prepare soundfonts and make the map globally accessible ---
2905
+ global soundfonts_dict, demucs_model, separator_models
2906
  # On application start, download SoundFonts from Hugging Face Hub if they don't exist.
2907
  soundfonts_dict = prepare_soundfonts()
2908
  print(f"Found {len(soundfonts_dict)} local SoundFonts.")
 
2922
  print(f"Warning: Could not load Demucs model. Vocal separation will not be available. Error: {e}")
2923
  demucs_model = None
2924
 
2925
+ # --- Pre-load BS-RoFormer and Mel-RoFormer models ---
2926
+ separator_models: dict[str, Separator] = {}
2927
+ try:
2928
+ temp_dir = os.path.join("output", "temp_transcribe")
2929
+ print("Loading BS-RoFormer model...")
2930
+ bs_roformer = Separator(output_dir=temp_dir, output_format='flac', model_file_dir=os.path.join("src", "models"))
2931
+ bs_roformer.load_model("model_bs_roformer_ep_317_sdr_12.9755.ckpt")
2932
+ separator_models['BS-RoFormer'] = bs_roformer
2933
+ print("BS-RoFormer model loaded successfully.")
2934
+
2935
+ print("Loading Mel-RoFormer model...")
2936
+ mel_roformer = Separator(output_dir=temp_dir, output_format='flac', model_file_dir=os.path.join("src", "models"))
2937
+ mel_roformer.load_model("model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt")
2938
+ separator_models['Mel-RoFormer'] = mel_roformer
2939
+ print("Mel-RoFormer model loaded successfully.")
2940
+
2941
+ except Exception as e:
2942
+ print(f"Warning: Could not load RoFormer models. They will not be available for separation. Error: {e}")
2943
+
2944
  # --- Dictionary containing descriptions for each render type ---
2945
  RENDER_TYPE_DESCRIPTIONS = {
2946
  "Render as-is": "**Mode: Pass-through.** Renders the MIDI file directly without any modifications. Advanced MIDI options will be ignored.",
 
3471
  merge_other_or_accompaniment: gr.update(label="Merge Accompaniment")
3472
  }
3473
 
3474
+ # --- UI controller for handling model selection ---
3475
+ def on_separation_model_change(model_choice):
3476
+ """
3477
+ Update the UI when the separation model changes.
3478
+ If a 2-stem model (RoFormer) is selected, hide advanced (4-stem) controls.
3479
+ """
3480
+ is_demucs = 'Demucs' in model_choice
3481
+ # For 2-stem models, we force simple mode (is_advanced=False)
3482
+ updates = update_separation_mode_ui(is_advanced=False)
3483
+ # Also hide the checkbox that allows switching to advanced mode
3484
+ updates[enable_advanced_separation] = gr.update(visible=is_demucs, value=False)
3485
+ return updates
3486
+
3487
  # --- Use the dataclass to define the master list of parameter keys ---
3488
  # This is now the single source of truth for parameter order.
3489
  ALL_PARAM_KEYS = [field.name for field in fields(AppParameters) if field.name not in ["input_file", "batch_input_files"]]
 
3574
 
3575
  # --- Vocal Separation Group ---
3576
  with gr.Group():
3577
+ separate_vocals = gr.Checkbox(label="Enable Source Separation", value=False,
3578
  info="If checked, separates the audio into its component stems (vocals, drums, etc.) before processing.")
3579
 
3580
  # --- Container for all separation options, visible only when enabled ---
3581
  with gr.Group(visible=False) as separation_options_box:
3582
+ separation_model = gr.Radio(
3583
+ ["Demucs (4-stem)", "BS-RoFormer (Vocals/Instrumental)", "Mel-RoFormer (Vocals/Instrumental)"],
3584
+ label="Separation Model",
3585
+ value="Demucs (4-stem)",
3586
+ info="Select the separation model. Demucs provides 4 stems (vocals, drums, bass, other). RoFormer models are specialized for 2-stem (vocals/instrumental) separation.",
3587
+ )
3588
+
3589
  gr.Markdown("#### 1. Stem Separation Options")
3590
+ enable_advanced_separation = gr.Checkbox(label="Enable Advanced Stem Control (Demucs Only)", value=False,
3591
+ info="If checked, you can individually control drums, bass, and other. If unchecked, they are treated as a single 'Accompaniment' track. This option is only available for the Demucs model.")
3592
 
3593
  with gr.Row(visible=False) as advanced_separation_controls:
3594
  separate_drums = gr.Checkbox(label="Drums", value=True)
 
4172
  outputs=[separation_options_box]
4173
  )
4174
 
4175
+ # When the model selection changes, trigger UI update
4176
+ separation_model.change(
4177
+ fn=on_separation_model_change,
4178
+ inputs=separation_model,
4179
+ outputs=[
4180
+ enable_advanced_separation,
4181
+ advanced_separation_controls,
4182
+ transcribe_drums,
4183
+ transcribe_bass,
4184
+ transcribe_other_or_accompaniment,
4185
+ merge_drums_to_render,
4186
+ merge_bass_to_render,
4187
+ merge_other_or_accompaniment
4188
+ ]
4189
+ )
4190
+
4191
+ # When the advanced stem control checkbox is toggled, update all related UI parts
4192
  enable_advanced_separation.change(
4193
  fn=update_separation_mode_ui,
4194
  inputs=enable_advanced_separation,
requirements.txt CHANGED
@@ -28,4 +28,7 @@ basic-pitch[tf] @ git+https://github.com/avan06/basic-pitch; sys_platform == 'li
28
 
29
  git+https://github.com/avan06/pyfluidsynth
30
 
31
- demucs
 
 
 
 
28
 
29
  git+https://github.com/avan06/pyfluidsynth
30
 
31
+ demucs
32
+
33
+ audio-separator[gpu]; sys_platform != 'darwin'
34
+ audio-separator[cpu]; sys_platform == 'darwin'