TDN-M commited on
Commit
78505ef
·
verified ·
1 Parent(s): 1ee99ff

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +8 -6
tts.py CHANGED
@@ -12,6 +12,8 @@ checkpoint_dir = "model/"
12
  repo_id = "capleaf/viXTTS"
13
  use_deepspeed = False
14
 
 
 
15
  # Tạo thư mục nếu chưa tồn tại
16
  os.makedirs(checkpoint_dir, exist_ok=True)
17
 
@@ -37,8 +39,8 @@ config.load_json(xtts_config)
37
  MODEL = Xtts.init_from_config(config)
38
  MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
39
 
40
- # Đảm bảo mô hình chạy trên CPU
41
- MODEL.to("cpu")
42
 
43
  # Danh sách ngôn ngữ được hỗ trợ (chỉ tiếng Việt và tiếng Anh)
44
  supported_languages = ["vi", "en"]
@@ -80,9 +82,9 @@ def generate_speech(text, language="vi", speaker_wav=None, normalize_text=True):
80
  with torch.no_grad(): # Tắt tính gradient để tiết kiệm bộ nhớ
81
  gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
82
  audio_path=speaker_wav,
83
- gpt_cond_len=15, # Giảm độ dài để tối ưu hóa cho CPU
84
- gpt_cond_chunk_len=4,
85
- max_ref_length=30, # Giảm độ dài để tối ưu hóa cho CPU
86
  )
87
 
88
  # Tạo giọng nói
@@ -98,7 +100,7 @@ def generate_speech(text, language="vi", speaker_wav=None, normalize_text=True):
98
 
99
  # Lưu file âm thanh
100
  output_file = "output.wav"
101
- torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0), 24000)
102
 
103
  return output_file
104
 
 
12
  repo_id = "capleaf/viXTTS"
13
  use_deepspeed = False
14
 
15
+ device = "cuda" if torch.cuda.is_available() and "T4" in torch.cuda.get_device_name(0) else "cpu"
16
+
17
  # Tạo thư mục nếu chưa tồn tại
18
  os.makedirs(checkpoint_dir, exist_ok=True)
19
 
 
39
  MODEL = Xtts.init_from_config(config)
40
  MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
41
 
42
+ # Tải mô hình vào thiết bị phù hợp
43
+ MODEL.to(device)
44
 
45
  # Danh sách ngôn ngữ được hỗ trợ (chỉ tiếng Việt và tiếng Anh)
46
  supported_languages = ["vi", "en"]
 
82
  with torch.no_grad(): # Tắt tính gradient để tiết kiệm bộ nhớ
83
  gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
84
  audio_path=speaker_wav,
85
+ gpt_cond_len=30 if device == "cuda" else 15, # Tăng độ dài khi dùng GPU
86
+ gpt_cond_chunk_len=8 if device == "cuda" else 4,
87
+ max_ref_length=60 if device == "cuda" else 30,
88
  )
89
 
90
  # Tạo giọng nói
 
100
 
101
  # Lưu file âm thanh
102
  output_file = "output.wav"
103
+ torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0).to("cpu"), 24000)
104
 
105
  return output_file
106