File size: 8,828 Bytes
3883c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio
import webui.ui.tabs.training.training.rvc_workspace as rvc_ws


def change_setting(name, value):
    rvc_ws.current_workspace.data[name] = value
    rvc_ws.current_workspace.save()


def train_rvc():
    with gradio.Row():
        with gradio.Column():
            gradio.Markdown('''
            # πŸš€ Workspaces
            ''', elem_classes='text-center')
            with gradio.Tabs():
                with gradio.Tab('🧠 Load'):
                    with gradio.Row():
                        workspace_select = gradio.Dropdown(rvc_ws.get_workspaces(), label='Select workspace')
                        refresh_workspaces = gradio.Button('πŸ”ƒ', variant='primary tool offset--10')
                with gradio.Tab('βž• Create'):
                    create_name = gradio.Textbox(label='Name')
                    version_sample_rate = gradio.Radio(['v1 40k', 'v1 48k', 'v2 40k', 'v2 48k'], value='v2 40k', label='version and sample rate')
                    create_button = gradio.Button('Create workspace', variant='primary')
        with gradio.Column(visible=False) as settings:
            gradio.Markdown('''
            # βš™ Settings
            ''', elem_classes='text-center')
            with gradio.Tabs():
                with gradio.Tab('🚚 data'):
                    dataset_path = gradio.Textbox(label='Dataset path', info='The path to the dataset containing your training audio.')
                    dataset_path.change(fn=lambda val: change_setting('dataset', val), inputs=dataset_path)
                    process_dataset = gradio.Button('Resample and split dataset', variant='primary')
                    f0_method = gradio.CheckboxGroup(["none", "dio", "pm", "harvest", "torchcrepe", "torchcrepe tiny", "mangio-crepe", "mangio-crepe tiny", "rmvpe"], value='harvest', label='Pitch extraction method', info='Harvest is usually good, crepe has potential to be even better.')
                    crepe_hop_length = gradio.Slider(visible=False, minimum=64, maximum=512, step=64, value=128,
                                                     label='torchcrepe hop length',
                                                     info='The length of the hops used for torchcrepe\'s crepe implementation',
                                                     interactive=True)
                    filter_radius = gradio.Slider(0, 7, 3, step=1, label='Filter radius',
                                                  info='Default: 3. Smooth out the pitches, should yield less voice cracks.', interactive=True)

                    def set_f0(val: list[str]):
                        change_setting('f0', val)
                        return gradio.update(visible=any(['crepe' in v for v in val]))

                    f0_method.change(fn=set_f0, inputs=f0_method, outputs=crepe_hop_length)
                    crepe_hop_length.change(fn=lambda val: change_setting('crepe_hop_length', val), inputs=crepe_hop_length)
                    filter_radius.change(fn=lambda val: change_setting('filter_radius', val), inputs=filter_radius)
                    pitch_extract = gradio.Button('Extract pitches', variant='primary')
                    create_index_button = gradio.Button('Create index file', variant='primary padding-h-0')
                with gradio.Tab('πŸƒβ€ train'):
                    with gradio.Row():
                        base_ckpt = gradio.Dropdown(['f0'], value='f0', label='Base checkpoint', info='The base checkpoint to train from, select f0 if you haven\'t trained yet.')
                        refresh_checkpoints = gradio.Button('πŸ”ƒ', variant='primary tool offset--10')

                        def refresh_checkpoints_click():
                            return gradio.update(choices=rvc_ws.get_continue_models())

                        refresh_checkpoints.click(fn=refresh_checkpoints_click, outputs=base_ckpt)
                    epochs = gradio.Number(label='Epochs to train (added)', value=100, interactive=True)
                    batch_size = gradio.Slider(1, 50, 6, step=1, label='Batch size', info='Higher uses more VRAM.')
                    batch_size.change(fn=lambda v: change_setting('batch_size', v), inputs=batch_size)
                    save_n_epochs = gradio.Number(10, label='Save every n epochs', info='Save every time n epochs of training have been processed. 0 = disabled.')
                    save_n_epochs.change(fn=lambda v: change_setting('save_epochs', v), inputs=save_n_epochs)
                    with gradio.Accordion('advanced', open=False):
                        lr = gradio.Textbox(label='Learning rate', info='Default: 1e-4, must be readable as a number')
                        lr.change(fn=lambda v: change_setting('lr', v), inputs=lr)
                    with gradio.Row():
                        train_button = gradio.Button('Train', variant='primary padding-h-0')
                        stop_button = gradio.Button('Stop', variant='stop padding-h-0')
                    copy_button = gradio.Button('Copy to RVC models')

                with gradio.Tab('❓ how to?'):
                    gradio.Markdown('''
                    **Don't use too much training data or too many epochs, your model will overtrain. 10 to 60 minutes is the standard**
                    ## How to train
                    1. Collect audio data (if from youtube, you can use the Utils tab to download audio from youtube quickly).
                      * Optional: use Utils tab to split vocals from music if there's background music.
                    2. Open the "data" tab.
                        1. Put the path to the folder containing your audio data as the training path.
                        2. Click "Process dataset"
                        3. Pick your preferred pitch extraction method. Harvest and crepe are recommended.
                        4. Click "Extract pitches"
                        5. Click "Create index file" (Optional, adds a .index file, which uses more space, but improves results)
                    3. Open the "train" tab.
                        1. Set training epochs. (you'll need more if you have less audio) Try it out, you can continue training by selecting the model to continue from and then training.
                        2. Click "Train", and wait for training to complete.
                    4. Select the model in the "Base checkpoint" dropdown (refresh first)
                    5. Now, click "Copy to RVC models". Now you can use the trained model inside of your RVC tab.
                    ''')
        with gradio.Column():
            gradio.Markdown('''
            # ✨ Status
            ''', elem_classes='text-center')
            with gradio.Row():
                loss_plot = gradio.LinePlot(label='Loss', x_title='steps', y_title='loss', x='x', y='y')
            with gradio.Row():
                status_box = gradio.TextArea(label='Status')

    def load_workspace(name):
        rvc_ws.current_workspace = rvc_ws.RvcWorkspace(name).load()
        ws = rvc_ws.current_workspace
        return f'Loaded workspace {name}', ws.name, gradio.update(visible=True), ws.data['dataset'], ws.data['f0'], ws.data['crepe_hop_length'], ws.data['save_epochs'], ws.data['batch_size'], list_models(), ws.data['lr'], ws.data['filter_radius']

    def list_workspaces():
        return gradio.update(choices=rvc_ws.get_workspaces())

    def list_models():
        return gradio.update(choices=rvc_ws.get_continue_models())

    def create_workspace(name: str, vsr):
        if name == "":
            raise gradio.Error("Please enter a name for your new workspace.")
        rvc_ws.current_workspace = rvc_ws.RvcWorkspace(name).create({
            'vsr': vsr
        })
        rvc_ws.current_workspace.save()
        return load_workspace(name)

    setting_elements = [status_box, workspace_select, settings, dataset_path, f0_method, crepe_hop_length, save_n_epochs, batch_size, base_ckpt, lr, filter_radius]

    process_dataset.click(fn=rvc_ws.process_dataset, outputs=status_box)
    pitch_extract.click(fn=rvc_ws.pitch_extract, outputs=status_box)

    copy_button.click(fn=rvc_ws.copy_model, inputs=base_ckpt, outputs=status_box, queue=False)
    train_button.click(fn=rvc_ws.train_model, inputs=[base_ckpt, epochs], outputs=[status_box, loss_plot])
    create_index_button.click(fn=rvc_ws.create_index, outputs=status_box)
    stop_button.click(fn=rvc_ws.cancel_train, outputs=status_box, queue=False)


    workspace_select.select(fn=load_workspace, inputs=workspace_select, outputs=setting_elements, show_progress=True)
    refresh_workspaces.click(fn=list_workspaces, outputs=workspace_select, show_progress=True, queue=False)
    create_button.click(fn=create_workspace, inputs=[create_name, version_sample_rate], outputs=setting_elements, show_progress=True)