File size: 4,575 Bytes
3883c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio
from gradio.components import IOComponent

import webui.modules.models as mod
import webui.modules.implementations.ttsmodels as tts_models
from webui.modules import util

mod_type = 'text-to-speech'

loader: mod.TTSModelLoader = mod.TTSModelLoader

to_rvc, audio_out = None, None


def get_models_installed():
    # return [model for model in mod.get_installed_models(mod_type) if model in [tts.replace('/', '--') for tts in mod.all_tts_models()]]
    return [model for model in mod.get_installed_models(mod_type) if
            model in [tts.replace('/', '--') for tts in mod.all_tts_models()]] + \
           [model.model for model in mod.all_tts() if model.no_install]


def filter_components(components):
    return [component for component in components if isinstance(component, IOComponent)]


def text_to_speech():
    with gradio.Row():
        with gradio.Column():
            all_components_dict = tts_models.all_elements_dict()
            all_components = tts_models.all_elements(all_components_dict)
            with gradio.Row():
                selected = gradio.Dropdown(get_models_installed(), label='Model')
                with gradio.Column(elem_classes='smallsplit'):
                    refresh = gradio.Button('🔃', variant='tool secondary')
                    unload = gradio.Button('💣', variant='tool primary')
                refresh.click(fn=get_models_installed, outputs=selected, show_progress=True)

                def unload_model():
                    global loader
                    if isinstance(loader, mod.TTSModelLoader):
                        loader.unload_model()
                    return [gradio.update(value='')] + [gradio.update(visible=False) for _ in all_components]

                unload.click(fn=unload_model, outputs=[selected] + all_components, show_progress=True)

                def load_model(model, progress=gradio.Progress()):
                    progress(0, desc='Loading model.')
                    global loader
                    if not (hasattr(loader, 'model') and model.lower().endswith(loader.model.lower())):
                        unload_model()
                    loader = loader.from_model(model)
                    loader.load_model(progress)
                    inputs = all_components_dict[loader.model]
                    return_value = [gradio.update()] + [
                        gradio.update(visible=element in inputs and not (hasattr(element, 'hide') and element.hide)) for
                        element in all_components]
                    progress(1, desc='Loaded model!')
                    return return_value

                selected.select(fn=load_model, inputs=selected, outputs=[selected] + all_components, show_progress=True)
        with gradio.Column():
            global to_rvc, audio_out
            with gradio.Row():
                generate = gradio.Button('Generate', variant='primary', elem_id='tts-generate')
                to_rvc = gradio.Button('Send to RVC', elem_id='tts-to-rvc')
            with gradio.Row():
                audio_out = gradio.Audio(interactive=False)
            with gradio.Row():
                video_out = gradio.Video()
            with gradio.Row():
                file_out = gradio.File()

    filtered_components = filter_components(all_components)

    def _generate(progress=gradio.Progress(track_tqdm=True), *values):
        progress(0, desc='Generating audio...')
        global loader
        inputs = filtered_components
        inputs = [values[i] for i in range(len(inputs)) if
                  inputs[i] in all_components_dict[loader.model]]  # Filter and convert inputs
        response, file = loader.get_response(*inputs, progress=progress)
        return response, util.make_waveform(response), file

    generate.click(fn=_generate, inputs=filtered_components,
                   outputs=[audio_out, video_out, file_out], show_progress=True)

    # def _generate(inputs, values, progress=gradio.Progress()):
    #     global loader
    #     inputs = [values[i] for i in range(len(inputs)) if
    #               inputs[i] in all_components_dict[loader.model]]  # Filter and convert inputs
    #     response, file = loader.get_response(*inputs, progress=progress)
    #     return response, util.make_waveform(response), file
    #
    # filtered_components = filter_components(all_components)
    # generate.click(fn=lambda *values: _generate(filtered_components, values), inputs=filtered_components,
    #                outputs=[audio_out, video_out, file_out], show_progress=True)