|
import os |
|
import re |
|
import torch |
|
import torchaudio |
|
from huggingface_hub import hf_hub_download |
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
from TTS.tts.models.xtts import Xtts |
|
from vinorm import TTSnorm |
|
from torch.cuda.amp import autocast |
|
|
|
|
|
checkpoint_dir = "model/" |
|
repo_id = "capleaf/viXTTS" |
|
use_deepspeed = False |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
if "A100" in torch.cuda.get_device_name(0): |
|
print("Đang sử dụng GPU A100 với hỗ trợ FP16.") |
|
use_fp16 = True |
|
else: |
|
print(f"Đang sử dụng GPU: {torch.cuda.get_device_name(0)}") |
|
use_fp16 = False |
|
else: |
|
device = "cpu" |
|
use_fp16 = False |
|
|
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] |
|
for file in required_files: |
|
file_path = os.path.join(checkpoint_dir, file) |
|
if not os.path.exists(file_path): |
|
hf_hub_download( |
|
repo_id=repo_id if file != "speakers_xtts.pth" else "coqui/XTTS-v2", |
|
filename=file, |
|
local_dir=checkpoint_dir, |
|
) |
|
|
|
|
|
xtts_config = os.path.join(checkpoint_dir, "config.json") |
|
config = XttsConfig() |
|
config.load_json(xtts_config) |
|
MODEL = Xtts.init_from_config(config) |
|
MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed) |
|
MODEL.to(device) |
|
|
|
|
|
supported_languages = ["vi", "en"] |
|
|
|
def normalize_vietnamese_text(text): |
|
text = ( |
|
TTSnorm(text, unknown=False, lower=False, rule=True) |
|
.replace("..", ".") |
|
.replace("!.", "!") |
|
.replace("?.", "?") |
|
.replace(" .", ".") |
|
.replace(" ,", ",") |
|
.replace('"', "") |
|
.replace("'", "") |
|
.replace("AI", "Ây Ai") |
|
.replace("A.I", "Ây Ai") |
|
) |
|
return text |
|
|
|
def generate_speech(text, language="vi", speaker_wav=None, normalize_text=True): |
|
if language not in supported_languages: |
|
raise ValueError(f"Ngôn ngữ {language} không được hỗ trợ.") |
|
if len(text) < 2: |
|
raise ValueError("Văn bản quá ngắn.") |
|
try: |
|
if normalize_text and language == "vi": |
|
text = normalize_vietnamese_text(text) |
|
with torch.no_grad(): |
|
with autocast(enabled=use_fp16): |
|
gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents( |
|
audio_path=speaker_wav, |
|
gpt_cond_len=30 if device == "cuda" else 15, |
|
gpt_cond_chunk_len=8 if device == "cuda" else 4, |
|
max_ref_length=60 if device == "cuda" else 30, |
|
) |
|
out = MODEL.inference( |
|
text, |
|
language, |
|
gpt_cond_latent, |
|
speaker_embedding, |
|
repetition_penalty=5.0, |
|
temperature=0.75, |
|
enable_text_splitting=True, |
|
) |
|
output_file = f"output_{os.urandom(4).hex()}.wav" |
|
torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0).to("cpu"), 24000) |
|
return output_file |
|
except Exception as e: |
|
raise RuntimeError(f"Lỗi khi tạo giọng nói: {str(e)}") |