Simonlob commited on
Commit
69a49c7
·
verified ·
1 Parent(s): 1215562

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -174
app.py CHANGED
@@ -1,174 +1,174 @@
1
- from pathlib import Path
2
- import argparse
3
- import soundfile as sf
4
- import torch
5
- import io
6
- import argparse
7
- from matcha.hifigan.config import v1
8
- from matcha.hifigan.denoiser import Denoiser
9
- from matcha.hifigan.env import AttrDict
10
- from matcha.hifigan.models import Generator as HiFiGAN
11
- from matcha.models.matcha_tts import MatchaTTS
12
- from matcha.text import sequence_to_text, text_to_sequence
13
- from matcha.utils.utils import intersperse
14
- import gradio as gr
15
- import requests
16
-
17
- def download_file(url, save_path):
18
- response = requests.get(url)
19
- with open(save_path, 'wb') as file:
20
- file.write(response.content)
21
-
22
- url_checkpoint = 'https://github.com/simonlobgromov/AkylAI_Matcha_Checkpoint/releases/download/Matcha-TTS/checkpoint_epoch.499.ckpt'
23
- save_checkpoint_path = './checkpoints/checkpoint.ckpt'
24
- url_generator = 'https://github.com/simonlobgromov/AkylAI_Matcha_HiFiGan/releases/download/Generator/generator_v1'
25
- save_generator_path = './checkpoints/generator'
26
-
27
- download_file(url_checkpoint, save_checkpoint_path)
28
- download_file(url_generator, save_generator_path)
29
-
30
- def load_matcha( checkpoint_path, device):
31
- model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
32
- _ = model.eval()
33
- return model
34
-
35
- def load_hifigan(checkpoint_path, device):
36
- h = AttrDict(v1)
37
- hifigan = HiFiGAN(h).to(device)
38
- hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"])
39
- _ = hifigan.eval()
40
- hifigan.remove_weight_norm()
41
- return hifigan
42
-
43
- def load_vocoder(checkpoint_path, device):
44
- vocoder = None
45
- vocoder = load_hifigan(checkpoint_path, device)
46
- denoiser = Denoiser(vocoder, mode="zeros")
47
- return vocoder, denoiser
48
-
49
- def process_text(i: int, text: str, device: torch.device):
50
- print(f"[{i}] - Input text: {text}")
51
- x = torch.tensor(
52
- intersperse(text_to_sequence(text, ["kyrgyz_cleaners"]), 0),
53
- dtype=torch.long,
54
- device=device,
55
- )[None]
56
- x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
57
- x_phones = sequence_to_text(x.squeeze(0).tolist())
58
- print(f"[{i}] - Phonetised text: {x_phones[1::2]}")
59
- return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
60
-
61
- def to_waveform(mel, vocoder, denoiser=None):
62
- audio = vocoder(mel).clamp(-1, 1)
63
- if denoiser is not None:
64
- audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
65
- return audio.cpu().squeeze()
66
-
67
- @torch.inference_mode()
68
- def process_text_gradio(text):
69
- output = process_text(1, text, device)
70
- return output["x_phones"][1::2], output["x"], output["x_lengths"]
71
-
72
- @torch.inference_mode()
73
- def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk=-1):
74
- spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
75
- output = model.synthesise(
76
- text,
77
- text_length,
78
- n_timesteps=n_timesteps,
79
- temperature=temperature,
80
- spks=spk,
81
- length_scale=length_scale,
82
- )
83
- output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
84
- return output["waveform"].numpy()
85
-
86
- def get_inference(text, n_timesteps=20, mel_temp = 0.667, length_scale=0.8, spk=-1):
87
- phones, text, text_lengths = process_text_gradio(text)
88
- print(type(synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)))
89
- return synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
90
-
91
-
92
- device = torch.device("cpu")
93
- model_path = './checkpoints/checkpoint.ckpt'
94
- vocoder_path = './checkpoints/generator'
95
- model = load_matcha(model_path, device)
96
- vocoder, denoiser = load_vocoder(vocoder_path, device)
97
-
98
- def gen_tts(text, speaking_rate):
99
- return 22050, get_inference(text = text, length_scale = speaking_rate)
100
-
101
- default_text = "Баарыңарга салам, менин атым Акылай."
102
-
103
- css = """
104
- #share-btn-container {
105
- display: flex;
106
- padding-left: 0.5rem !important;
107
- padding-right: 0.5rem !important;
108
- background-color: #000000;
109
- justify-content: center;
110
- align-items: center;
111
- border-radius: 9999px !important;
112
- width: 13rem;
113
- margin-top: 10px;
114
- margin-left: auto;
115
- flex: unset !important;
116
- }
117
- #share-btn {
118
- all: initial;
119
- color: #ffffff;
120
- font-weight: 600;
121
- cursor: pointer;
122
- font-family: 'IBM Plex Sans', sans-serif;
123
- margin-left: 0.5rem !important;
124
- padding-top: 0.25rem !important;
125
- padding-bottom: 0.25rem !important;
126
- right:0;
127
- }
128
- #share-btn * {
129
- all: unset !important;
130
- }
131
- #share-btn-container div:nth-child(-n+2){
132
- width: auto !important;
133
- min-height: 0px !important;
134
- }
135
- #share-btn-container .wrap {
136
- display: none !important;
137
- }
138
- """
139
- with gr.Blocks(css=css) as block:
140
- gr.HTML(
141
- """
142
- <div style="text-align: center; max-width: 700px; margin: 0 auto;">
143
- <div
144
- style="
145
- display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
146
- "
147
- >
148
- <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
149
- Akyl-AI TTS
150
- </h1>
151
- </div>
152
- </div>
153
- """
154
- )
155
- with gr.Row():
156
- image_path = "./photo_2024-04-07_15-59-52.png"
157
- gr.Image(image_path, label=None, width=660, height=315, show_label=False)
158
- with gr.Row():
159
- with gr.Column():
160
- input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
161
- speaking_rate = gr.Slider(label='Speaking rate', minimum=0.5, maximum=1, step=0.05, value=0.8, interactive=True, show_label=True, elem_id="speaking_rate")
162
-
163
-
164
- run_button = gr.Button("Generate Audio", variant="primary")
165
- with gr.Column():
166
- audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out")
167
-
168
- inputs = [input_text, speaking_rate]
169
- outputs = [audio_out]
170
- run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
171
-
172
-
173
- block.queue()
174
- block.launch(share=True)
 
