Update tts.py
Browse files
    	
        tts.py
    CHANGED
    
    | @@ -1,58 +1,116 @@ | |
| 1 | 
             
            import os
         | 
|  | |
| 2 | 
             
            import torch
         | 
| 3 | 
             
            import torchaudio
         | 
|  | |
| 4 | 
             
            from TTS.tts.configs.xtts_config import XttsConfig
         | 
| 5 | 
             
            from TTS.tts.models.xtts import Xtts
         | 
| 6 | 
            -
            from huggingface_hub import snapshot_download, hf_hub_download
         | 
| 7 | 
             
            from vinorm import TTSnorm
         | 
| 8 |  | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
                     | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 30 |  | 
| 31 | 
            -
                # Cấu hình và tải mô hình
         | 
| 32 | 
            -
                xtts_config = os.path.join(checkpoint_dir, "config.json")
         | 
| 33 | 
            -
                config = XttsConfig()
         | 
| 34 | 
            -
                config.load_json(xtts_config)
         | 
| 35 | 
            -
                MODEL = Xtts.init_from_config(config)
         | 
| 36 | 
            -
                MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
         | 
| 37 |  | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 40 |  | 
| 41 | 
            -
                # Chuẩn hóa văn bản
         | 
| 42 | 
            -
                normalized_text = TTSnorm(text)
         | 
| 43 |  | 
| 44 | 
            -
             | 
| 45 | 
            -
                 | 
| 46 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 47 | 
             
                    out = MODEL.inference(
         | 
| 48 | 
            -
                         | 
| 49 | 
             
                        language,
         | 
| 50 | 
             
                        gpt_cond_latent,
         | 
| 51 | 
             
                        speaker_embedding,
         | 
| 52 | 
            -
                         | 
|  | |
|  | |
| 53 | 
             
                    )
         | 
| 54 |  | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
             
            import torch
         | 
| 4 | 
             
            import torchaudio
         | 
| 5 | 
            +
            from huggingface_hub import snapshot_download, hf_hub_download
         | 
| 6 | 
             
            from TTS.tts.configs.xtts_config import XttsConfig
         | 
| 7 | 
             
            from TTS.tts.models.xtts import Xtts
         | 
|  | |
| 8 | 
             
            from vinorm import TTSnorm
         | 
| 9 |  | 
| 10 | 
            +
            # Cấu hình đường dẫn và tải mô hình
         | 
| 11 | 
            +
            checkpoint_dir = "model/"
         | 
| 12 | 
            +
            repo_id = "capleaf/viXTTS"
         | 
| 13 | 
            +
            use_deepspeed = False
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Tạo thư mục nếu chưa tồn tại
         | 
| 16 | 
            +
            os.makedirs(checkpoint_dir, exist_ok=True)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Kiểm tra và tải các file cần thiết
         | 
| 19 | 
            +
            required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
         | 
| 20 | 
            +
            files_in_dir = os.listdir(checkpoint_dir)
         | 
| 21 | 
            +
            if not all(file in files_in_dir for file in required_files):
         | 
| 22 | 
            +
                snapshot_download(
         | 
| 23 | 
            +
                    repo_id=repo_id,
         | 
| 24 | 
            +
                    repo_type="model",
         | 
| 25 | 
            +
                    local_dir=checkpoint_dir,
         | 
| 26 | 
            +
                )
         | 
