Update app.py
Browse files
app.py
CHANGED
@@ -1,174 +1,174 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
import argparse
|
3 |
-
import soundfile as sf
|
4 |
-
import torch
|
5 |
-
import io
|
6 |
-
import argparse
|
7 |
-
from matcha.hifigan.config import v1
|
8 |
-
from matcha.hifigan.denoiser import Denoiser
|
9 |
-
from matcha.hifigan.env import AttrDict
|
10 |
-
from matcha.hifigan.models import Generator as HiFiGAN
|
11 |
-
from matcha.models.matcha_tts import MatchaTTS
|
12 |
-
from matcha.text import sequence_to_text, text_to_sequence
|
13 |
-
from matcha.utils.utils import intersperse
|
14 |
-
import gradio as gr
|
15 |
-
import requests
|
16 |
-
|
17 |
-
def download_file(url, save_path):
|
18 |
-
response = requests.get(url)
|
19 |
-
with open(save_path, 'wb') as file:
|
20 |
-
file.write(response.content)
|
21 |
-
|
22 |
-
url_checkpoint = 'https://github.com/simonlobgromov/AkylAI_Matcha_Checkpoint/releases/download/Matcha-TTS/checkpoint_epoch.499.ckpt'
|
23 |
-
save_checkpoint_path = './checkpoints/checkpoint.ckpt'
|
24 |
-
url_generator = 'https://github.com/simonlobgromov/AkylAI_Matcha_HiFiGan/releases/download/Generator/generator_v1'
|
25 |
-
save_generator_path = './checkpoints/generator'
|
26 |
-
|
27 |
-
download_file(url_checkpoint, save_checkpoint_path)
|
28 |
-
download_file(url_generator, save_generator_path)
|
29 |
-
|
30 |
-
def load_matcha( checkpoint_path, device):
|
31 |
-
model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
|
32 |
-
_ = model.eval()
|
33 |
-
return model
|
34 |
-
|
35 |
-
def load_hifigan(checkpoint_path, device):
|
36 |
-
h = AttrDict(v1)
|
37 |
-
hifigan = HiFiGAN(h).to(device)
|
38 |
-
hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"])
|
39 |
-
_ = hifigan.eval()
|
40 |
-
hifigan.remove_weight_norm()
|
41 |
-
return hifigan
|
42 |
-
|
43 |
-
def load_vocoder(checkpoint_path, device):
|
44 |
-
vocoder = None
|
45 |
-
vocoder = load_hifigan(checkpoint_path, device)
|
46 |
-
denoiser = Denoiser(vocoder, mode="zeros")
|
47 |
-
return vocoder, denoiser
|
48 |
-
|
49 |
-
def process_text(i: int, text: str, device: torch.device):
|
50 |
-
print(f"[{i}] - Input text: {text}")
|
51 |
-
x = torch.tensor(
|
52 |
-
intersperse(text_to_sequence(text, ["kyrgyz_cleaners"]), 0),
|
53 |
-
dtype=torch.long,
|
54 |
-
device=device,
|
55 |
-
)[None]
|
56 |
-
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
|
57 |
-
x_phones = sequence_to_text(x.squeeze(0).tolist())
|
58 |
-
print(f"[{i}] - Phonetised text: {x_phones[1::2]}")
|
59 |
-
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
|
60 |
-
|
61 |
-
def to_waveform(mel, vocoder, denoiser=None):
|
62 |
-
audio = vocoder(mel).clamp(-1, 1)
|
63 |
-
if denoiser is not None:
|
64 |
-
audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
|
65 |
-
return audio.cpu().squeeze()
|
66 |
-
|
67 |
-
@torch.inference_mode()
|
68 |
-
def process_text_gradio(text):
|
69 |
-
output = process_text(1, text, device)
|
70 |
-
return output["x_phones"][1::2], output["x"], output["x_lengths"]
|
71 |
-
|
72 |
-
@torch.inference_mode()
|
73 |
-
def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk=-1):
|
74 |
-
spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
|
75 |
-
output = model.synthesise(
|
76 |
-
text,
|
77 |
-
text_length,
|
78 |
-
n_timesteps=n_timesteps,
|
79 |
-
temperature=temperature,
|
80 |
-
spks=spk,
|
81 |
-
length_scale=length_scale,
|
82 |
-
)
|
83 |
-
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
84 |
-
return output["waveform"].numpy()
|
85 |
-
|
86 |
-
def get_inference(text, n_timesteps=20, mel_temp = 0.667, length_scale=0.8, spk=-1):
|
87 |
-
phones, text, text_lengths = process_text_gradio(text)
|
88 |
-
print(type(synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)))
|
89 |
-
return synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
|
90 |
-
|
91 |
-
|
92 |
-
device = torch.device("cpu")
|
93 |
-
model_path = './checkpoints/checkpoint.ckpt'
|
94 |
-
vocoder_path = './checkpoints/generator'
|
95 |
-
model = load_matcha(model_path, device)
|
96 |
-
vocoder, denoiser = load_vocoder(vocoder_path, device)
|
97 |
-
|
98 |
-
def gen_tts(text, speaking_rate):
|
99 |
-
return 22050, get_inference(text = text, length_scale = speaking_rate)
|
100 |
-
|
101 |
-
default_text = "Баарыңарга салам, менин атым Акылай."
|
102 |
-
|
103 |
-
css = """
|
104 |
-
#share-btn-container {
|
105 |
-
display: flex;
|
106 |
-
padding-left: 0.5rem !important;
|
107 |
-
padding-right: 0.5rem !important;
|
108 |
-
background-color: #000000;
|
109 |
-
justify-content: center;
|
110 |
-
align-items: center;
|
111 |
-
border-radius: 9999px !important;
|
112 |
-
width: 13rem;
|
113 |
-
margin-top: 10px;
|
114 |
-
margin-left: auto;
|
115 |
-
flex: unset !important;
|
116 |
-
}
|
117 |
-
#share-btn {
|
118 |
-
all: initial;
|
119 |
-
color: #ffffff;
|
120 |
-
font-weight: 600;
|
121 |
-
cursor: pointer;
|
122 |
-
font-family: 'IBM Plex Sans', sans-serif;
|
123 |
-
margin-left: 0.5rem !important;
|
124 |
-
padding-top: 0.25rem !important;
|
125 |
-
padding-bottom: 0.25rem !important;
|
126 |
-
right:0;
|
127 |
-
}
|
128 |
-
#share-btn * {
|
129 |
-
all: unset !important;
|
130 |
-
}
|
131 |
-
#share-btn-container div:nth-child(-n+2){
|
132 |
-
width: auto !important;
|
133 |
-
min-height: 0px !important;
|
134 |
-
}
|
135 |
-
#share-btn-container .wrap {
|
136 |
-
display: none !important;
|
137 |
-
}
|
138 |
-
"""
|
139 |
-
with gr.Blocks(css=css) as block:
|
140 |
-
gr.HTML(
|
141 |
-
"""
|
142 |
-
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
|
143 |
-
<div
|
144 |
-
style="
|
145 |
-
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
|
146 |
-
"
|
147 |
-
>
|
148 |
-
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
|
149 |
-
Akyl-AI TTS
|
150 |
-
</h1>
|
151 |
-
</div>
|
152 |
-
</div>
|
153 |
-
"""
|
154 |
-
)
|
155 |
-
with gr.Row():
|
156 |
-
image_path = "./photo_2024-04-07_15-59-52.png"
|
157 |
-
gr.Image(image_path, label=None, width=660, height=315, show_label=False)
|
158 |
-
with gr.Row():
|
159 |
-
with gr.Column():
|
160 |
-
input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
|
161 |
-
speaking_rate = gr.Slider(label='Speaking rate', minimum=0.5, maximum=1, step=0.05, value=0.8, interactive=True, show_label=True, elem_id="speaking_rate")
|
162 |
-
|
163 |
-
|
164 |
-
run_button = gr.Button("Generate Audio", variant="primary")
|
165 |
-
with gr.Column():
|
166 |
-
audio_out = gr.Audio(label="
|
167 |
-
|
168 |
-
inputs = [input_text, speaking_rate]
|
169 |
-
outputs = [audio_out]
|
170 |
-
run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
|
171 |
-
|
172 |
-
|
173 |
-
block.queue()
|
174 |
-
block.launch(share=True)
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import argparse
|
3 |
+
import soundfile as sf
|
4 |
+
import torch
|
5 |
+
import io
|
6 |
+
import argparse
|
7 |
+
from matcha.hifigan.config import v1
|
8 |
+
from matcha.hifigan.denoiser import Denoiser
|
9 |
+
from matcha.hifigan.env import AttrDict
|
10 |
+
from matcha.hifigan.models import Generator as HiFiGAN
|
11 |
+
from matcha.models.matcha_tts import MatchaTTS
|
12 |
+
from matcha.text import sequence_to_text, text_to_sequence
|
13 |
+
from matcha.utils.utils import intersperse
|
14 |
+
import gradio as gr
|
15 |
+
import requests
|
16 |
+
|
17 |
+
def download_file(url, save_path):
|
18 |
+
response = requests.get(url)
|
19 |
+
with open(save_path, 'wb') as file:
|
20 |
+
file.write(response.content)
|
21 |
+
|
22 |
+
url_checkpoint = 'https://github.com/simonlobgromov/AkylAI_Matcha_Checkpoint/releases/download/Matcha-TTS/checkpoint_epoch.499.ckpt'
|
23 |
+
save_checkpoint_path = './checkpoints/checkpoint.ckpt'
|
24 |
+
url_generator = 'https://github.com/simonlobgromov/AkylAI_Matcha_HiFiGan/releases/download/Generator/generator_v1'
|
25 |
+
save_generator_path = './checkpoints/generator'
|
26 |
+
|
27 |
+
download_file(url_checkpoint, save_checkpoint_path)
|
28 |
+
download_file(url_generator, save_generator_path)
|
29 |
+
|
30 |
+
def load_matcha( checkpoint_path, device):
|
31 |
+
model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
|
32 |
+
_ = model.eval()
|
33 |
+
return model
|
34 |
+
|
35 |
+
def load_hifigan(checkpoint_path, device):
|
36 |
+
h = AttrDict(v1)
|
37 |
+
hifigan = HiFiGAN(h).to(device)
|
38 |
+
hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)["generator"])
|
39 |
+
_ = hifigan.eval()
|
40 |
+
hifigan.remove_weight_norm()
|
41 |
+
return hifigan
|
42 |
+
|
43 |
+
def load_vocoder(checkpoint_path, device):
|
44 |
+
vocoder = None
|
45 |
+
vocoder = load_hifigan(checkpoint_path, device)
|
46 |
+
denoiser = Denoiser(vocoder, mode="zeros")
|
47 |
+
return vocoder, denoiser
|
48 |
+
|
49 |
+
def process_text(i: int, text: str, device: torch.device):
|
50 |
+
print(f"[{i}] - Input text: {text}")
|
51 |
+
x = torch.tensor(
|
52 |
+
intersperse(text_to_sequence(text, ["kyrgyz_cleaners"]), 0),
|
53 |
+
dtype=torch.long,
|
54 |
+
device=device,
|
55 |
+
)[None]
|
56 |
+
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
|
57 |
+
x_phones = sequence_to_text(x.squeeze(0).tolist())
|
58 |
+
print(f"[{i}] - Phonetised text: {x_phones[1::2]}")
|
59 |
+
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
|
60 |
+
|
61 |
+
def to_waveform(mel, vocoder, denoiser=None):
|
62 |
+
audio = vocoder(mel).clamp(-1, 1)
|
63 |
+
if denoiser is not None:
|
64 |
+
audio = denoiser(audio.squeeze(), strength=0.00025).cpu().squeeze()
|
65 |
+
return audio.cpu().squeeze()
|
66 |
+
|
67 |
+
@torch.inference_mode()
|
68 |
+
def process_text_gradio(text):
|
69 |
+
output = process_text(1, text, device)
|
70 |
+
return output["x_phones"][1::2], output["x"], output["x_lengths"]
|
71 |
+
|
72 |
+
@torch.inference_mode()
|
73 |
+
def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk=-1):
|
74 |
+
spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
|
75 |
+
output = model.synthesise(
|
76 |
+
text,
|
77 |
+
text_length,
|
78 |
+
n_timesteps=n_timesteps,
|
79 |
+
temperature=temperature,
|
80 |
+
spks=spk,
|
81 |
+
length_scale=length_scale,
|
82 |
+
)
|
83 |
+
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
|
84 |
+
return output["waveform"].numpy()
|
85 |
+
|
86 |
+
def get_inference(text, n_timesteps=20, mel_temp = 0.667, length_scale=0.8, spk=-1):
|
87 |
+
phones, text, text_lengths = process_text_gradio(text)
|
88 |
+
print(type(synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)))
|
89 |
+
return synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
|
90 |
+
|
91 |
+
|
92 |
+
device = torch.device("cpu")
|
93 |
+
model_path = './checkpoints/checkpoint.ckpt'
|
94 |
+
vocoder_path = './checkpoints/generator'
|
95 |
+
model = load_matcha(model_path, device)
|
96 |
+
vocoder, denoiser = load_vocoder(vocoder_path, device)
|
97 |
+
|
98 |
+
def gen_tts(text, speaking_rate):
|
99 |
+
return 22050, get_inference(text = text, length_scale = speaking_rate)
|
100 |
+
|
101 |
+
default_text = "Баарыңарга салам, менин атым Акылай."
|
102 |
+
|
103 |
+
css = """
|
104 |
+
#share-btn-container {
|
105 |
+
display: flex;
|
106 |
+
padding-left: 0.5rem !important;
|
107 |
+
padding-right: 0.5rem !important;
|
108 |
+
background-color: #000000;
|
109 |
+
justify-content: center;
|
110 |
+
align-items: center;
|
111 |
+
border-radius: 9999px !important;
|
112 |
+
width: 13rem;
|
113 |
+
margin-top: 10px;
|
114 |
+
margin-left: auto;
|
115 |
+
flex: unset !important;
|
116 |
+
}
|
117 |
+
#share-btn {
|
118 |
+
all: initial;
|
119 |
+
color: #ffffff;
|
120 |
+
font-weight: 600;
|
121 |
+
cursor: pointer;
|
122 |
+
font-family: 'IBM Plex Sans', sans-serif;
|
123 |
+
margin-left: 0.5rem !important;
|
124 |
+
padding-top: 0.25rem !important;
|
125 |
+
padding-bottom: 0.25rem !important;
|
126 |
+
right:0;
|
127 |
+
}
|
128 |
+
#share-btn * {
|
129 |
+
all: unset !important;
|
130 |
+
}
|
131 |
+
#share-btn-container div:nth-child(-n+2){
|
132 |
+
width: auto !important;
|
133 |
+
min-height: 0px !important;
|
134 |
+
}
|
135 |
+
#share-btn-container .wrap {
|
136 |
+
display: none !important;
|
137 |
+
}
|
138 |
+
"""
|
139 |
+
with gr.Blocks(css=css) as block:
|
140 |
+
gr.HTML(
|
141 |
+
"""
|
142 |
+
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
|
143 |
+
<div
|
144 |
+
style="
|
145 |
+
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
|
146 |
+
"
|
147 |
+
>
|
148 |
+
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
|
149 |
+
Akyl-AI TTS
|
150 |
+
</h1>
|
151 |
+
</div>
|
152 |
+
</div>
|
153 |
+
"""
|
154 |
+
)
|
155 |
+
with gr.Row():
|
156 |
+
image_path = "./photo_2024-04-07_15-59-52.png"
|
157 |
+
gr.Image(image_path, label=None, width=660, height=315, show_label=False)
|
158 |
+
with gr.Row():
|
159 |
+
with gr.Column():
|
160 |
+
input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
|
161 |
+
speaking_rate = gr.Slider(label='Speaking rate', minimum=0.5, maximum=1, step=0.05, value=0.8, interactive=True, show_label=True, elem_id="speaking_rate")
|
162 |
+
|
163 |
+
|
164 |
+
run_button = gr.Button("Generate Audio", variant="primary")
|
165 |
+
with gr.Column():
|
166 |
+
audio_out = gr.Audio(label="AkylAi-TTS", type="numpy", elem_id="audio_out")
|
167 |
+
|
168 |
+
inputs = [input_text, speaking_rate]
|
169 |
+
outputs = [audio_out]
|
170 |
+
run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
|
171 |
+
|
172 |
+
|
173 |
+
block.queue()
|
174 |
+
block.launch(share=True)
|