1
+ from pathlib import Path
2
+ import argparse
3
+ import soundfile as sf
4
+ import torch
5
+ import io
6
+ import argparse
7
+ from matcha.hifigan.config import v1
8
+ from matcha.hifigan.denoiser import Denoiser
9
+ from matcha.hifigan.env import AttrDict
10
+ from matcha.hifigan.models import Generator as HiFiGAN
11
+ from matcha.models.matcha_tts import MatchaTTS
12
+ from matcha.text import sequence_to_text, text_to_sequence
13
+ from matcha.utils.utils import intersperse
14
+ import gradio as gr
15
+ import requests
16
+
17
+ def download_file(url, save_path):
18
+ response = requests.get(url)
19
+ with open(save_path, 'wb') as file:
20
+ file.write(response.content)
21
+
22
+ url_checkpoint = 'https://github.com/simonlobgromov/AkylAI_Matcha_Checkpoint/releases/download/Matcha-TTS/checkpoint_epoch.499.ckpt'
23
+ save_checkpoint_path = './checkpoints/checkpoint.ckpt'
24
+ url_generator = 'https://github.com/simonlobgromov/AkylAI_Matcha_HiFiGan/releases/download/Generator/generator_v1'
25
+ save_generator_path = './checkpoints/generator'
26
+
27
+ download_file(url_checkpoint, save_checkpoint_path)
28
+ download_file(url_generator, save_generator_path)
29
+
30
+ def load_matcha( checkpoint_path, device):
31
+ model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
32
+ _ = model.eval()
33
+ return model
34
+
35
+ def load_hifigan(checkpoint_path, device):
36
+ h = AttrDict(v1)
37
+ hifigan = HiFiGAN(h).to(device)
38
+ hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"])
39
+ _ = hifigan.eval()
40
+ hifigan.remove_weight_norm()
41
+ return hifigan
42
+
43
+ def load_vocoder(checkpoint_path, device):
44
+ vocoder = None
45
+ vocoder = load_hifigan(checkpoint_path, device)
46
+ denoiser = Denoiser(vocoder, mode="zeros")
47
+ return vocoder, denoiser
48
+
49
+ def process_text(i: int, text: str, device: torch.device):
50
+ print(f"[{i}] - Input text: {text}")
51
+ x = torch.tensor(
52
+ intersperse(text_to_sequence(text, ["kyrgyz_cleaners"]), 0),
53
+ dtype=torch.long,
54
+ device=device,
55
+ )[None]
56
+ x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
57
+ x_phones = sequence_to_text(x.squeeze(0).tolist())
58
+ print(f"[{i}] - Phonetised text: {x_phones[1::2]}")
59
+ return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
60
+
61
+ def to_waveform(mel, vocoder, denoiser=None):
62
+ audio = vocoder(mel).clamp(-1, 1)
63
+ if denoiser is not None:
64
+ audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
65
+ return audio.cpu().squeeze()
66
+
67
+ @torch.inference_mode()
68
+ def process_text_gradio(text):
69
+ output = process_text(1, text, device)
70
+ return output["x_phones"][1::2], output["x"], output["x_lengths"]
71
+
72
+ @torch.inference_mode()
73
+ def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk=-1):
74
+ spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
75
+ output = model.synthesise(
76
+ text,
77
+ text_length,
78
+ n_timesteps=n_timesteps,
79
+ temperature=temperature,
80
+ spks=spk,
81
+ length_scale=length_scale,
82
+ )
83
+ output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
84
+ return output["waveform"].numpy()
85
+
86
+ def get_inference(text, n_timesteps=20, mel_temp = 0.667, length_scale=0.8, spk=-1):
87
+ phones, text, text_lengths = process_text_gradio(text)
88
+ print(type(synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)))
89
+ return synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
90
+
91
+
92
+ device = torch.device("cpu")
93
+ model_path = './checkpoints/checkpoint.ckpt'
94
+ vocoder_path = './checkpoints/generator'
95
+ model = load_matcha(model_path, device)
96
+ vocoder, denoiser = load_vocoder(vocoder_path, device)
97
+
98
+ def gen_tts(text, speaking_rate):
99
+ return 22050, get_inference(text = text, length_scale = speaking_rate)
100
+
101
+ default_text = "Баарыңарга салам, менин атым Акылай."
102
+
103
+ css = """
104
+ #share-btn-container {
105
+ display: flex;
106
+ padding-left: 0.5rem !important;
107
+ padding-right: 0.5rem !important;
108
+ background-color: #000000;
109
+ justify-content: center;
110
+ align-items: center;
111
+ border-radius: 9999px !important;
112
+ width: 13rem;
113
+ margin-top: 10px;
114
+ margin-left: auto;
115
+ flex: unset !important;
116
+ }
117
+ #share-btn {
118
+ all: initial;
119
+ color: #ffffff;
120
+ font-weight: 600;
121
+ cursor: pointer;
122
+ font-family: 'IBM Plex Sans', sans-serif;
123
+ margin-left: 0.5rem !important;
124
+ padding-top: 0.25rem !important;
125
+ padding-bottom: 0.25rem !important;
126
+ right:0;
127
+ }
128
+ #share-btn * {
129
+ all: unset !important;
130
+ }
131
+ #share-btn-container div:nth-child(-n+2){
132
+ width: auto !important;
133
+ min-height: 0px !important;
134
+ }
135
+ #share-btn-container .wrap {
136
+ display: none !important;
137
+ }
138
+ """
139
+ with gr.Blocks(css=css) as block:
140
+ gr.HTML(
141
+ """
142
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
143
+ <div
144
+ style="
145
+ display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
146
+ "
147
+ >
148
+ <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
149
+ Akyl-AI TTS
150
+ </h1>
151
+ </div>
152
+ </div>
153
+ """
154
+ )
155
+ with gr.Row():
156
+ image_path = "./photo_2024-04-07_15-59-52.png"
157
+ gr.Image(image_path, label=None, width=660, height=315, show_label=False)
158
+ with gr.Row():
159
+ with gr.Column():
160
+ input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
161
+ speaking_rate = gr.Slider(label='Speaking rate', minimum=0.5, maximum=1, step=0.05, value=0.8, interactive=True, show_label=True, elem_id="speaking_rate")
162
+
163
+
164
+ run_button = gr.Button("Generate Audio", variant="primary")
165
+ with gr.Column():
166
+ audio_out = gr.Audio(label="AkylAi-TTS", type="numpy", elem_id="audio_out")
167
+
168
+ inputs = [input_text, speaking_rate]
169
+ outputs = [audio_out]
170
+ run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
171
+
172
+
173
+ block.queue()
174
+ block.launch(share=True)