TDN-M commited on
Commit
88c80d0
·
verified ·
1 Parent(s): cf030b8

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +43 -46
tts.py CHANGED
@@ -2,36 +2,43 @@ import os
2
  import re
3
  import torch
4
  import torchaudio
5
- from huggingface_hub import snapshot_download, hf_hub_download
6
  from TTS.tts.configs.xtts_config import XttsConfig
7
  from TTS.tts.models.xtts import Xtts
8
  from vinorm import TTSnorm
 
9
 
10
  # Cấu hình đường dẫn và tải mô hình
11
  checkpoint_dir = "model/"
12
  repo_id = "capleaf/viXTTS"
13
  use_deepspeed = False
14
 
15
- # Kiểm tra xem GPU có sẵn không sử dụng A100 nếu có
16
- device = "cuda" if torch.cuda.is_available() and "A100" in torch.cuda.get_device_name(0) else "cpu"
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Tạo thư mục nếu chưa tồn tại
19
  os.makedirs(checkpoint_dir, exist_ok=True)
20
 
21
  # Kiểm tra và tải các file cần thiết
22
  required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
23
- files_in_dir = os.listdir(checkpoint_dir)
24
- if not all(file in files_in_dir for file in required_files):
25
- snapshot_download(
26
- repo_id=repo_id,
27
- repo_type="model",
28
- local_dir=checkpoint_dir,
29
- )
30
- hf_hub_download(
31
- repo_id="coqui/XTTS-v2",
32
- filename="speakers_xtts.pth",
33
- local_dir=checkpoint_dir,
34
- )
35
 
36
  # Tải cấu hình và mô hình
37
  xtts_config = os.path.join(checkpoint_dir, "config.json")
@@ -39,17 +46,12 @@ config = XttsConfig()
39
  config.load_json(xtts_config)
40
  MODEL = Xtts.init_from_config(config)
41
  MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
42
-
43
- # Tải mô hình vào thiết bị phù hợp
44
  MODEL.to(device)
45
 
46
- # Danh sách ngôn ngữ được hỗ trợ (chỉ tiếng Việt và tiếng Anh)
47
  supported_languages = ["vi", "en"]
48
 
49
  def normalize_vietnamese_text(text):
50
- """
51
- Chuẩn hóa văn bản tiếng Việt.
52
- """
53
  text = (
54
  TTSnorm(text, unknown=False, lower=False, rule=True)
55
  .replace("..", ".")
@@ -65,38 +67,33 @@ def normalize_vietnamese_text(text):
65
  return text
66
 
67
  def generate_speech(text, language="vi", speaker_wav=None, normalize_text=True):
68
- """
69
- Tạo giọng nói từ văn bản.
70
- """
71
  if language not in supported_languages:
72
- raise ValueError(f"Ngôn ngữ {language} không được hỗ trợ. Chỉ hỗ trợ tiếng Việt (vi) và tiếng Anh (en).")
73
  if len(text) < 2:
74
- raise ValueError("Văn bản quá ngắn. Vui lòng nhập văn bản dài hơn.")
 
75
  try:
76
- # Chuẩn hóa văn bản nếu cần
77
  if normalize_text and language == "vi":
78
  text = normalize_vietnamese_text(text)
79
 
80
- # Lấy latent và embedding từ file âm thanh mẫu
81
- with torch.no_grad(): # Tắt tính gradient để tiết kiệm bộ nhớ
82
- gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
83
- audio_path=speaker_wav,
84
- gpt_cond_len=30 if device == "cuda" else 15, # Tăng độ dài khi dùng GPU
85
- gpt_cond_chunk_len=8 if device == "cuda" else 4,
86
- max_ref_length=60 if device == "cuda" else 30,
87
- )
88
- # Tạo giọng nói
89
- out = MODEL.inference(
90
- text,
91
- language,
92
- gpt_cond_latent,
93
- speaker_embedding,
94
- repetition_penalty=5.0,
95
- temperature=0.75,
96
- enable_text_splitting=True,
97
- )
98
 
99
- # Lưu file âm thanh
100
  output_file = "output.wav"
101
  torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0).to("cpu"), 24000)
102
  return output_file
 
2
  import re
3
  import torch
4
  import torchaudio
5
+ from huggingface_hub import hf_hub_download
6
  from TTS.tts.configs.xtts_config import XttsConfig
7
  from TTS.tts.models.xtts import Xtts
8
  from vinorm import TTSnorm
9
+ from torch.cuda.amp import autocast
10
 
11
  # Cấu hình đường dẫn và tải mô hình
12
  checkpoint_dir = "model/"
13
  repo_id = "capleaf/viXTTS"
14
  use_deepspeed = False
15
 
16
+ # Kiểm tra GPU và hỗ trợ FP16
17
+ if torch.cuda.is_available():
18
+ device = "cuda"
19
+ if "A100" in torch.cuda.get_device_name(0):
20
+ print("Đang sử dụng GPU A100 với hỗ trợ FP16.")
21
+ use_fp16 = True
22
+ else:
23
+ print(f"Đang sử dụng GPU: {torch.cuda.get_device_name(0)}")
24
+ use_fp16 = False
25
+ else:
26
+ device = "cpu"
27
+ use_fp16 = False
28
 
29
  # Tạo thư mục nếu chưa tồn tại
30
  os.makedirs(checkpoint_dir, exist_ok=True)
31
 
32
  # Kiểm tra và tải các file cần thiết
33
  required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
34
+ for file in required_files:
35
+ file_path = os.path.join(checkpoint_dir, file)
36
+ if not os.path.exists(file_path):
37
+ hf_hub_download(
38
+ repo_id=repo_id if file != "speakers_xtts.pth" else "coqui/XTTS-v2",
39
+ filename=file,
40
+ local_dir=checkpoint_dir,
41
+ )
 
 
 
 
42
 
43
  # Tải cấu hình và mô hình
44
  xtts_config = os.path.join(checkpoint_dir, "config.json")
 
46
  config.load_json(xtts_config)
47
  MODEL = Xtts.init_from_config(config)
48
  MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
 
 
49
  MODEL.to(device)
50
 
51
+ # Danh sách ngôn ngữ được hỗ trợ
52
  supported_languages = ["vi", "en"]
53
 
54
  def normalize_vietnamese_text(text):
 
 
 
55
  text = (
56
  TTSnorm(text, unknown=False, lower=False, rule=True)
57
  .replace("..", ".")
 
67
  return text
68
 
69
  def generate_speech(text, language="vi", speaker_wav=None, normalize_text=True):
 
 
 
70
  if language not in supported_languages:
71
+ raise ValueError(f"Ngôn ngữ {language} không được hỗ trợ.")
72
  if len(text) < 2:
73
+ raise ValueError("Văn bản quá ngắn.")
74
+
75
  try:
 
76
  if normalize_text and language == "vi":
77
  text = normalize_vietnamese_text(text)
78
 
79
+ with torch.no_grad():
80
+ with autocast(enabled=use_fp16):
81
+ gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
82
+ audio_path=speaker_wav,
83
+ gpt_cond_len=30 if device == "cuda" else 15,
84
+ gpt_cond_chunk_len=8 if device == "cuda" else 4,
85
+ max_ref_length=60 if device == "cuda" else 30,
86
+ )
87
+ out = MODEL.inference(
88
+ text,
89
+ language,
90
+ gpt_cond_latent,
91
+ speaker_embedding,
92
+ repetition_penalty=5.0,
93
+ temperature=0.75,
94
+ enable_text_splitting=True,
95
+ )
 
96
 
 
97
  output_file = "output.wav"
98
  torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0).to("cpu"), 24000)
99
  return output_file