| 27 | 
            +
                hf_hub_download(
         | 
| 28 | 
            +
                    repo_id="coqui/XTTS-v2",
         | 
| 29 | 
            +
                    filename="speakers_xtts.pth",
         | 
| 30 | 
            +
                    local_dir=checkpoint_dir,
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            # Tải cấu hình và mô hình
         | 
| 34 | 
            +
            xtts_config = os.path.join(checkpoint_dir, "config.json")
         | 
| 35 | 
            +
            config = XttsConfig()
         | 
| 36 | 
            +
            config.load_json(xtts_config)
         | 
| 37 | 
            +
            MODEL = Xtts.init_from_config(config)
         | 
| 38 | 
            +
            MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            # Sử dụng GPU nếu có
         | 
| 41 | 
            +
            if torch.cuda.is_available():
         | 
| 42 | 
            +
                MODEL.cuda()
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            # Danh sách ngôn ngữ được hỗ trợ (chỉ tiếng Việt và tiếng Anh)
         | 
| 45 | 
            +
            supported_languages = ["vi", "en"]
         | 
| 46 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 47 |  | 
| 48 | 
            +
            def normalize_vietnamese_text(text):
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                Chuẩn hóa văn bản tiếng Việt.
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                text = (
         | 
| 53 | 
            +
                    TTSnorm(text, unknown=False, lower=False, rule=True)
         | 
| 54 | 
            +
                    .replace("..", ".")
         | 
| 55 | 
            +
                    .replace("!.", "!")
         | 
| 56 | 
            +
                    .replace("?.", "?")
         | 
| 57 | 
            +
                    .replace(" .", ".")
         | 
| 58 | 
            +
                    .replace(" ,", ",")
         | 
| 59 | 
            +
                    .replace('"', "")
         | 
| 60 | 
            +
                    .replace("'", "")
         | 
| 61 | 
            +
                    .replace("AI", "Ây Ai")
         | 
| 62 | 
            +
                    .replace("A.I", "Ây Ai")
         | 
| 63 | 
            +
                )
         | 
| 64 | 
            +
                return text
         | 
| 65 |  | 
|  | |
|  | |
| 66 |  | 
| 67 | 
            +
            def generate_speech(text, language="vi", speaker_wav=None, normalize_text=True):
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                Tạo giọng nói từ văn bản.
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
                if language not in supported_languages:
         | 
| 72 | 
            +
                    raise ValueError(f"Ngôn ngữ {language} không được hỗ trợ. Chỉ hỗ trợ tiếng Việt (vi) và tiếng Anh (en).")
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                if len(text) < 2:
         | 
| 75 | 
            +
                    raise ValueError("Văn bản quá ngắn. Vui lòng nhập văn bản dài hơn.")
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                try:
         | 
| 78 | 
            +
                    # Chuẩn hóa văn bản nếu cần
         | 
| 79 | 
            +
                    if normalize_text and language == "vi":
         | 
| 80 | 
            +
                        text = normalize_vietnamese_text(text)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # Lấy latent và embedding từ file âm thanh mẫu
         | 
| 83 | 
            +
                    gpt_cond_latent, speaker_embedding = MODEL.get_conditioning_latents(
         | 
| 84 | 
            +
                        audio_path=speaker_wav,
         | 
| 85 | 
            +
                        gpt_cond_len=30,
         | 
| 86 | 
            +
                        gpt_cond_chunk_len=4,
         | 
| 87 | 
            +
                        max_ref_length=60,
         | 
| 88 | 
            +
                    )
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    # Tạo giọng nói
         | 
| 91 | 
             
                    out = MODEL.inference(
         | 
| 92 | 
            +
                        text,
         | 
| 93 | 
             
                        language,
         | 
| 94 | 
             
                        gpt_cond_latent,
         | 
| 95 | 
             
                        speaker_embedding,
         | 
| 96 | 
            +
                        repetition_penalty=5.0,
         | 
| 97 | 
            +
                        temperature=0.75,
         | 
| 98 | 
            +
                        enable_text_splitting=True,
         | 
| 99 | 
             
                    )
         | 
| 100 |  | 
| 101 | 
            +
                    # Lưu file âm thanh
         | 
| 102 | 
            +
                    output_file = "output.wav"
         | 
| 103 | 
            +
                    torchaudio.save(output_file, torch.tensor(out["wav"]).unsqueeze(0), 24000)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    return output_file
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                except Exception as e:
         | 
| 108 | 
            +
                    raise RuntimeError(f"Lỗi khi tạo giọng nói: {str(e)}")
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            if __name__ == "__main__":
         | 
| 112 | 
            +
                # Ví dụ sử dụng
         | 
| 113 | 
            +
                text = "Xin chào, đây là một đoạn văn bản được chuyển thành giọng nói."
         | 
| 114 | 
            +
                speaker_wav = "voices/sample_voice.wav"  # Đường dẫn đến file âm thanh mẫu trong thư mục /voices
         | 
| 115 | 
            +
                output_audio = generate_speech(text, language="vi", speaker_wav=speaker_wav)
         | 
| 116 | 
            +
                print(f"File âm thanh đã được tạo: {output_audio}")
         | 
