Persian-TTS / app.py
akbarazimifar's picture
Update app.py
e343100 verified
raw
history blame
3.85 kB
import tempfile
import os
import gradio as gr
from TTS.config import load_config
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
from TTS.utils.download import download_url
MODEL_NAMES = [
"vits male1 (best)",
"vits female (best)",
"vits-male",
"vits female1",
"glowtts-male",
"glowtts-female",
"female tacotron2"
]
MAX_TXT_LEN = 800
MODELS_DIRECTORY = "models"
modelInfo = [
["vits-male", "best_model_65633.pth", "config-0.json",
"https://huggingface.co/Kamtera/persian-tts-male-vits/resolve/main/"],
["vits female (best)", "checkpoint_48000.pth", "config-2.json",
"https://huggingface.co/Kamtera/persian-tts-female-vits/resolve/main/"],
["glowtts-male", "best_model_77797.pth", "config-1.json",
"https://huggingface.co/Kamtera/persian-tts-male-glow_tts/resolve/main/"],
["glowtts-female", "best_model.pth", "config.json",
"https://huggingface.co/Kamtera/persian-tts-female-glow_tts/resolve/main/"],
["vits male1 (best)", "checkpoint_88000.pth", "config.json",
"https://huggingface.co/Kamtera/persian-tts-male1-vits/resolve/main/"],
["vits female1", "checkpoint_50000.pth", "config.json",
"https://huggingface.co/Kamtera/persian-tts-female1-vits/resolve/main/"],
["female tacotron2", "checkpoint_313000.pth", "config-2.json",
"https://huggingface.co/Kamtera/persian-tts-female-tacotron2/resolve/main/"]
]
class PersianTTS:
def __init__(self):
self.model_manager = ModelManager(MODELS_DIRECTORY)
for model in modelInfo:
model_name, model_filename, config_filename, model_url = model
self.model_manager.download_model(model_name, model_filename, config_filename, model_url)
def tts(self, text: str, model_name: str):
if len(text) > MAX_TXT_LEN:
text = text[:MAX_TXT_LEN]
print(f"Input text was cutoff since it went over the {MAX_TXT_LEN} character limit.")
# synthesize
model_path, config_path = self.model_manager.get_model_paths(model_name)
synthesizer = Synthesizer(model_path, config_path)
if synthesizer is None:
raise NameError("model not found")
wavs = synthesizer.tts(text)
# return output
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
synthesizer.save_wav(wavs, fp)
return fp.name
article = ""
examples = [
["و خداوند شما را با ارسال روح در جسم زندگانی و حیات بخشید", "vits-male"],
["تاجر تو چه تجارت می کنی ، تو را چه که چه تجارت می کنم؟", "vits female (best)"],
["شیش سیخ جیگر سیخی شیش هزار", "vits female (best)"],
["سه شیشه شیر ، سه سیر سرشیر", "vits female (best)"],
["دزدی دزدید ز بز دزدی بزی ، عجب دزدی که دزدید ز بز دزدی بزی", "vits male1 (best)"],
["مثنوی یکی از قالب های شعری است ک هر بیت قافیه ی جداگانه دارد", "vits female1"],
["در گلو ماند خس او سالها، چیست آن خس مهر جاه و مالها", "vits male1 (best)"],
]
persian_tts = PersianTTS()
iface = gr.Interface(
fn=persian_tts.tts,
inputs=[
gr.Textbox(
label="Text",
value="زندگی فقط یک بار است؛ از آن به خوبی استفاده کن",
),
gr.Radio(
label="Pick a TTS Model ",
choices=MODEL_NAMES,
value="vits-female",
),
],
outputs=gr.Audio(label="Output", type='filepath'),
examples=examples,
title="🗣️ Persian tts 🗣️",
description=description,
article=article,
live=False
)
iface.launch(share=False)