GV-a / tts.py
TDN-M's picture
Update tts.py
a7ba2cd verified
import os
import re
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from vinorm import TTSnorm
from torch.cuda.amp import autocast
# Cấu hình đường dẫn và tải mô hình
checkpoint_dir = "model/"
repo_id = "capleaf/viXTTS"
use_deepspeed = False
# Kiểm tra GPU và hỗ trợ FP16
if torch.cuda.is_available():
device = "cuda"
if "A100" in torch.cuda.get_device_name(0):
print("Đang sử dụng GPU A100 với hỗ trợ FP16.")
use_fp16 = True
else:
print(f"Đang sử dụng GPU: {torch.cuda.get_device_name(0)}")
use_fp16 = False
else:
device = "cpu"
use_fp16 = False
# Tạo thư mục nếu chưa tồn tại
os.makedirs(checkpoint_dir, exist_ok=True)
# Kiểm tra và tải các file cần thiết
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
for file in required_files:
file_path = os.path.join(checkpoint_dir, file)
if not os.path.exists(file_path):
hf_hub_download(
repo_id=repo_id if file != "speakers_xtts.pth" else "coqui/XTTS-v2",
filename=file,
local_dir=checkpoint_dir,
)
# Tải cấu hình và mô hình
xtts_config = os.path.join(checkpoint_dir, "config.json")
config = XttsConfig()
config.load_json(xtts_config)
MODEL = Xtts.init_from_config(config)
MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
MODEL.to(device)
# Danh sách ngôn ngữ được hỗ trợ
supported_languages = ["vi", "en"]
def normalize_vietnamese_text(text):
text = (
TTSnorm(text, unknown=False, lower=False, rule=True)
.replace("..", ".")
.replace("!.", "!")
.replace("?.", "?")
.replace(" .", ".")
.replace(" ,", ",")
.replace('"', "")
.replace("'", "")
.replace("AI", "Ây Ai")
.replace("A.I", "Ây Ai")
)
return text
def generate_speech(text, language="vi", speaker_wav=None, normalize_text=True):
if language not in supported_languages:
raise ValueError(f"Ngôn ngữ {language} không được hỗ trợ.")
if len(text) < 2:
raise ValueError("Văn bản quá ngắn.")
try:
if normalize_text and language == "vi":
text = normalize_vietnamese_text(text)
with torch.no_grad():
with autocast(enabled=use_fp16):
gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
audio_path=speaker_wav,
gpt_cond_len=30 if device == "cuda" else 15,
gpt_cond_chunk_len=8 if device == "cuda" else 4,
max_ref_length=60 if device == "cuda" else 30,
)
out = MODEL.inference(
text,
language,
gpt_cond_latent,
speaker_embedding,
repetition_penalty=5.0,
temperature=0.75,
enable_text_splitting=True,
)
output_file = f"output_{os.urandom(4).hex()}.wav"
torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0).to("cpu"), 24000)
return output_file
except Exception as e:
raise RuntimeError(f"Lỗi khi tạo giọng nói: {str(e)}")