diff --git a/.env.webui b/.env.webui
index a807cd0298c85378bbefe6e65990411d0b9ee51b..22d73786f170f6fa0741125fe3841a6eed8f0128 100644
--- a/.env.webui
+++ b/.env.webui
@@ -17,5 +17,5 @@ TTS_MAX_LEN=1000
 SSML_MAX_LEN=3000
 MAX_BATCH_SIZE=12
 
-V_GIT_TAG="🤗hf(0.5.6-rc)"
+V_GIT_TAG="🤗hf(0.6.1-rc)"
 V_GIT_COMMIT=main
diff --git a/language/zh-CN.json b/language/zh-CN.json
index 31e5890a575dcac42c44b75b7675650a38f22553..f4f41cf038fe73c2194b311b9357b8f5a3b77d6d 100644
--- a/language/zh-CN.json
+++ b/language/zh-CN.json
@@ -80,6 +80,9 @@
   "readme": "readme",
   "changelog": "changelog",
   "💼Speaker file": "💼音色文件",
+  "🎛️Spliter": "🎛️分割器配置",
+  "eos": "句尾词",
+  "Spliter Threshold": "分割器阈值",
   "TTS_STYLE_GUIDE": ["后缀为 _p 表示带prompt,效果更强但是影响质量"],
   "SSML_SPLITER_GUIDE": [
     "- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`",
diff --git a/modules/ChatTTS/ChatTTS/core.py b/modules/ChatTTS/ChatTTS/core.py
index 225de72f07cb237cfa7210872d316936afe808f8..549973e0c5dcdf9869ae1237a65fc7762ceae244 100644
--- a/modules/ChatTTS/ChatTTS/core.py
+++ b/modules/ChatTTS/ChatTTS/core.py
@@ -17,7 +17,7 @@ from .infer.api import refine_text, infer_code
 
 from huggingface_hub import snapshot_download
 
-logging.basicConfig(level=logging.INFO)
+logging.basicConfig(level=logging.ERROR)
 
 
 class Chat:
diff --git a/modules/SynthesizeSegments.py b/modules/SynthesizeSegments.py
index e7589ab85d6ebdbd3618d6a668a0bf60c6a51b56..de8a7778a27b2d89e6058a15716a0c53538c9a71 100644
--- a/modules/SynthesizeSegments.py
+++ b/modules/SynthesizeSegments.py
@@ -1,8 +1,10 @@
+import copy
 from box import Box
 from pydub import AudioSegment
 from typing import List, Union
 from scipy.io.wavfile import write
 import io
+from modules.SentenceSplitter import SentenceSplitter
 from modules.api.utils import calc_spk_style
 from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
 from modules.utils import rng
@@ -56,27 +58,27 @@ def to_number(value, t, default=0):
 
 
 class TTSAudioSegment(Box):
-    text: str
-    temperature: float
-    top_P: float
-    top_K: int
-    spk: int
-    infer_seed: int
-    prompt1: str
-    prompt2: str
-    prefix: str
-
-    _type: str
-
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
+        self._type = kwargs.get("_type", "voice")
+        self.text = kwargs.get("text", "")
+        self.temperature = kwargs.get("temperature", 0.3)
+        self.top_P = kwargs.get("top_P", 0.5)
+        self.top_K = kwargs.get("top_K", 20)
+        self.spk = kwargs.get("spk", -1)
+        self.infer_seed = kwargs.get("infer_seed", -1)
+        self.prompt1 = kwargs.get("prompt1", "")
+        self.prompt2 = kwargs.get("prompt2", "")
+        self.prefix = kwargs.get("prefix", "")
 
 
 class SynthesizeSegments:
-    def __init__(self, batch_size: int = 8):
+    def __init__(self, batch_size: int = 8, eos="", spliter_thr=100):
         self.batch_size = batch_size
         self.batch_default_spk_seed = rng.np_rng()
         self.batch_default_infer_seed = rng.np_rng()
+        self.eos = eos
+        self.spliter_thr = spliter_thr
 
     def segment_to_generate_params(
         self, segment: Union[SSMLSegment, SSMLBreak]
@@ -85,9 +87,11 @@ class SynthesizeSegments:
             return TTSAudioSegment(_type="break")
 
         if segment.get("params", None) is not None:
-            return TTSAudioSegment(**segment.get("params"))
+            params = segment.get("params")
+            text = segment.get("text", None) or segment.text or ""
+            return TTSAudioSegment(**params, text=text)
 
-        text = segment.get("text", "")
+        text = segment.get("text", None) or segment.text or ""
         is_end = segment.get("is_end", False)
 
         text = str(text).strip()
@@ -156,7 +160,7 @@ class SynthesizeSegments:
         for i in range(0, len(bucket), self.batch_size):
             batch = bucket[i : i + self.batch_size]
             param_arr = [self.segment_to_generate_params(segment) for segment in batch]
-            texts = [params.text for params in param_arr]
+            texts = [params.text + self.eos for params in param_arr]
 
             params = param_arr[0]
             audio_datas = generate_audio.generate_audio_batch(
@@ -204,9 +208,38 @@ class SynthesizeSegments:
 
         return buckets
 
+    def split_segments(self, segments: List[Union[SSMLSegment, SSMLBreak]]):
+        """
+        将 segments 中的 text 经过 spliter 处理成多个 segments
+        """
+        spliter = SentenceSplitter(threshold=self.spliter_thr)
+        ret_segments: List[Union[SSMLSegment, SSMLBreak]] = []
+
+        for segment in segments:
+            if isinstance(segment, SSMLBreak):
+                ret_segments.append(segment)
+                continue
+
+            text = segment.text
+            if not text:
+                continue
+
+            sentences = spliter.parse(text)
+            for sentence in sentences:
+                ret_segments.append(
+                    SSMLSegment(
+                        text=sentence,
+                        attrs=segment.attrs.copy(),
+                        params=copy.copy(segment.params),
+                    )
+                )
+
+        return ret_segments
+
     def synthesize_segments(
         self, segments: List[Union[SSMLSegment, SSMLBreak]]
     ) -> List[AudioSegment]:
+        segments = self.split_segments(segments)
         audio_segments = [None] * len(segments)
         buckets = self.bucket_segments(segments)
 
diff --git a/modules/api/api_setup.py b/modules/api/api_setup.py
index e7de2a62e4131afe6fb5db0280feca8288f4d79d..bfe07f4de7b7f9ddc4a2c625579f1ecd07aa1e2a 100644
--- a/modules/api/api_setup.py
+++ b/modules/api/api_setup.py
@@ -18,6 +18,7 @@ from modules.api.impl import (
     speaker_api,
     ping_api,
     models_api,
+    xtts_v2_api,
 )
 
 logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ def create_api(app, exclude=[]):
     google_api.setup(app_mgr)
     openai_api.setup(app_mgr)
     refiner_api.setup(app_mgr)
+    xtts_v2_api.setup(app_mgr)
 
     return app_mgr
 
@@ -42,9 +44,9 @@ def create_api(app, exclude=[]):
 def setup_model_args(parser: argparse.ArgumentParser):
     parser.add_argument("--compile", action="store_true", help="Enable model compile")
     parser.add_argument(
-        "--half",
+        "--no_half",
         action="store_true",
-        help="Enable half precision for model inference",
+        help="Disalbe half precision for model inference",
     )
     parser.add_argument(
         "--off_tqdm",
@@ -82,7 +84,7 @@ def process_model_args(args):
     compile = env.get_and_update_env(args, "compile", False, bool)
     device_id = env.get_and_update_env(args, "device_id", None, str)
     use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
-    half = env.get_and_update_env(args, "half", False, bool)
+    no_half = env.get_and_update_env(args, "no_half", False, bool)
     off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
     debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
 
diff --git a/modules/api/impl/google_api.py b/modules/api/impl/google_api.py
index cd66e036578866f35b2cf8a9f1f559e1e2ca7e90..6244bc9fa8ab4c625b5ec8a04982be20798f8ed5 100644
--- a/modules/api/impl/google_api.py
+++ b/modules/api/impl/google_api.py
@@ -13,6 +13,7 @@ from modules.Enhancer.ResembleEnhance import (
 )
 from modules.api.Api import APIManager
 from modules.synthesize_audio import synthesize_audio
+from modules.utils import audio
 from modules.utils.audio import apply_prosody_to_audio_data
 from modules.normalization import text_normalize
 
@@ -44,6 +45,9 @@ class VoiceSelectionParams(BaseModel):
     topK: int = 20
     seed: int = 42
 
+    # end_of_sentence
+    eos: str = "[uv_break]"
+
 
 class AudioConfig(BaseModel):
     audioEncoding: api_utils.AudioFormat = "mp3"
@@ -87,6 +91,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
     language_code = voice.languageCode
     voice_name = voice.name
     infer_seed = voice.seed or 42
+    eos = voice.eos or "[uv_break]"
     audio_format = audioConfig.audioEncoding or "mp3"
     speaking_rate = audioConfig.speakingRate or 1
     pitch = audioConfig.pitch or 0
@@ -94,11 +99,9 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
 
     batch_size = audioConfig.batchSize or 1
 
-    # TODO spliter_threshold
     spliter_threshold = audioConfig.spliterThreshold or 100
 
-    # TODO sample_rate
-    sample_rate_hertz = audioConfig.sampleRateHertz or 24000
+    sample_rate = audioConfig.sampleRateHertz or 24000
 
     params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
 
@@ -137,10 +140,10 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
                 prefix=params.get("prefix", ""),
                 batch_size=batch_size,
                 spliter_threshold=spliter_threshold,
+                end_of_sentence=eos,
             )
 
         elif input.ssml:
-            # 处理SSML合成逻辑
             parser = create_ssml_parser()
             segments = parser.parse(input.ssml)
             for seg in segments:
@@ -151,17 +154,13 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
                     status_code=422, detail="The SSML text is empty or parsing failed."
                 )
 
-            synthesize = SynthesizeSegments(batch_size=batch_size)
+            synthesize = SynthesizeSegments(
+                batch_size=batch_size, eos=eos, spliter_thr=spliter_threshold
+            )
             audio_segments = synthesize.synthesize_segments(segments)
             combined_audio = combine_audio_segments(audio_segments)
 
-            buffer = io.BytesIO()
-            combined_audio.export(buffer, format="wav")
-
-            buffer.seek(0)
-
-            audio_data = buffer.read()
-
+            sample_rate, audio_data = audio.pydub_to_np(combined_audio)
         else:
             raise HTTPException(
                 status_code=422, detail="Either text or SSML input must be provided."
diff --git a/modules/api/impl/openai_api.py b/modules/api/impl/openai_api.py
index f1e21e1241c958047d3b8488105981528fde82eb..7c0e012c093c13bb6ce19415a3e504cc3fcafe46 100644
--- a/modules/api/impl/openai_api.py
+++ b/modules/api/impl/openai_api.py
@@ -41,6 +41,8 @@ class AudioSpeechRequest(BaseModel):
     spliter_threshold: float = Field(
         100, ge=10, le=1024, description="Threshold for sentence spliter"
     )
+    # end of sentence
+    eos: str = "[uv_break]"
 
 
 async def openai_speech_api(
@@ -52,6 +54,7 @@ async def openai_speech_api(
     input_text = request.input
     voice = request.voice
     style = request.style
+    eos = request.eos
     response_format = request.response_format
     batch_size = request.batch_size
     spliter_threshold = request.spliter_threshold
@@ -95,6 +98,7 @@ async def openai_speech_api(
             prompt1=prompt1,
             prompt2=prompt2,
             prefix=prefix,
+            end_of_sentence=eos,
         )
 
         if speed != 1:
diff --git a/modules/api/impl/ssml_api.py b/modules/api/impl/ssml_api.py
index 2696470d6afaa6f5ba6bac9a75b36e2bd6164ce8..c6277b6214fe18a5f9e271c766f1f16d9d3f981f 100644
--- a/modules/api/impl/ssml_api.py
+++ b/modules/api/impl/ssml_api.py
@@ -26,8 +26,13 @@ class SSMLRequest(BaseModel):
     # NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
     batch_size: int = 4
 
+    # end of sentence
+    eos: str = "[uv_break]"
 
-async def synthesize_ssml(
+    spliter_thr: int = 100
+
+
+async def synthesize_ssml_api(
     request: SSMLRequest = Body(
         ..., description="JSON body with SSML string and format"
     )
@@ -36,12 +41,19 @@ async def synthesize_ssml(
         ssml = request.ssml
         format = request.format.lower()
         batch_size = request.batch_size
+        eos = request.eos
+        spliter_thr = request.spliter_thr
 
         if batch_size < 1:
             raise HTTPException(
                 status_code=400, detail="Batch size must be greater than 0."
             )
 
+        if spliter_thr < 50:
+            raise HTTPException(
+                status_code=400, detail="Spliter threshold must be greater than 50."
+            )
+
         if not ssml or ssml == "":
             raise HTTPException(status_code=400, detail="SSML content is required.")
 
@@ -55,7 +67,9 @@ async def synthesize_ssml(
         for seg in segments:
             seg["text"] = text_normalize(seg["text"], is_end=True)
 
-        synthesize = SynthesizeSegments(batch_size)
+        synthesize = SynthesizeSegments(
+            batch_size=batch_size, eos=eos, spliter_thr=spliter_thr
+        )
         audio_segments = synthesize.synthesize_segments(segments)
         combined_audio = combine_audio_segments(audio_segments)
         buffer = io.BytesIO()
@@ -77,4 +91,4 @@ async def synthesize_ssml(
 
 
 def setup(api_manager: APIManager):
-    api_manager.post("/v1/ssml", response_class=FileResponse)(synthesize_ssml)
+    api_manager.post("/v1/ssml", response_class=FileResponse)(synthesize_ssml_api)
diff --git a/modules/api/impl/tts_api.py b/modules/api/impl/tts_api.py
index 7330b4820612341e8ca57fea0a04b3dbe3eadfa5..b91f5493ca898c5d7522110004a65443849536b8 100644
--- a/modules/api/impl/tts_api.py
+++ b/modules/api/impl/tts_api.py
@@ -38,6 +38,7 @@ class TTSParams(BaseModel):
     prefix: str = Query("", description="Text prefix for inference")
     bs: str = Query("8", description="Batch size for inference")
     thr: str = Query("100", description="Threshold for sentence spliter")
+    eos: str = Query("", description="End of sentence str")
 
 
 async def synthesize_tts(params: TTSParams = Depends()):
@@ -87,6 +88,7 @@ async def synthesize_tts(params: TTSParams = Depends()):
         prefix = params.prefix or calc_params.get("prefix", params.prefix)
         prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1)
         prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2)
+        eos = params.eos or ""
 
         batch_size = int(params.bs)
         threshold = int(params.thr)
@@ -103,6 +105,7 @@ async def synthesize_tts(params: TTSParams = Depends()):
             prefix=prefix,
             batch_size=batch_size,
             spliter_threshold=threshold,
+            end_of_sentence=eos,
         )
 
         buffer = io.BytesIO()
diff --git a/modules/api/impl/xtts_v2_api.py b/modules/api/impl/xtts_v2_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b660562b2ae1041f46c62fe17355040723a6590
--- /dev/null
+++ b/modules/api/impl/xtts_v2_api.py
@@ -0,0 +1,160 @@
+import io
+from fastapi import HTTPException
+from fastapi.responses import StreamingResponse
+from pydantic import BaseModel
+from modules.api import utils as api_utils
+from modules.api.Api import APIManager
+
+import soundfile as sf
+
+from modules import config
+from modules.normalization import text_normalize
+from modules.speaker import speaker_mgr
+from modules.synthesize_audio import synthesize_audio
+
+import logging
+
+from modules.utils.audio import apply_prosody_to_audio_data
+
+logger = logging.getLogger(__name__)
+
+
+class XTTS_V2_Settings:
+    def __init__(self):
+        self.stream_chunk_size = 100
+        self.temperature = 0.3
+        self.speed = 1
+        self.length_penalty = 0.5
+        self.repetition_penalty = 1.0
+        self.top_p = 0.7
+        self.top_k = 20
+        self.enable_text_splitting = True
+
+
+class TTSSettingsRequest(BaseModel):
+    stream_chunk_size: int
+    temperature: float
+    speed: float
+    length_penalty: float
+    repetition_penalty: float
+    top_p: float
+    top_k: int
+    enable_text_splitting: bool
+
+
+class SynthesisRequest(BaseModel):
+    text: str
+    speaker_wav: str
+    language: str
+
+
+def setup(app: APIManager):
+    XTTSV2 = XTTS_V2_Settings()
+
+    @app.get("/v1/xtts_v2/speakers")
+    async def speakers():
+        spks = speaker_mgr.list_speakers()
+        return [
+            {
+                "name": spk.name,
+                "voice_id": spk.id,
+                # TODO: 也许可以放一个 "/v1/tts" 接口地址在这里
+                "preview_url": "",
+            }
+            for spk in spks
+        ]
+
+    @app.post("/v1/xtts_v2/tts_to_audio", response_class=StreamingResponse)
+    async def tts_to_audio(request: SynthesisRequest):
+        text = request.text
+        # speaker_wav 就是 speaker id 。。。
+        voice_id = request.speaker_wav
+        language = request.language
+
+        spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker(
+            voice_id
+        )
+        if spk is None:
+            raise HTTPException(status_code=400, detail="Invalid speaker id")
+
+        text = text_normalize(text, is_end=True)
+        sample_rate, audio_data = synthesize_audio(
+            text=text,
+            temperature=XTTSV2.temperature,
+            # length_penalty=XTTSV2.length_penalty,
+            # repetition_penalty=XTTSV2.repetition_penalty,
+            top_P=XTTSV2.top_p,
+            top_K=XTTSV2.top_k,
+            spk=spk,
+            spliter_threshold=XTTSV2.stream_chunk_size,
+            # TODO 支持设置 batch_size
+            batch_size=4,
+            end_of_sentence="[uv_break]",
+        )
+
+        if XTTSV2.speed:
+            audio_data = apply_prosody_to_audio_data(
+                audio_data,
+                rate=XTTSV2.speed,
+                sr=sample_rate,
+            )
+
+        # to mp3
+        buffer = io.BytesIO()
+        sf.write(buffer, audio_data, sample_rate, format="wav")
+        buffer.seek(0)
+
+        buffer = api_utils.wav_to_mp3(buffer)
+
+        return StreamingResponse(buffer, media_type="audio/mpeg")
+
+    @app.get("/v1/xtts_v2/tts_stream")
+    async def tts_stream():
+        raise HTTPException(status_code=501, detail="Not implemented")
+
+    @app.post("/v1/xtts_v2/set_tts_settings")
+    async def set_tts_settings(request: TTSSettingsRequest):
+        try:
+            if request.stream_chunk_size < 50:
+                raise HTTPException(
+                    status_code=400, detail="stream_chunk_size must be greater than 0"
+                )
+            if request.temperature < 0:
+                raise HTTPException(
+                    status_code=400, detail="temperature must be greater than 0"
+                )
+            if request.speed < 0:
+                raise HTTPException(
+                    status_code=400, detail="speed must be greater than 0"
+                )
+            if request.length_penalty < 0:
+                raise HTTPException(
+                    status_code=400, detail="length_penalty must be greater than 0"
+                )
+            if request.repetition_penalty < 0:
+                raise HTTPException(
+                    status_code=400, detail="repetition_penalty must be greater than 0"
+                )
+            if request.top_p < 0:
+                raise HTTPException(
+                    status_code=400, detail="top_p must be greater than 0"
+                )
+            if request.top_k < 0:
+                raise HTTPException(
+                    status_code=400, detail="top_k must be greater than 0"
+                )
+
+            XTTSV2.stream_chunk_size = request.stream_chunk_size
+            XTTSV2.temperature = request.temperature
+            XTTSV2.speed = request.speed
+            XTTSV2.length_penalty = request.length_penalty
+            XTTSV2.repetition_penalty = request.repetition_penalty
+            XTTSV2.top_p = request.top_p
+            XTTSV2.top_k = request.top_k
+            XTTSV2.enable_text_splitting = request.enable_text_splitting
+            return {"message": "Settings successfully applied"}
+        except Exception as e:
+            if isinstance(e, HTTPException):
+                raise e
+            logger.error(e)
+            raise HTTPException(status_code=500, detail=str(e))
diff --git a/modules/devices/devices.py b/modules/devices/devices.py
index c9e1862cdd427f19edf669b2b6c4daee9d6f3340..e11f6eba5da29af778ee1d254778acaa7492b6bb 100644
--- a/modules/devices/devices.py
+++ b/modules/devices/devices.py
@@ -127,7 +127,7 @@ def reset_device():
     global dtype_gpt
     global dtype_decoder
 
-    if config.runtime_env_vars.half:
+    if not config.runtime_env_vars.no_half:
         dtype = torch.float16
         dtype_dvae = torch.float16
         dtype_vocos = torch.float16
diff --git a/modules/finetune/__init__.py b/modules/finetune/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/finetune/model/__init__.py b/modules/finetune/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/finetune/model/encoder.py b/modules/finetune/model/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d445ad83d07ebdd9bfaa25e9e6d9c64001471d08
--- /dev/null
+++ b/modules/finetune/model/encoder.py
@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+
+from modules.ChatTTS.ChatTTS.model.dvae import ConvNeXtBlock, DVAEDecoder
+
+from .wavenet import WaveNet
+
+
+def get_encoder_config(decoder: DVAEDecoder) -> dict[str, int | bool]:
+    return {
+        "idim": decoder.conv_out.out_channels,
+        "odim": decoder.conv_in[0].in_channels,
+        "n_layer": len(decoder.decoder_block),
+        "bn_dim": decoder.conv_in[0].out_channels,
+        "hidden": decoder.conv_in[2].out_channels,
+        "kernel": decoder.decoder_block[0].dwconv.kernel_size[0],
+        "dilation": decoder.decoder_block[0].dwconv.dilation[0],
+        "down": decoder.up,
+    }
+
+
+class DVAEEncoder(nn.Module):
+    def __init__(
+        self,
+        idim: int,
+        odim: int,
+        n_layer: int = 12,
+        bn_dim: int = 64,
+        hidden: int = 256,
+        kernel: int = 7,
+        dilation: int = 2,
+        down: bool = False,
+    ) -> None:
+        super().__init__()
+        self.wavenet = WaveNet(
+            input_channels=100,
+            residual_channels=idim,
+            residual_layers=20,
+            dilation_cycle=4,
+        )
+        self.conv_in_transpose = nn.ConvTranspose1d(
+            idim, hidden, kernel_size=1, bias=False
+        )
+        # nn.Sequential(
+        #     nn.ConvTranspose1d(100, idim, 3, 1, 1, bias=False),
+        #     nn.ConvTranspose1d(idim, hidden, kernel_size=1, bias=False)
+        # )
+        self.encoder_block = nn.ModuleList(
+            [
+                ConvNeXtBlock(
+                    hidden,
+                    hidden * 4,
+                    kernel,
+                    dilation,
+                )
+                for _ in range(n_layer)
+            ]
+        )
+        self.conv_out_transpose = nn.Sequential(
+            nn.Conv1d(hidden, bn_dim, 3, 1, 1),
+            nn.GELU(),
+            nn.Conv1d(bn_dim, odim, 3, 1, 1),
+        )
+
+    def forward(
+        self,
+        audio_mel_specs: torch.Tensor,  # (batch_size, audio_len*2, 100)
+        audio_attention_mask: torch.Tensor,  # (batch_size, audio_len)
+        conditioning=None,
+    ) -> torch.Tensor:
+        mel_attention_mask = (
+            audio_attention_mask.unsqueeze(-1).repeat(1, 1, 2).flatten(1)
+        )
+        x: torch.Tensor = self.wavenet(
+            audio_mel_specs.transpose(1, 2)
+        )  # (batch_size, idim, audio_len*2)
+        x = x * mel_attention_mask.unsqueeze(1)
+        x = self.conv_in_transpose(x)  # (batch_size, hidden, audio_len*2)
+        for f in self.encoder_block:
+            x = f(x, conditioning)
+        x = self.conv_out_transpose(x)  # (batch_size, odim, audio_len*2)
+        x = (
+            x.view(x.size(0), x.size(1), 2, x.size(2) // 2)
+            .permute(0, 3, 1, 2)
+            .flatten(2)
+        )
+        return x  # (batch_size, audio_len, audio_dim=odim*2)
diff --git a/modules/finetune/model/wavenet.py b/modules/finetune/model/wavenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..828aa41c8afdbbb59f038d5053357a317721b5c0
--- /dev/null
+++ b/modules/finetune/model/wavenet.py
@@ -0,0 +1,227 @@
+"""https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/vqgan/modules/wavenet.py"""
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Mish(nn.Module):
+    def forward(self, x):
+        return x * torch.tanh(F.softplus(x))
+
+
+class DiffusionEmbedding(nn.Module):
+    """Diffusion Step Embedding"""
+
+    def __init__(self, d_denoiser):
+        super(DiffusionEmbedding, self).__init__()
+        self.dim = d_denoiser
+
+    def forward(self, x):
+        device = x.device
+        half_dim = self.dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+        emb = x[:, None] * emb[None, :]
+        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+        return emb
+
+
+class LinearNorm(nn.Module):
+    """LinearNorm Projection"""
+
+    def __init__(self, in_features, out_features, bias=False):
+        super(LinearNorm, self).__init__()
+        self.linear = nn.Linear(in_features, out_features, bias)
+
+        nn.init.xavier_uniform_(self.linear.weight)
+        if bias:
+            nn.init.constant_(self.linear.bias, 0.0)
+
+    def forward(self, x):
+        x = self.linear(x)
+        return x
+
+
+class ConvNorm(nn.Module):
+    """1D Convolution"""
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size=1,
+        stride=1,
+        padding=None,
+        dilation=1,
+        bias=True,
+        w_init_gain="linear",
+    ):
+        super(ConvNorm, self).__init__()
+
+        if padding is None:
+            assert kernel_size % 2 == 1
+            padding = int(dilation * (kernel_size - 1) / 2)
+
+        self.conv = nn.Conv1d(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=bias,
+        )
+        nn.init.kaiming_normal_(self.conv.weight)
+
+    def forward(self, signal):
+        conv_signal = self.conv(signal)
+
+        return conv_signal
+
+
+class ResidualBlock(nn.Module):
+    """Residual Block"""
+
+    def __init__(
+        self,
+        residual_channels,
+        use_linear_bias=False,
+        dilation=1,
+        condition_channels=None,
+    ):
+        super(ResidualBlock, self).__init__()
+        self.conv_layer = ConvNorm(
+            residual_channels,
+            2 * residual_channels,
+            kernel_size=3,
+            stride=1,
+            padding=dilation,
+            dilation=dilation,
+        )
+
+        if condition_channels is not None:
+            self.diffusion_projection = LinearNorm(
+                residual_channels, residual_channels, use_linear_bias
+            )
+            self.condition_projection = ConvNorm(
+                condition_channels, 2 * residual_channels, kernel_size=1
+            )
+
+        self.output_projection = ConvNorm(
+            residual_channels, 2 * residual_channels, kernel_size=1
+        )
+
+    def forward(self, x, condition=None, diffusion_step=None):
+        y = x
+
+        if diffusion_step is not None:
+            diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+            y = y + diffusion_step
+
+        y = self.conv_layer(y)
+
+        if condition is not None:
+            condition = self.condition_projection(condition)
+            y = y + condition
+
+        gate, filter = torch.chunk(y, 2, dim=1)
+        y = torch.sigmoid(gate) * torch.tanh(filter)
+
+        y = self.output_projection(y)
+        residual, skip = torch.chunk(y, 2, dim=1)
+
+        return (x + residual) / math.sqrt(2.0), skip
+
+
+class WaveNet(nn.Module):
+    def __init__(
+        self,
+        input_channels: Optional[int] = None,
+        output_channels: Optional[int] = None,
+        residual_channels: int = 512,
+        residual_layers: int = 20,
+        dilation_cycle: Optional[int] = 4,
+        is_diffusion: bool = False,
+        condition_channels: Optional[int] = None,
+    ):
+        super().__init__()
+
+        # Input projection
+        self.input_projection = None
+        if input_channels is not None and input_channels != residual_channels:
+            self.input_projection = ConvNorm(
+                input_channels, residual_channels, kernel_size=1
+            )
+
+        if input_channels is None:
+            input_channels = residual_channels
+
+        self.input_channels = input_channels
+
+        # Residual layers
+        self.residual_layers = nn.ModuleList(
+            [
+                ResidualBlock(
+                    residual_channels=residual_channels,
+                    use_linear_bias=False,
+                    dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
+                    condition_channels=condition_channels,
+                )
+                for i in range(residual_layers)
+            ]
+        )
+
+        # Skip projection
+        self.skip_projection = ConvNorm(
+            residual_channels, residual_channels, kernel_size=1
+        )
+
+        # Output projection
+        self.output_projection = None
+        if output_channels is not None and output_channels != residual_channels:
+            self.output_projection = ConvNorm(
+                residual_channels, output_channels, kernel_size=1
+            )
+
+        if is_diffusion:
+            self.diffusion_embedding = DiffusionEmbedding(residual_channels)
+            self.mlp = nn.Sequential(
+                LinearNorm(residual_channels, residual_channels * 4, False),
+                Mish(),
+                LinearNorm(residual_channels * 4, residual_channels, False),
+            )
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv1d, nn.Linear)):
+            nn.init.trunc_normal_(m.weight, std=0.02)
+            if getattr(m, "bias", None) is not None:
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x, t=None, condition=None):
+        if self.input_projection is not None:
+            x = self.input_projection(x)
+            x = F.silu(x)
+
+        if t is not None:
+            t = self.diffusion_embedding(t)
+            t = self.mlp(t)
+
+        skip = []
+        for layer in self.residual_layers:
+            x, skip_connection = layer(x, condition, t)
+            skip.append(skip_connection)
+
+        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
+        x = self.skip_projection(x)
+
+        if self.output_projection is not None:
+            x = F.silu(x)
+            x = self.output_projection(x)
+
+        return x
diff --git a/modules/finetune/train_gpt.py b/modules/finetune/train_gpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..9642d37016a0b99348d11a79dcdf7b2bdd3c0aef
--- /dev/null
+++ b/modules/finetune/train_gpt.py
@@ -0,0 +1,246 @@
+import functools
+import torch
+import transformers
+import peft
+from transformers.trainer_pt_utils import LabelSmoother
+from utils.dataset import AudioCollator
+from utils.logger import MetricLogger
+from utils.output import ansi, get_ansi_len, output_iter
+
+IGNORE_TOKEN_ID = LabelSmoother.ignore_index
+
+
+def train_gpt_lora(
+    chat,
+    dataset,
+    decoder_encoder,
+    dvae_encoder,
+    batch_size=16,
+    epochs=10,
+    train_text=True,
+    speaker_embeds=None,
+    lora_r=8,
+    lora_alpha=16,
+):
+    if speaker_embeds is None:
+        speaker_embeds = {}
+
+    tokenizer = chat.pretrain_models["tokenizer"]
+    decoder_decoder = chat.pretrain_models["decoder"]
+    decoder_decoder.eval().requires_grad_(False)
+    decoder_encoder.to(device=dataset.device).eval().requires_grad_(False)
+    dvae_decoder = chat.pretrain_models["dvae"]
+    dvae_decoder.eval().requires_grad_(False)
+    dvae_encoder.to(device=dataset.device).eval().requires_grad_(False)
+
+    gpt = chat.pretrain_models["gpt"]
+    gpt.train().requires_grad_()
+
+    # Add LoRA to GPT model
+    lora_config = peft.LoraConfig(r=lora_r, lora_alpha=lora_alpha)
+    gpt.gpt = peft.get_peft_model(gpt.gpt, lora_config)
+
+    speaker_embeds = {
+        speaker: torch.randn(768, device=dataset.device, requires_grad=True)
+        for speaker in dataset.speakers
+    } | speaker_embeds
+
+    for speaker_embed in speaker_embeds.values():
+        std, mean = chat.pretrain_models["spk_stat"].chunk(2)
+        speaker_embed.data = speaker_embed.data * std + mean
+
+    SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
+    AUDIO_EOS_TOKEN_ID = 0
+    AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID
+
+    train_params = list(gpt.parameters()) + list(speaker_embeds.values())
+    optimizer = torch.optim.Adam(
+        gpt.parameters(), lr=1e-3, weight_decay=0, betas=[0.9, 0.95], eps=1e-5
+    )
+    optimizer.add_param_group({"params": speaker_embeds.values(), "lr": 1e-1})
+
+    loss_fn = torch.nn.CrossEntropyLoss()
+    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7)
+
+    loader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=batch_size,
+        shuffle=True,
+        collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id),
+    )
+    logger = MetricLogger()
+    logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None)
+
+    for _epoch in range(epochs):
+        _epoch += 1
+        logger.reset()
+        header = "{blue_light}{0}: {1}{reset}".format(
+            "Epoch", output_iter(_epoch, epochs), **ansi
+        )
+        header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header))
+        iterator = logger.log_every(loader, header=header, tqdm_header="Batch")
+
+        for batch in iterator:
+            speakers = batch["speaker"]
+            text_input_ids = batch["text_input_ids"]
+            text_attention_mask = batch["text_attention_mask"]
+            audio_mel_specs = batch["audio_mel_specs"]
+            audio_attention_mask = batch["audio_attention_mask"]
+
+            batch_size, text_len = text_attention_mask.size()
+
+            dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask)
+            _, dvae_audio_input_ids = quantize(
+                dvae_decoder.vq_layer.quantizer, dvae_audio_latents
+            )
+            dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID
+
+            extended_audio_attention_mask = torch.cat(
+                [
+                    audio_attention_mask,
+                    torch.zeros(
+                        (batch_size, 1),
+                        dtype=audio_attention_mask.dtype,
+                        device=audio_attention_mask.device,
+                    ),
+                ],
+                dim=1,
+            )
+            extended_audio_input_ids = torch.cat(
+                [
+                    dvae_audio_input_ids,
+                    AUDIO_PAD_TOKEN_ID
+                    * torch.ones(
+                        (batch_size, 1, gpt.num_vq),
+                        dtype=dvae_audio_input_ids.dtype,
+                        device=dvae_audio_input_ids.device,
+                    ),
+                ],
+                dim=1,
+            )
+
+            indices = audio_attention_mask.int().sum(dim=1)
+            for i in range(batch_size):
+                extended_audio_attention_mask[i, indices[i]] = 1
+                extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID
+
+            input_ids = torch.cat(
+                [
+                    text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq),
+                    extended_audio_input_ids,
+                ],
+                dim=1,
+            )
+            attention_mask = torch.cat(
+                [text_attention_mask, extended_audio_attention_mask], dim=1
+            )
+            text_mask = torch.cat(
+                [
+                    torch.ones_like(text_attention_mask, dtype=bool),
+                    torch.zeros_like(extended_audio_attention_mask, dtype=bool),
+                ],
+                dim=1,
+            )
+            labels = input_ids.clone()
+            labels[~attention_mask.bool()] = IGNORE_TOKEN_ID
+
+            inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask)
+
+            indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1)
+            for i, speaker in enumerate(speakers):
+                inputs_embeds[i, indices[i]] = torch.nn.functional.normalize(
+                    speaker_embeds[speaker].to(dtype=inputs_embeds.dtype),
+                    p=2.0,
+                    dim=-1,
+                    eps=1e-12,
+                ).unsqueeze(0)
+
+            outputs = gpt.gpt.forward(
+                inputs_embeds=inputs_embeds, attention_mask=attention_mask
+            )
+            hidden_states = outputs.last_hidden_state
+            text_hidden_states = hidden_states[:, : text_len - 1]
+            audio_hidden_states = hidden_states[:, text_len - 1 : -1]
+
+            audio_logits = torch.stack(
+                [gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)],
+                dim=2,
+            )
+            audio_loss = loss_fn(
+                audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
+            )
+            loss = audio_loss
+
+            if train_text:
+                text_logits = gpt.head_text(text_hidden_states)
+                text_loss = loss_fn(
+                    text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
+                )
+                loss += text_loss
+                logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
+
+            gpt_gen_mel_specs = decoder_decoder(
+                audio_hidden_states[:, :-1].transpose(1, 2)
+            ).transpose(1, 2)
+            mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs)
+            loss += 0.01 * mse_loss
+
+            optimizer.zero_grad()
+            loss.backward()
+            torch.nn.utils.clip_grad_norm_(train_params, 1.0)
+            optimizer.step()
+
+            logger.meters["loss"].update(loss.item(), n=batch_size)
+            logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size)
+            logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size)
+
+        lr_scheduler.step()
+    optimizer.zero_grad()
+    return speaker_embeds
+
+
+# Example usage
+def main():
+    # Load necessary models and data paths
+    chat = ChatTTS.Chat()
+    chat.load_models()
+    dataset = XzListTar(
+        root="data/all.list",
+        tokenizer=chat.pretrain_models["tokenizer"],
+        vocos_model=chat.pretrain_models["vocos"],
+        tar_path="data/Xz.tar",
+        tar_in_memory=True,
+        process_ahead=True,
+    )
+
+    decoder_encoder = DVAEEncoder(
+        **get_encoder_config(chat.pretrain_models["decoder"].decoder)
+    )
+    dvae_encoder = DVAEEncoder(
+        **get_encoder_config(chat.pretrain_models["dvae"].decoder)
+    )
+
+    # Train GPT with LoRA
+    speaker_embeds = train_gpt_lora(
+        chat=chat,
+        dataset=dataset,
+        decoder_encoder=decoder_encoder,
+        dvae_encoder=dvae_encoder,
+        batch_size=32,
+        epochs=10,
+        train_text=True,
+        lora_r=8,
+        lora_alpha=16,
+    )
+
+    # Save LoRA parameters and embeddings
+    lora_save_path = "./saved_models/gpt_lora.pth"
+    peft.save_pretrained(gpt.gpt, lora_save_path)
+    np.savez(
+        "./saved_models/speaker_embeds.npz",
+        **{k: v.cpu().numpy() for k, v in speaker_embeds.items()}
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/modules/finetune/train_speaker.py b/modules/finetune/train_speaker.py
new file mode 100644
index 0000000000000000000000000000000000000000..343d743d6c2acff51c170fde6fc3ae32bb47c482
--- /dev/null
+++ b/modules/finetune/train_speaker.py
@@ -0,0 +1,296 @@
+import torch
+import torch.nn.functional as F
+import transformers
+
+from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config
+from modules.finetune.utils.output import get_ansi_len, output_iter, ansi
+from .utils.logger import MetricLogger
+from .utils.dataset import AudioCollator, XzListTar
+from .utils.model import quantize
+
+IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index
+
+
+def train_speaker_embeddings(
+    chat,
+    dataset,
+    gpt,
+    batch_size=16,
+    epochs=10,
+    train_text=True,
+    speaker_embeds=None,
+):
+    tokenizer = chat.pretrain_models["tokenizer"]
+
+    decoder_decoder = chat.pretrain_models["decoder"]
+    decoder_decoder.eval().requires_grad_(False)
+    decoder_encoder = DVAEEncoder(**get_encoder_config(decoder_decoder.decoder)).to(
+        device=dataset.device
+    )
+    decoder_encoder.eval().requires_grad_(False)
+
+    dvae_decoder = chat.pretrain_models["dvae"]
+    dvae_decoder.eval().requires_grad_(False)
+    dvae_encoder = DVAEEncoder(**get_encoder_config(dvae_decoder.decoder)).to(
+        device=dataset.device
+    )
+    dvae_encoder.eval().requires_grad_(False)
+
+    if speaker_embeds is None:
+        speaker_embeds = {
+            speaker: torch.randn(
+                768,
+                device=dataset.device,
+                requires_grad=True,
+            )
+            for speaker in dataset.speakers
+        }
+        for speaker_embed in speaker_embeds.values():
+            std, mean = chat.pretrain_models["spk_stat"].chunk(2)
+            speaker_embed.data = speaker_embed.data * std + mean
+
+    SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
+    AUDIO_EOS_TOKEN_ID = 0
+    AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID
+
+    optimizer = torch.optim.Adam(
+        speaker_embeds.values(), lr=1e-2, weight_decay=0, betas=[0.9, 0.95], eps=1e-5
+    )
+    loss_fn = torch.nn.CrossEntropyLoss()
+    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7)
+
+    loader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=batch_size,
+        shuffle=True,
+        collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id),
+    )
+    logger = MetricLogger()
+    logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None)
+
+    for _epoch in range(epochs):
+        _epoch += 1
+        logger.reset()
+        header = "{blue_light}{0}: {1}{reset}".format(
+            "Epoch", output_iter(_epoch, epochs), **ansi
+        )
+        header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header))
+        iterator = logger.log_every(loader, header=header, tqdm_header="Batch")
+
+        for batch in iterator:
+            speakers = batch["speaker"]
+            text_input_ids = batch["text_input_ids"]
+            text_attention_mask = batch["text_attention_mask"]
+            audio_mel_specs = batch["audio_mel_specs"]
+            audio_attention_mask = batch["audio_attention_mask"]
+
+            batch_size, text_len = text_attention_mask.size()
+
+            dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask)
+            _, dvae_audio_input_ids = quantize(
+                dvae_decoder.vq_layer.quantizer, dvae_audio_latents
+            )
+            dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID
+
+            extended_audio_attention_mask = torch.cat(
+                [
+                    audio_attention_mask,
+                    torch.zeros(
+                        (batch_size, 1),
+                        dtype=audio_attention_mask.dtype,
+                        device=audio_attention_mask.device,
+                    ),
+                ],
+                dim=1,
+            )
+            extended_audio_input_ids = torch.cat(
+                [
+                    dvae_audio_input_ids,
+                    AUDIO_PAD_TOKEN_ID
+                    * torch.ones(
+                        (batch_size, 1, gpt.num_vq),
+                        dtype=dvae_audio_input_ids.dtype,
+                        device=dvae_audio_input_ids.device,
+                    ),
+                ],
+                dim=1,
+            )
+            indices = audio_attention_mask.int().sum(dim=1)
+            for i in range(batch_size):
+                extended_audio_attention_mask[i, indices[i]] = 1
+                extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID
+
+            input_ids = torch.cat(
+                [
+                    text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq),
+                    extended_audio_input_ids,
+                ],
+                dim=1,
+            )
+            attention_mask = torch.cat(
+                [text_attention_mask, extended_audio_attention_mask], dim=1
+            )
+            text_mask = torch.cat(
+                [
+                    torch.ones_like(text_attention_mask, dtype=bool),
+                    torch.zeros_like(extended_audio_attention_mask, dtype=bool),
+                ],
+                dim=1,
+            )
+
+            labels = input_ids.clone()
+            labels[~attention_mask.bool()] = IGNORE_TOKEN_ID
+
+            inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask)
+
+            indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1)
+            for i, speaker in enumerate(speakers):
+                inputs_embeds[i, indices[i]] = F.normalize(
+                    speaker_embeds[speaker].to(dtype=inputs_embeds.dtype),
+                    p=2.0,
+                    dim=-1,
+                    eps=1e-12,
+                ).unsqueeze(0)
+            outputs = gpt.gpt.forward(
+                inputs_embeds=inputs_embeds, attention_mask=attention_mask
+            )
+            hidden_states = outputs.last_hidden_state
+            text_hidden_states = hidden_states[:, : text_len - 1]
+            audio_hidden_states = hidden_states[:, text_len - 1 : -1]
+
+            audio_logits = torch.stack(
+                [gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)],
+                dim=2,
+            )
+            audio_loss = loss_fn(
+                audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
+            )
+            loss = audio_loss
+            if train_text:
+                text_logits = gpt.head_text(text_hidden_states)
+                text_loss = loss_fn(
+                    text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
+                )
+                loss += text_loss
+                logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
+
+            gpt_gen_mel_specs = decoder_decoder(
+                audio_hidden_states[:, :-1].transpose(1, 2)
+            ).transpose(1, 2)
+            mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs)
+            loss += 0.01 * mse_loss
+
+            optimizer.zero_grad()
+            loss.backward()
+            torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
+            optimizer.step()
+            logger.meters["loss"].update(loss.item(), n=batch_size)
+            logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size)
+            logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size)
+        lr_scheduler.step()
+    optimizer.zero_grad()
+    return speaker_embeds
+
+
+if __name__ == "__main__":
+    import argparse
+    import os
+    import numpy as np
+    import pathlib
+    from modules.models import load_chat_tts
+    from modules.devices import devices
+    from modules import config
+    from modules.speaker import Speaker
+
+    config.runtime_env_vars.no_half = True
+    devices.reset_device()
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--save_folder", type=str, default="./")
+    parser.add_argument("--batch_size", type=int, default=16)
+    parser.add_argument("--epochs", type=int, default=100)
+    parser.add_argument("--train_text", action="store_true", help="train text loss")
+    # 初始化 speaker
+    parser.add_argument("--init_speaker", type=str)
+    parser.add_argument(
+        "--data_path",
+        type=str,
+        default="datasets/data_speaker_a/speaker_a.list",
+        help="the data_path to json/list file",
+    )
+    parser.add_argument("--tar_path", type=str, help="the tarball path with wavs")
+    parser.add_argument(
+        "--tar_in_memory", action="store_true", help="load tarball in memory"
+    )
+
+    args = parser.parse_args()
+
+    data_path: str = args.data_path
+    tar_path: str | None = args.tar_path
+    tar_in_memory: bool = args.tar_in_memory
+    train_text: bool = args.train_text
+    # gpt_lora: bool = args.gpt_lora
+    # gpt_kbit: int = args.gpt_kbit
+    save_folder: str = args.save_folder
+    batch_size: int = args.batch_size
+    epochs: int = args.epochs
+    init_speaker: str = args.init_speaker
+
+    speaker_embeds_save_path = os.path.join(save_folder, "speaker_embeds.npz")
+
+    chat = load_chat_tts()
+    dataset = XzListTar(
+        root=data_path,
+        tokenizer=chat.pretrain_models["tokenizer"],
+        vocos_model=chat.pretrain_models["vocos"],
+        tar_path=tar_path,
+        tar_in_memory=tar_in_memory,
+        device=devices.device,
+        # speakers=None,  # set(['speaker_A', 'speaker_B'])
+    )
+
+    print("len(dataset)", len(dataset))
+
+    speaker_embeds = None
+    if init_speaker:
+        spk: Speaker = Speaker.from_file(init_speaker)
+        speaker_embeds = {
+            speaker: torch.tensor(
+                spk.emb,
+                device=devices.device,
+                requires_grad=True,
+            )
+            for speaker in dataset.speakers
+        }
+
+    speaker_embeds = train_speaker_embeddings(
+        chat,
+        dataset,
+        chat.pretrain_models["gpt"],
+        batch_size=batch_size,
+        epochs=epochs,
+        train_text=train_text,
+        speaker_embeds=speaker_embeds,
+    )
+    speaker_outs = {
+        speaker: Speaker(speaker_embed.detach().cpu(), f"ep{epochs}_{speaker}")
+        for speaker, speaker_embed in speaker_embeds.items()
+    }
+    time_str = np.datetime_as_string(np.datetime64("now", "s"))
+    time_str = time_str.replace(":", "_").replace(" ", "_").replace("-", "_")
+    for speaker, speaker_out in speaker_outs.items():
+        torch.save(
+            speaker_out,
+            pathlib.Path(save_folder) / f"spk_{speaker}_{time_str}_ep{epochs}.pt",
+        )
+
+# example
+"""
+python -m modules.finetune.train_speaker \
+    --data_path datasets/data_speaker_a/speaker_a.list \
+    --save_folder ./data \
+    --init_speaker ./data/speakers/Bob.pt \
+    --epochs 100 \
+    --batch_size 6 \
+    --train_text
+"""
diff --git a/modules/finetune/utils/__init__.py b/modules/finetune/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/finetune/utils/dataset.py b/modules/finetune/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5d86dae96de753e9c648b62329844958c3a33eb
--- /dev/null
+++ b/modules/finetune/utils/dataset.py
@@ -0,0 +1,487 @@
+import os
+import functools
+import json
+import tarfile
+import io
+import logging
+import abc
+import typing
+
+import torch.utils.data
+import torchaudio
+from torchvision.datasets.utils import download_url
+import transformers
+import vocos
+
+from modules.ChatTTS.ChatTTS.utils.infer_utils import (
+    count_invalid_characters,
+    apply_character_map,
+)
+
+
+class LazyDataType(typing.TypedDict):
+    filepath: str
+    speaker: str
+    lang: str
+    text: str
+
+
+class DataType(LazyDataType):
+    text_input_ids: torch.Tensor  # (batch_size, text_len)
+    text_attention_mask: torch.Tensor  # (batch_size, text_len)
+    audio_mel_specs: torch.Tensor  # (batch_size, audio_len*2, 100)
+    audio_attention_mask: torch.Tensor  # (batch_size, audio_len)
+
+
+class XzListTarKwargsType(typing.TypedDict):
+    tokenizer: typing.Union[transformers.PreTrainedTokenizer, None]
+    vocos_model: typing.Union[vocos.Vocos, None]
+    device: typing.Union[str, torch.device, None]
+    speakers: typing.Union[typing.Iterable[str], None]
+    sample_rate: typing.Union[int]
+    default_speaker: typing.Union[str, None]
+    default_lang: typing.Union[str, None]
+    tar_in_memory: typing.Union[bool, None]
+    process_ahead: typing.Union[bool, None]
+
+
+class AudioFolder(torch.utils.data.Dataset, abc.ABC):
+    def __init__(
+        self,
+        root: str | io.BytesIO,
+        tokenizer: transformers.PreTrainedTokenizer | None = None,
+        vocos_model: vocos.Vocos | None = None,
+        device: str | torch.device | None = None,
+        speakers: typing.Iterable[str] | None = None,
+        sample_rate: int = 24_000,
+        default_speaker: str | None = None,
+        default_lang: str | None = None,
+        tar_path: str | None = None,
+        tar_in_memory: bool = False,
+        process_ahead: bool = False,
+    ) -> None:
+        self.root = root
+        self.sample_rate = sample_rate
+        self.default_speaker = default_speaker
+        self.default_lang = default_lang
+
+        self.logger = logging.getLogger(__name__)
+        self.normalizer = {}
+
+        self.tokenizer = tokenizer
+        self.vocos = vocos_model
+        self.vocos_device = (
+            None if self.vocos is None else next(self.vocos.parameters()).device
+        )
+        self.device = device or self.vocos_device
+
+        # tar -cvf ../Xz.tar *
+        # tar -xf Xz.tar -C ./Xz
+        self.tar_path = tar_path
+        self.tar_file = None
+        self.tar_io = None
+        if tar_path is not None:
+            if tar_in_memory:
+                with open(tar_path, "rb") as f:
+                    self.tar_io = io.BytesIO(f.read())
+                self.tar_file = tarfile.open(fileobj=self.tar_io)
+            else:
+                self.tar_file = tarfile.open(tar_path)
+
+        self.lazy_data, self.speakers = self.get_lazy_data(root, speakers)
+
+        self.text_input_ids: dict[int, torch.Tensor] = {}
+        self.audio_mel_specs: dict[int, torch.Tensor] = {}
+        if process_ahead:
+            for n, item in enumerate(self.lazy_data):
+                self.audio_mel_specs[n] = self.preprocess_audio(item["filepath"])
+                self.text_input_ids[n] = self.preprocess_text(
+                    item["text"], item["lang"]
+                )
+            if self.tar_file is not None:
+                self.tar_file.close()
+            if self.tar_io is not None:
+                self.tar_io.close()
+
+    @abc.abstractmethod
+    def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: ...
+
+    @staticmethod
+    @abc.abstractmethod
+    def save_config(
+        save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
+    ) -> None: ...
+
+    def __len__(self):
+        return len(self.lazy_data)
+
+    def __getitem__(self, n: int) -> DataType:
+        lazy_data = self.lazy_data[n]
+        if n in self.audio_mel_specs:
+            audio_mel_specs = self.audio_mel_specs[n]
+            text_input_ids = self.text_input_ids[n]
+        else:
+            audio_mel_specs = self.preprocess_audio(lazy_data["filepath"])
+            text_input_ids = self.preprocess_text(lazy_data["text"], lazy_data["lang"])
+            self.audio_mel_specs[n] = audio_mel_specs
+            self.text_input_ids[n] = text_input_ids
+            if len(self.audio_mel_specs) == len(self.lazy_data):
+                if self.tar_file is not None:
+                    self.tar_file.close()
+                if self.tar_io is not None:
+                    self.tar_io.close()
+        text_attention_mask = torch.ones(
+            len(text_input_ids), device=text_input_ids.device
+        )
+        audio_attention_mask = torch.ones(
+            (len(audio_mel_specs) + 1) // 2,
+            device=audio_mel_specs.device,
+        )
+        return {
+            "filepath": lazy_data["filepath"],
+            "speaker": lazy_data["speaker"],
+            "lang": lazy_data["lang"],
+            "text": lazy_data["text"],
+            "text_input_ids": text_input_ids,
+            "text_attention_mask": text_attention_mask,
+            "audio_mel_specs": audio_mel_specs,
+            "audio_attention_mask": audio_attention_mask,
+        }
+
+    def get_lazy_data(
+        self,
+        root: str | io.BytesIO,
+        speakers: typing.Iterable[str] | None = None,
+    ) -> tuple[list[LazyDataType], set[str]]:
+        if speakers is not None:
+            new_speakers = set(speakers)
+        else:
+            new_speakers = set()
+        lazy_data = []
+
+        raw_data = self.get_raw_data(root)
+        folder_path = os.path.dirname(root) if isinstance(root, str) else ""
+        for item in raw_data:
+            if "speaker" not in item:
+                item["speaker"] = self.default_speaker
+            if "lang" not in item:
+                item["lang"] = self.default_lang
+
+            if speakers is not None and item["speaker"] not in speakers:
+                continue
+            if speakers is None and item["speaker"] not in new_speakers:
+                new_speakers.add(item["speaker"])
+            if self.tar_file is None and isinstance(root, str):
+                filepath = os.path.join(folder_path, item["filepath"])
+            else:
+                filepath = item["filepath"]
+            lazy_data.append(
+                {
+                    "filepath": filepath,
+                    "speaker": item["speaker"],
+                    "lang": item["lang"].lower(),
+                    "text": item["text"],
+                }
+            )
+        return lazy_data, new_speakers
+
+    def preprocess_text(
+        self,
+        text: str,
+        lang: str,
+    ) -> torch.Tensor:
+        invalid_characters = count_invalid_characters(text)
+        if len(invalid_characters):
+            # self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
+            text = apply_character_map(text)
+
+        # if not skip_refine_text:
+        #     text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
+        #     text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
+        #     text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
+        #     if refine_text_only:
+        #         return text
+
+        text = f"[Stts][spk_emb]{text}[Ptts]"
+        # text = f'[Stts][empty_spk]{text}[Ptts]'
+
+        text_token = self.tokenizer(
+            text, return_tensors="pt", add_special_tokens=False
+        ).to(device=self.device)
+        return text_token["input_ids"].squeeze(0)
+
+    def preprocess_audio(self, filepath: str) -> torch.Tensor:
+        if self.tar_file is not None:
+            file = self.tar_file.extractfile(filepath)
+            waveform, sample_rate = torchaudio.load(file)
+        else:
+            waveform, sample_rate = torchaudio.load(filepath)
+        waveform = waveform.to(device=self.vocos_device)
+        if sample_rate != self.sample_rate:
+            waveform = torchaudio.functional.resample(
+                waveform,
+                orig_freq=sample_rate,
+                new_freq=self.sample_rate,
+            )
+        mel_spec: torch.Tensor = self.vocos.feature_extractor(waveform)
+        return (
+            mel_spec.to(device=self.device).squeeze(0).transpose(0, 1)
+        )  # (audio_len*2, 100)
+
+
+class JsonFolder(AudioFolder):
+    """
+    In json file, each item is formatted as following example:
+    `{"filepath": "path/to/file.wav", "speaker": "John", "lang": "ZH", "text": "Hello"}`.
+
+    filepath is relative to the dirname of root json file.
+    """
+
+    def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
+        with open(root, "r", encoding="utf-8") as f:
+            raw_data = json.load(f)
+        return raw_data
+
+    @staticmethod
+    def save_config(
+        save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
+    ) -> None:
+        save_data = [item.copy() for item in lazy_data]
+        for item in save_data:
+            item["filepath"] = os.path.relpath(item["filepath"], rel_path)
+        with open(save_path, "w", encoding="utf-8") as f:
+            json.dump(save_data, f, ensure_ascii=False, indent=4)
+
+
+class ListFolder(AudioFolder):
+    """
+    In list file, each row is formatted as `filepath|speaker|lang|text` with `|` as separator.
+    `path/to/file.wav|John|ZH|Hello`.
+
+    filepath is relative to the dirname of root list file.
+    """
+
+    def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
+        raw_data = []
+        with open(root, "r", encoding="utf-8") as f:
+            for line in f.readlines():
+                line = line.strip().removesuffix("\n")
+                if len(line) == 0:
+                    continue
+                filepath, speaker, lang, text = line.split(sep="|", maxsplit=3)
+                raw_data.append(
+                    {
+                        "text": text,
+                        "filepath": filepath,
+                        "speaker": speaker,
+                        "lang": lang,
+                    }
+                )
+        return raw_data
+
+    @staticmethod
+    def save_config(
+        save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
+    ) -> None:
+        save_data = [item.copy() for item in lazy_data]
+        for item in save_data:
+            item["filepath"] = os.path.relpath(item["filepath"], rel_path)
+        with open(save_path, "w", encoding="utf-8") as f:
+            for item in save_data:
+                f.write(
+                    f"{item['filepath']}|{item['speaker']}|{item['lang']}|{item['text']}\n"
+                )
+
+
+class XzListTar(ListFolder):
+    def __init__(
+        self,
+        *args,
+        root: str | io.BytesIO,
+        tar_path: str | None = None,
+        **kwargs,
+    ):
+        if isinstance(root, io.BytesIO):
+            assert tar_path is not None
+        else:
+            # make sure root is a list file
+            if not root.endswith(".list"):  # folder case
+                if os.path.isfile(root):
+                    raise FileExistsError(f"{root} is a file!")
+                elif not os.path.exists(root):
+                    os.makedirs(root)
+                root = os.path.join(root, "all.list")
+        if isinstance(root, str) and not os.path.isfile(root):
+            # prepare all.list
+            self.concat_dataset(
+                save_folder=os.path.dirname(root),
+                langs=kwargs.get("langs", ["zh", "en"]),
+            )
+
+        super().__init__(root, *args, tar_path=tar_path, **kwargs)
+
+    def concat_dataset(
+        self, save_folder: str | None = None, langs: list[str] = ["zh", "en"]
+    ) -> None:
+        if save_folder is None:
+            save_folder = os.path.dirname(self.root)
+        if os.path.isfile(save_folder):
+            raise FileExistsError(f"{save_folder} already exists as a file!")
+        elif not os.path.exists(save_folder):
+            os.makedirs(save_folder)
+        lazy_data = []
+
+        for member in self.tar_file.getmembers():
+            if not member.isfile():
+                continue
+            if member.name.endswith(".list"):
+                print(member.name)
+                root_io = self.tar_file.extractfile(member)
+                lazy_data += ListFolder(root_io).lazy_data
+            if member.name.endswith(".json"):
+                print(member.name)
+                root_io = self.tar_file.extractfile(member)
+                lazy_data += JsonFolder(root_io).lazy_data
+        if langs is not None:
+            lazy_data = [item for item in lazy_data if item["lang"] in langs]
+        ListFolder.save_config(os.path.join(save_folder, "all.list"), lazy_data)
+        JsonFolder.save_config(os.path.join(save_folder, "all.json"), lazy_data)
+        print(f"all.list and all.json are saved to {save_folder}")
+
+
+class XzListFolder(ListFolder):
+    """
+    [Xz乔希](https://space.bilibili.com/5859321)
+
+    Only look at the basename of filepath in list file. Previous folder paths are ignored.
+    Files are organized as `[list basename]/[file basename]`
+
+    Example tree structure:
+
+    [folder]
+    ├── speaker_A
+    │   ├── 1.wav
+    │   └── 2.wav
+    ├── speaker_A.list
+    ├── speaker_B
+    │   ├── 1.wav
+    │   └── 2.wav
+    └── speaker_B.list
+    """
+
+    def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
+        raw_data = super().get_raw_data(root)
+        for item in raw_data:
+            item["filepath"] = os.path.join(
+                os.path.basename(root).removesuffix(".list"),
+                os.path.basename(item["filepath"]),
+            )
+        return raw_data
+
+
+class AudioCollator:
+    def __init__(self, text_pad: int = 0, audio_pad: int = 0):
+        self.text_pad = text_pad
+        self.audio_pad = audio_pad
+
+    def __call__(self, batch: list[DataType]):
+        batch = [x for x in batch if x is not None]
+
+        audio_maxlen = max(len(item["audio_attention_mask"]) for item in batch)
+        text_maxlen = max(len(item["text_attention_mask"]) for item in batch)
+
+        filepath = []
+        speaker = []
+        lang = []
+        text = []
+        text_input_ids = []
+        text_attention_mask = []
+        audio_mel_specs = []
+        audio_attention_mask = []
+
+        for x in batch:
+            filepath.append(x["filepath"])
+            speaker.append(x["speaker"])
+            lang.append(x["lang"])
+            text.append(x["text"])
+            text_input_ids.append(
+                torch.nn.functional.pad(
+                    x["text_input_ids"],
+                    (text_maxlen - len(x["text_input_ids"]), 0),
+                    value=self.text_pad,
+                )
+            )
+            text_attention_mask.append(
+                torch.nn.functional.pad(
+                    x["text_attention_mask"],
+                    (text_maxlen - len(x["text_attention_mask"]), 0),
+                    value=0,
+                )
+            )
+            audio_mel_specs.append(
+                torch.nn.functional.pad(
+                    x["audio_mel_specs"],
+                    (0, 0, 0, audio_maxlen * 2 - len(x["audio_mel_specs"])),
+                    value=self.audio_pad,
+                )
+            )
+            audio_attention_mask.append(
+                torch.nn.functional.pad(
+                    x["audio_attention_mask"],
+                    (0, audio_maxlen - len(x["audio_attention_mask"])),
+                    value=0,
+                )
+            )
+        return {
+            "filepath": filepath,
+            "speaker": speaker,
+            "lang": lang,
+            "text": text,
+            "text_input_ids": torch.stack(text_input_ids),
+            "text_attention_mask": torch.stack(text_attention_mask),
+            "audio_mel_specs": torch.stack(audio_mel_specs),
+            "audio_attention_mask": torch.stack(audio_attention_mask),
+        }
+
+
+def formalize_xz_list(src_folder: str):
+    for root, _, files in os.walk(src_folder):
+        for file in files:
+            if file.endswith(".list"):
+                filepath = os.path.join(root, file)
+                print(filepath)
+                lazy_data = XzListFolder(filepath).lazy_data
+                XzListFolder.save_config(filepath, lazy_data, rel_path=src_folder)
+
+
+def concat_dataset(
+    src_folder: str, save_folder: str | None = None, langs: list[str] = ["zh", "en"]
+) -> None:
+    if save_folder is None:
+        save_folder = src_folder
+    if os.path.isfile(save_folder):
+        raise FileExistsError(f"{save_folder} already exists as a file!")
+    elif not os.path.exists(save_folder):
+        os.makedirs(save_folder)
+    lazy_data = []
+    same_folder = os.path.samefile(src_folder, save_folder)
+    for root, _, files in os.walk(src_folder):
+        for file in files:
+            filepath = os.path.join(root, file)
+            if same_folder and file in ("all.list", "all.json"):
+                continue
+            if file.endswith(".list"):
+                print(filepath)
+                lazy_data += ListFolder(filepath).lazy_data
+            if file.endswith(".json"):
+                print(filepath)
+                lazy_data += JsonFolder(filepath).lazy_data
+    if langs is not None:
+        lazy_data = [item for item in lazy_data if item["lang"] in langs]
+    ListFolder.save_config(
+        os.path.join(save_folder, "all.list"), lazy_data, rel_path=save_folder
+    )
+    JsonFolder.save_config(
+        os.path.join(save_folder, "all.json"), lazy_data, rel_path=save_folder
+    )
+    print(f"all.list and all.json are saved to {save_folder}")
diff --git a/modules/finetune/utils/logger.py b/modules/finetune/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e59b5cfd272cfb4294822924aba108e143d7310
--- /dev/null
+++ b/modules/finetune/utils/logger.py
@@ -0,0 +1,409 @@
+#!/usr/bin/env python3
+
+import statistics
+import time
+from collections import defaultdict, deque
+from tqdm import tqdm as tqdm_class
+
+from typing import Generator, Iterable, TypeVar
+from typing_extensions import Self
+
+import torch
+import torch.distributed as dist
+
+from .output import ansi, prints, get_ansi_len
+
+__all__ = ["SmoothedValue", "MetricLogger"]
+
+MB = 1 << 20
+T = TypeVar("T")
+
+
+class SmoothedValue:
+    r"""Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+
+    See Also:
+        https://github.com/pytorch/vision/blob/main/references/classification/utils.py
+
+    Args:
+        name (str): Name string.
+        window_size (int): The :attr:`maxlen` of :class:`~collections.deque`.
+        fmt (str): The format pattern of ``str(self)``.
+
+    Attributes:
+        name (str): Name string.
+        fmt (str): The string pattern.
+        deque (~collections.deque): The unique data series.
+        count (int): The amount of data.
+        total (float): The sum of all data.
+
+        median (float): The median of :attr:`deque`.
+        avg (float): The avg of :attr:`deque`.
+        global_avg (float): :math:`\frac{\text{total}}{\text{count}}`
+        max (float): The max of :attr:`deque`.
+        min (float): The min of :attr:`deque`.
+        last_value (float): The last value of :attr:`deque`.
+    """
+
+    def __init__(
+        self, name: str = "", window_size: int = None, fmt: str = "{global_avg:.3f}"
+    ):
+        self.name = name
+        self.deque: deque[float] = deque(maxlen=window_size)
+        self.count: int = 0
+        self.total: float = 0.0
+        self.fmt = fmt
+
+    def update(self, value: float, n: int = 1) -> Self:
+        r"""Update :attr:`n` pieces of data with same :attr:`value`.
+
+        .. code-block:: python
+
+            self.deque.append(value)
+            self.total += value * n
+            self.count += n
+
+        Args:
+            value (float): the value to update.
+            n (int): the number of data with same :attr:`value`.
+
+        Returns:
+            SmoothedValue: return ``self`` for stream usage.
+        """
+        self.deque.append(value)
+        self.total += value * n
+        self.count += n
+        return self
+
+    def update_list(self, value_list: list[float]) -> Self:
+        r"""Update :attr:`value_list`.
+
+        .. code-block:: python
+
+            for value in value_list:
+                self.deque.append(value)
+                self.total += value
+            self.count += len(value_list)
+
+        Args:
+            value_list (list[float]): the value list to update.
+
+        Returns:
+            SmoothedValue: return ``self`` for stream usage.
+        """
+        for value in value_list:
+            self.deque.append(value)
+            self.total += value
+        self.count += len(value_list)
+        return self
+
+    def reset(self) -> Self:
+        r"""Reset ``deque``, ``count`` and ``total`` to be empty.
+
+        Returns:
+            SmoothedValue: return ``self`` for stream usage.
+        """
+        self.deque = deque(maxlen=self.deque.maxlen)
+        self.count = 0
+        self.total = 0.0
+        return self
+
+    def synchronize_between_processes(self):
+        r"""
+        Warning:
+            Does NOT synchronize the deque!
+        """
+        if not (dist.is_available() and dist.is_initialized()):
+            return
+        t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+        dist.barrier()
+        dist.all_reduce(t)
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = float(t[1])
+
+    @property
+    def median(self) -> float:
+        try:
+            return statistics.median(self.deque)
+        except Exception:
+            return 0.0
+
+    @property
+    def avg(self) -> float:
+        try:
+            return statistics.mean(self.deque)
+        except Exception:
+            return 0.0
+
+    @property
+    def global_avg(self) -> float:
+        try:
+            return self.total / self.count
+        except Exception:
+            return 0.0
+
+    @property
+    def max(self) -> float:
+        try:
+            return max(self.deque)
+        except Exception:
+            return 0.0
+
+    @property
+    def min(self) -> float:
+        try:
+            return min(self.deque)
+        except Exception:
+            return 0.0
+
+    @property
+    def last_value(self) -> float:
+        try:
+            return self.deque[-1]
+        except Exception:
+            return 0.0
+
+    def __str__(self):
+        return self.fmt.format(
+            name=self.name,
+            count=self.count,
+            total=self.total,
+            median=self.median,
+            avg=self.avg,
+            global_avg=self.global_avg,
+            min=self.min,
+            max=self.max,
+            last_value=self.last_value,
+        )
+
+    def __format__(self, format_spec: str) -> str:
+        return self.__str__()
+
+
+class MetricLogger:
+    r"""
+    See Also:
+        https://github.com/pytorch/vision/blob/main/references/classification/utils.py
+
+    Args:
+        delimiter (str): The delimiter to join different meter strings.
+            Defaults to ``''``.
+        meter_length (int): The minimum length for each meter.
+            Defaults to ``20``.
+        tqdm (bool): Whether to use tqdm to show iteration information.
+            Defaults to ``env['tqdm']``.
+        indent (int): The space indent for the entire string.
+            Defaults to ``0``.
+
+    Attributes:
+        meters (dict[str, SmoothedValue]): The meter dict.
+        iter_time (SmoothedValue): Iteration time meter.
+        data_time (SmoothedValue): Data loading time meter.
+        memory (SmoothedValue): Memory usage meter.
+    """
+
+    def __init__(
+        self,
+        delimiter: str = "",
+        meter_length: int = 20,
+        tqdm: bool = True,
+        indent: int = 0,
+        **kwargs,
+    ):
+        self.meters: defaultdict[str, SmoothedValue] = defaultdict(SmoothedValue)
+        self.create_meters(**kwargs)
+        self.delimiter = delimiter
+        self.meter_length = meter_length
+        self.tqdm = tqdm
+        self.indent = indent
+
+        self.iter_time = SmoothedValue()
+        self.data_time = SmoothedValue()
+        self.memory = SmoothedValue(fmt="{max:.0f}")
+
+    def create_meters(self, **kwargs: str) -> Self:
+        r"""Create meters with specific ``fmt`` in :attr:`self.meters`.
+
+        ``self.meters[meter_name] = SmoothedValue(fmt=fmt)``
+
+        Args:
+            **kwargs: ``(meter_name: fmt)``
+
+        Returns:
+            MetricLogger: return ``self`` for stream usage.
+        """
+        for k, v in kwargs.items():
+            self.meters[k] = SmoothedValue(fmt="{global_avg:.3f}" if v is None else v)
+        return self
+
+    def update(self, n: int = 1, **kwargs: float) -> Self:
+        r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update()`.
+
+        ``self.meters[meter_name].update(float(value), n=n)``
+
+        Args:
+            n (int): the number of data with same value.
+            **kwargs: ``{meter_name: value}``.
+
+        Returns:
+            MetricLogger: return ``self`` for stream usage.
+        """
+        for k, v in kwargs.items():
+            if k not in self.meters:
+                self.meters[k] = SmoothedValue()
+            self.meters[k].update(float(v), n=n)
+        return self
+
+    def update_list(self, **kwargs: list) -> Self:
+        r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update_list()`.
+
+        ``self.meters[meter_name].update_list(value_list)``
+
+        Args:
+            **kwargs: ``{meter_name: value_list}``.
+
+        Returns:
+            MetricLogger: return ``self`` for stream usage.
+        """
+        for k, v in kwargs.items():
+            self.meters[k].update_list(v)
+        return self
+
+    def reset(self) -> Self:
+        r"""Reset meter in :attr:`self.meters` by calling :meth:`SmoothedValue.reset()`.
+
+        Returns:
+            MetricLogger: return ``self`` for stream usage.
+        """
+        for meter in self.meters.values():
+            meter.reset()
+        return self
+
+    def get_str(self, cut_too_long: bool = True, strip: bool = True, **kwargs) -> str:
+        r"""Generate formatted string based on keyword arguments.
+
+        ``key: value`` with max length to be :attr:`self.meter_length`.
+
+        Args:
+            cut_too_long (bool): Whether to cut too long values to first 5 characters.
+                Defaults to ``True``.
+            strip (bool): Whether to strip trailing whitespaces.
+                Defaults to ``True``.
+            **kwargs: Keyword arguments to generate string.
+        """
+        str_list: list[str] = []
+        for k, v in kwargs.items():
+            v_str = str(v)
+            _str: str = "{green}{k}{reset}: {v}".format(k=k, v=v_str, **ansi)
+            max_length = self.meter_length + get_ansi_len(_str)
+            if cut_too_long:
+                _str = _str[:max_length]
+            str_list.append(_str.ljust(max_length))
+        _str = self.delimiter.join(str_list)
+        if strip:
+            _str = _str.rstrip()
+        return _str
+
+    def __getattr__(self, attr: str) -> float:
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in vars(self):  # TODO: use hasattr
+            return vars(self)[attr]
+        raise AttributeError(
+            "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+        )
+
+    def __str__(self) -> str:
+        return self.get_str(**self.meters)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def log_every(
+        self,
+        iterable: Iterable[T],
+        header: str = "",
+        tqdm: bool = None,
+        tqdm_header: str = "Iter",
+        indent: int = None,
+        verbose: int = 1,
+    ) -> Generator[T, None, None]:
+        r"""Wrap an :class:`collections.abc.Iterable` with formatted outputs.
+
+        * Middle Output:
+          ``{tqdm_header}: [ current / total ] str(self) {memory} {iter_time} {data_time} {time}<{remaining}``
+        * Final Output
+          ``{header} str(self) {memory} {iter_time} {data_time} {total_time}``
+
+        Args:
+            iterable (~collections.abc.Iterable): The raw iterator.
+            header (str): The header string for final output.
+                Defaults to ``''``.
+            tqdm (bool): Whether to use tqdm to show iteration information.
+                Defaults to ``self.tqdm``.
+            tqdm_header (str): The header string for middle output.
+                Defaults to ``'Iter'``.
+            indent (int): The space indent for the entire string.
+                if ``None``, use ``self.indent``.
+                Defaults to ``None``.
+            verbose (int): The verbose level of output information.
+        """
+        tqdm = tqdm if tqdm is not None else self.tqdm
+        indent = indent if indent is not None else self.indent
+        iterator = iterable
+        if len(header) != 0:
+            header = header.ljust(30 + get_ansi_len(header))
+        if tqdm:
+            length = len(str(len(iterable)))
+            pattern: str = (
+                "{tqdm_header}: {blue_light}"
+                "[ {red}{{n_fmt:>{length}}}{blue_light} "
+                "/ {red}{{total_fmt}}{blue_light} ]{reset}"
+            ).format(tqdm_header=tqdm_header, length=length, **ansi)
+            offset = len(f"{{n_fmt:>{length}}}{{total_fmt}}") - 2 * length
+            pattern = pattern.ljust(30 + offset + get_ansi_len(pattern))
+            time_str = self.get_str(time="{elapsed}<{remaining}", cut_too_long=False)
+            bar_format = f"{pattern}{{desc}}{time_str}"
+            iterator = tqdm_class(iterable, leave=False, bar_format=bar_format)
+
+        self.iter_time.reset()
+        self.data_time.reset()
+        self.memory.reset()
+
+        end = time.time()
+        start_time = time.time()
+        for obj in iterator:
+            cur_data_time = time.time() - end
+            self.data_time.update(cur_data_time)
+            yield obj
+            cur_iter_time = time.time() - end
+            self.iter_time.update(cur_iter_time)
+            if torch.cuda.is_available():
+                cur_memory = torch.cuda.max_memory_allocated() / MB
+                self.memory.update(cur_memory)
+            if tqdm:
+                _dict = {k: v for k, v in self.meters.items()}
+                if verbose > 2 and torch.cuda.is_available():
+                    _dict.update(memory=f"{cur_memory:.0f} MB")
+                if verbose > 1:
+                    _dict.update(
+                        iter=f"{cur_iter_time:.3f} s", data=f"{cur_data_time:.3f} s"
+                    )
+                iterator.set_description_str(self.get_str(**_dict, strip=False))
+            end = time.time()
+        self.synchronize_between_processes()
+        total_time = time.time() - start_time
+        total_time_str = tqdm_class.format_interval(total_time)
+
+        _dict = {k: v for k, v in self.meters.items()}
+        if verbose > 2 and torch.cuda.is_available():
+            _dict.update(memory=f"{str(self.memory)} MB")
+        if verbose > 1:
+            _dict.update(
+                iter=f"{str(self.iter_time)} s", data=f"{str(self.data_time)} s"
+            )
+        _dict.update(time=total_time_str)
+        prints(self.delimiter.join([header, self.get_str(**_dict)]), indent=indent)
diff --git a/modules/finetune/utils/model.py b/modules/finetune/utils/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..416cb1bcc7084c8d0e065e4de36a75e43ab47fa6
--- /dev/null
+++ b/modules/finetune/utils/model.py
@@ -0,0 +1,19 @@
+import torch
+from einops import rearrange
+from vector_quantize_pytorch.residual_fsq import GroupedResidualFSQ
+
+
+def quantize(
+    quantizer: GroupedResidualFSQ,
+    audio_latents: torch.Tensor,  # (batch_size, audio_len, audio_dim=1024)
+) -> tuple[torch.Tensor, torch.Tensor]:
+    # feat shape (batch_size, audio_len, audio_dim)
+    # ind shape (GFSQ.G, batch_size, audio_len, GFSQ.R)
+    # num_vq=GFSQ.G*GFSQ.R
+    feat, ind = quantizer(audio_latents)
+    audio_quantized_latents = feat  # (batch_size, audio_len, audio_dim)
+    audio_input_ids = rearrange(  # (batch_size, audio_len, num_vq)
+        ind,
+        "g b t r ->b t (g r)",
+    )
+    return audio_quantized_latents, audio_input_ids
diff --git a/modules/finetune/utils/output.py b/modules/finetune/utils/output.py
new file mode 100644
index 0000000000000000000000000000000000000000..541092ddfa1f33848e1d8ff914ffbeab312db44f
--- /dev/null
+++ b/modules/finetune/utils/output.py
@@ -0,0 +1,146 @@
+#!/usr/bin/env python3
+
+import re
+import sys
+from contextlib import contextmanager
+
+
+class ANSI:
+    ansi_color = {
+        "black": "\033[30m",
+        "red": "\033[31m",
+        "green": "\033[32m",
+        "yellow": "\033[33m",
+        "blue": "\033[34m",
+        "purple": "\033[35m",
+        "blue_light": "\033[36m",
+        "white": "\033[37m",
+        "reset": "\033[0m",
+        "upline": "\033[1A",
+        "clear_line": "\033[2K",
+        "clear": "\033[2J",
+    }
+    ansi_nocolor = {
+        "black": "",
+        "red": "",
+        "green": "",
+        "yellow": "",
+        "blue": "",
+        "purple": "",
+        "blue_light": "",
+        "white": "",
+        "reset": "",
+        "upline": "\033[1A\033[",
+        "clear_line": "\033[K",
+        "clear": "\033[2J",
+    }
+
+    def __init__(self):
+        self._dict = ANSI.ansi_color if ("--color" in sys.argv) else ANSI.ansi_nocolor
+
+    def switch(self, color: bool):
+        self._dict = ANSI.ansi_color if color else ANSI.ansi_nocolor
+
+    def keys(self):
+        return self._dict.keys()
+
+    def items(self):
+        return self._dict.items()
+
+    def __getitem__(self, key):
+        return self._dict[key]
+
+    def __str__(self):
+        return str(self._dict)
+
+    def __repr__(self):
+        return repr(self._dict)
+
+
+ansi = ANSI()
+
+
+def remove_ansi(s: str) -> str:
+    ansi_escape = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
+    return ansi_escape.sub("", s)
+
+
+def get_ansi_len(s: str) -> int:
+    return len(s) - len(remove_ansi(s))
+
+
+def prints(*args: str, indent: int = 0, prefix: str = "", **kwargs):
+    assert indent >= 0
+    new_args = []
+    for arg in args:
+        new_args.append(indent_str(str(arg), indent=indent))
+    if len(new_args):
+        new_args[0] = prefix + str(new_args[0])
+    print(*new_args, **kwargs)
+
+
+def output_iter(_iter: int, iteration: int = None, iter_len: int = 4) -> str:
+    if iteration is None:
+        pattern = "{blue_light}[ {red}{0}{blue_light} ]{reset}"
+        return pattern.format(str(_iter).rjust(iter_len), **ansi)
+    else:
+        iter_str = str(iteration)
+        length = len(iter_str)
+        pattern = (
+            "{blue_light}[ {red}{0}{blue_light} " "/ {red}{1}{blue_light} ]{reset}"
+        )
+        return pattern.format(str(_iter).rjust(length), iter_str, **ansi)
+
+
+def indent_str(s_: str, indent: int = 0) -> str:
+    # modified from torch.nn.modules._addindent
+    if indent > 0 and s_:
+        s_ = indent * " " + str(s_[:-1]).replace("\n", "\n" + indent * " ") + s_[-1]
+    return s_
+
+
+class IndentRedirect:  # TODO: inherit TextIOWrapper?
+    def __init__(self, buffer: bool = True, indent: int = 0):
+        self.__console__ = sys.stdout
+        self.indent = indent
+        self.__buffer: str = None
+        if buffer:
+            self.__buffer = ""
+
+    def write(self, text: str, indent: int = None):
+        indent = indent if indent is not None else self.indent
+        text = indent_str(text, indent=indent)
+        if self.__buffer is None:
+            self.__console__.write(text)
+        else:
+            self.__buffer += text
+
+    def flush(self):
+        if self.__buffer is not None:
+            self.__console__.write(self.__buffer)
+            self.__buffer = ""
+        self.__console__.flush()
+
+    @contextmanager
+    def __call__(self) -> None:
+        try:
+            sys.stdout = self
+            yield
+        finally:
+            sys.stdout = self.__console__
+            self.__buffer = ""
+
+    def enable(self):
+        sys.stdout = self
+
+    def disable(self):
+        if self.__buffer is not None:
+            self.__buffer = ""
+        sys.stdout = self.__console__
+
+    @property
+    def buffer(self) -> str:
+        return self.__buffer
+
+
+redirect = IndentRedirect()
diff --git a/modules/generate_audio.py b/modules/generate_audio.py
index 9fcabe3954c3912b1f008ed8865ca04c12e18a14..a2e4552b9103d2bb13dd030724f93f740ab7f1b8 100644
--- a/modules/generate_audio.py
+++ b/modules/generate_audio.py
@@ -76,6 +76,8 @@ def generate_audio_batch(
             params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
         logger.debug(("spk", spk))
     elif isinstance(spk, Speaker):
+        if not isinstance(spk.emb, torch.Tensor):
+            raise ValueError("spk.pt is broken, please retrain the model.")
         params_infer_code["spk_emb"] = spk.emb
         logger.debug(("spk", spk.name))
     else:
diff --git a/modules/normalization.py b/modules/normalization.py
index 1d740e1ca6b914a37deceb515409e088cb5c29d2..cc6e941f78143b2b46b5eb8f886f55f68417c77f 100644
--- a/modules/normalization.py
+++ b/modules/normalization.py
@@ -120,6 +120,7 @@ character_map = {
     "~": " ",
     "~": " ",
     "/": " ",
+    "·": " ",
 }
 
 character_to_word = {
@@ -282,6 +283,9 @@ def text_normalize(text, is_end=False):
 
 
 if __name__ == "__main__":
+    from modules.devices import devices
+
+    devices.reset_device()
     test_cases = [
         "ChatTTS是专门为对话场景设计的文本转语音模型,例如LLM助手对话任务。它支持英文和中文两种语言。最大的模型使用了10万小时以上的中英文数据进行训练。在HuggingFace中开源的版本为4万小时训练且未SFT的版本.",
         " [oral_9] [laugh_0] [break_0] 电 [speed_0] 影 [speed_0] 中 梁朝伟 [speed_9] 扮演的陈永仁的编号27149",
@@ -319,6 +323,7 @@ State-of-the-art Machine Learning for PyTorch, TensorFlow, and JAX.
         """
 120米
 有12%的概率会下雨
+埃隆·马斯克
 """,
     ]
 
diff --git a/modules/repos_static/resemble_enhance/data/distorter/base.py b/modules/repos_static/resemble_enhance/data/distorter/base.py
index d43d84fa840dd25804d9c5e5dc9673f65fdc5b94..f07ef407fa92190234ead9f7de43d7c5ea3c6b4d 100644
--- a/modules/repos_static/resemble_enhance/data/distorter/base.py
+++ b/modules/repos_static/resemble_enhance/data/distorter/base.py
@@ -2,6 +2,7 @@ import itertools
 import os
 import random
 import time
+from typing import Union
 import warnings
 
 import numpy as np
@@ -87,7 +88,7 @@ class Choice(Effect):
 
 
 class Permutation(Effect):
-    def __init__(self, *effects, n: int | None = None):
+    def __init__(self, *effects, n: Union[int, None] = None):
         super().__init__()
         self.effects = effects
         self.n = n
diff --git a/modules/repos_static/resemble_enhance/data/distorter/custom.py b/modules/repos_static/resemble_enhance/data/distorter/custom.py
index 28428f7789cebb2d174c581111711f4d73f6565b..fdabed6aac1647de9a7ee887f84308effa71c8da 100644
--- a/modules/repos_static/resemble_enhance/data/distorter/custom.py
+++ b/modules/repos_static/resemble_enhance/data/distorter/custom.py
@@ -3,6 +3,7 @@ import random
 from dataclasses import dataclass
 from functools import cached_property
 from pathlib import Path
+from typing import Union
 
 import librosa
 import numpy as np
@@ -16,7 +17,7 @@ _logger = logging.getLogger(__name__)
 
 @dataclass
 class RandomRIR(Effect):
-    rir_dir: Path | None
+    rir_dir: Union[Path, None]
     rir_rate: int = 44_000
     rir_suffix: str = ".npy"
     deterministic: bool = False
@@ -49,7 +50,9 @@ class RandomRIR(Effect):
 
         length = len(wav)
 
-        wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
+        wav = librosa.resample(
+            wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast"
+        )
         rir = self._sample_rir()
 
         wav = signal.convolve(wav, rir, mode="same")
@@ -58,7 +61,9 @@ class RandomRIR(Effect):
         if actlev > 0.99:
             wav = (wav / actlev) * 0.98
 
-        wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
+        wav = librosa.resample(
+            wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast"
+        )
 
         if abs(length - len(wav)) > 10:
             _logger.warning(f"length mismatch: {length} vs {len(wav)}")
diff --git a/modules/repos_static/resemble_enhance/data/distorter/sox.py b/modules/repos_static/resemble_enhance/data/distorter/sox.py
index 92a2d74033d33b975141c1ece7ac5619d1dfcc39..3e08376087683222dd5db98f4c4b25ad0e38b847 100644
--- a/modules/repos_static/resemble_enhance/data/distorter/sox.py
+++ b/modules/repos_static/resemble_enhance/data/distorter/sox.py
@@ -1,6 +1,7 @@
 import logging
 import os
 import random
+from typing import Union
 import warnings
 from functools import partial
 
@@ -29,7 +30,9 @@ class AttachableEffect(Effect):
         chain = augment.EffectChain()
         chain = self.attach(chain)
         tensor = torch.from_numpy(wav)[None].float()  # (1, T)
-        tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
+        tensor = chain.apply(
+            tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr}
+        )
         wav = tensor.numpy()[0]  # (T,)
         return wav
 
@@ -41,7 +44,9 @@ class SoxEffect(AttachableEffect):
         self.kwargs = kwargs
 
     def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
-        _logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
+        _logger.debug(
+            f"Attaching {self.effect_name} with {self.args} and {self.kwargs}"
+        )
         if not hasattr(chain, self.effect_name):
             raise ValueError(f"EffectChain has no attribute {self.effect_name}")
         return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
@@ -115,21 +120,30 @@ class Randint(Generator):
 
 
 class Concat(Generator):
-    def __init__(self, *parts: Generator | str):
+    def __init__(self, *parts: Union[Generator, str]):
         self.parts = parts
 
     def __call__(self):
-        return "".join([part if isinstance(part, str) else part() for part in self.parts])
+        return "".join(
+            [part if isinstance(part, str) else part() for part in self.parts]
+        )
 
 
 class RandomLowpassDistorter(SoxEffect):
     def __init__(self, low=2000, high=16000):
-        super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
+        super().__init__(
+            "sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high))
+        )
 
 
 class RandomBandpassDistorter(SoxEffect):
     def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
-        super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
+        super().__init__(
+            "sinc",
+            "-n",
+            Randint(50, 200),
+            partial(self._fn, low, high, min_width, max_width),
+        )
 
     @staticmethod
     def _fn(low, high, min_width, max_width):
@@ -139,7 +153,15 @@ class RandomBandpassDistorter(SoxEffect):
 
 
 class RandomEqualizer(SoxEffect):
-    def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
+    def __init__(
+        self,
+        low=100,
+        high=4000,
+        q_low=1,
+        q_high=5,
+        db_low: int = -30,
+        db_high: int = 30,
+    ):
         super().__init__(
             "equalizer",
             Uniform(low, high),
@@ -150,7 +172,9 @@ class RandomEqualizer(SoxEffect):
 
 class RandomOverdrive(SoxEffect):
     def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
-        super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
+        super().__init__(
+            "overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high)
+        )
 
 
 class RandomReverb(Chain):
diff --git a/modules/repos_static/resemble_enhance/data/utils.py b/modules/repos_static/resemble_enhance/data/utils.py
index 77f59d345b75cac76c6c423c734ae9c70a626590..38ca25fe36074d615962dc229599cf1b3a548aaa 100644
--- a/modules/repos_static/resemble_enhance/data/utils.py
+++ b/modules/repos_static/resemble_enhance/data/utils.py
@@ -1,5 +1,5 @@
 from pathlib import Path
-from typing import Callable
+from typing import Callable, Union
 
 from torch import Tensor
 
@@ -16,7 +16,9 @@ def rglob_audio_files(path: Path):
     return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
 
 
-def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
+def mix_fg_bg(
+    fg: Tensor, bg: Tensor, alpha: Union[float, Callable[..., float]] = 0.5, eps=1e-7
+):
     """
     Args:
         fg: (b, t)
diff --git a/modules/repos_static/resemble_enhance/denoiser/denoiser.py b/modules/repos_static/resemble_enhance/denoiser/denoiser.py
index c0d9c2b6ffbc471029cee620216a2d080b9dd100..4d672df3431d1877d3c8cb882aa8606d6d8b5d1f 100644
--- a/modules/repos_static/resemble_enhance/denoiser/denoiser.py
+++ b/modules/repos_static/resemble_enhance/denoiser/denoiser.py
@@ -1,4 +1,5 @@
 import logging
+from typing import Union
 
 import torch
 import torch.nn.functional as F
@@ -154,7 +155,7 @@ class Denoiser(nn.Module):
         sep_sin = sin * cos_res + cos * sin_res
         return sep_mag, sep_cos, sep_sin
 
-    def forward(self, x: Tensor, y: Tensor | None = None):
+    def forward(self, x: Tensor, y: Union[Tensor, None] = None):
         """
         Args:
             x: (b t), a mixed audio
diff --git a/modules/repos_static/resemble_enhance/enhancer/download.py b/modules/repos_static/resemble_enhance/enhancer/download.py
index 614b9a4b4f9a1a10b79f12ca1a25821247ea2a16..089181893229ba67c9202e204f994d512975f9fc 100644
--- a/modules/repos_static/resemble_enhance/enhancer/download.py
+++ b/modules/repos_static/resemble_enhance/enhancer/download.py
@@ -1,5 +1,6 @@
 import logging
 from pathlib import Path
+from typing import Union
 
 import torch
 
@@ -12,14 +13,18 @@ def get_source_url(relpath):
     return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
 
 
-def get_target_path(relpath: str | Path, run_dir: str | Path | None = None):
+def get_target_path(relpath: Union[str, Path], run_dir: Union[str, Path, None] = None):
     if run_dir is None:
         run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
     return Path(run_dir) / relpath
 
 
-def download(run_dir: str | Path | None = None):
-    relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
+def download(run_dir: Union[str, Path, None] = None):
+    relpaths = [
+        "hparams.yaml",
+        "ds/G/latest",
+        "ds/G/default/mp_rank_00_model_states.pt",
+    ]
     for relpath in relpaths:
         path = get_target_path(relpath, run_dir=run_dir)
         if path.exists():
diff --git a/modules/repos_static/resemble_enhance/enhancer/enhancer.py b/modules/repos_static/resemble_enhance/enhancer/enhancer.py
index c7ab9417deb429855b7fce43962426f6b6c4a9c0..1ea3f351752d8e8e13040fef842372367926c3e4 100644
--- a/modules/repos_static/resemble_enhance/enhancer/enhancer.py
+++ b/modules/repos_static/resemble_enhance/enhancer/enhancer.py
@@ -1,4 +1,5 @@
 import logging
+from typing import Union
 
 import matplotlib.pyplot as plt
 import pandas as pd
@@ -109,7 +110,7 @@ class Enhancer(nn.Module):
             return self.mel_fn(x)[..., :-1]  # (b d t)
         return self.mel_fn(x)
 
-    def _may_denoise(self, x: Tensor, y: Tensor | None = None):
+    def _may_denoise(self, x: Tensor, y: Union[Tensor, None] = None):
         if self.hp.lcfm_training_mode == "cfm":
             return self.denoiser(x, y)
         return x
@@ -126,7 +127,9 @@ class Enhancer(nn.Module):
         self.lcfm.eval_tau_(tau)
         self._eval_lambd = lambd
 
-    def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
+    def forward(
+        self, x: Tensor, y: Union[Tensor, None] = None, z: Union[Tensor, None] = None
+    ):
         """
         Args:
             x: (b t), mix wavs (fg + bg)
diff --git a/modules/repos_static/resemble_enhance/enhancer/hparams.py b/modules/repos_static/resemble_enhance/enhancer/hparams.py
index ca89bea6f5d7d4ec4f543f8bde88b29dcae69f6a..7878e4172b5772c59aea0de54d2537f0523d9437 100644
--- a/modules/repos_static/resemble_enhance/enhancer/hparams.py
+++ b/modules/repos_static/resemble_enhance/enhancer/hparams.py
@@ -1,5 +1,6 @@
 from dataclasses import dataclass
 from pathlib import Path
+from typing import Union
 
 from ..hparams import HParams as HParamsBase
 
@@ -17,7 +18,7 @@ class HParams(HParamsBase):
 
     vocoder_extra_dim: int = 32
 
-    gan_training_start_step: int | None = 5_000
-    enhancer_stage1_run_dir: Path | None = None
+    gan_training_start_step: Union[int, None] = 5_000
+    enhancer_stage1_run_dir: Union[Path, None] = None
 
-    denoiser_run_dir: Path | None = None
+    denoiser_run_dir: Union[Path, None] = None
diff --git a/modules/repos_static/resemble_enhance/enhancer/inference.py b/modules/repos_static/resemble_enhance/enhancer/inference.py
index af57a2c7d3e5cc7b08b00f85f0135e881e50fcbe..dc7712cb6a4d2126bb4d740d24ed9355312741ef 100644
--- a/modules/repos_static/resemble_enhance/enhancer/inference.py
+++ b/modules/repos_static/resemble_enhance/enhancer/inference.py
@@ -1,6 +1,7 @@
 import logging
 from functools import cache
 from pathlib import Path
+from typing import Union
 
 import torch
 
@@ -13,7 +14,7 @@ logger = logging.getLogger(__name__)
 
 
 @cache
-def load_enhancer(run_dir: str | Path | None, device):
+def load_enhancer(run_dir: Union[str, Path, None], device):
     run_dir = download(run_dir)
     hp = HParams.load(run_dir)
     enhancer = Enhancer(hp)
diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py
index a5125267b7f32e11c58e4b96bffa3ba1e96fdc4f..09b4a3e45ce3b50cca7ce7debe77ddb230ee9783 100644
--- a/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py
+++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py
@@ -1,7 +1,7 @@
 import logging
 from dataclasses import dataclass
 from functools import partial
-from typing import Protocol
+from typing import Protocol, Union
 
 import matplotlib.pyplot as plt
 import numpy as np
@@ -17,8 +17,7 @@ logger = logging.getLogger(__name__)
 
 
 class VelocityField(Protocol):
-    def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
-        ...
+    def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: ...
 
 
 class Solver:
@@ -40,7 +39,9 @@ class Solver:
 
         self._camera = None
         self._mel_fn = mel_fn
-        self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
+        self._time_mapping = partial(
+            self.exponential_decay_mapping, n=time_mapping_divisor
+        )
 
     def configurate_(self, nfe=None, method=None):
         if nfe is None:
@@ -50,7 +51,9 @@ class Solver:
             method = self.method
 
         if nfe == 1 and method in ("midpoint", "rk4"):
-            logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
+            logger.warning(
+                f"1 NFE is not supported for {method}, using euler method instead."
+            )
             method = "euler"
 
         self.nfe = nfe
@@ -105,7 +108,9 @@ class Solver:
                 )
             else:
                 # Spectrogram, b c t
-                plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
+                plt.imshow(
+                    ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none"
+                )
             ax = plt.gca()
             ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
             camera.snap()
@@ -271,7 +276,7 @@ class CFM(nn.Module):
             global_dim=self.time_emb_dim,
         )
 
-    def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
+    def _perturb(self, ψ1: Tensor, t: Union[Tensor, None] = None):
         """
         Perturb ψ1 to ψt.
         """
@@ -311,7 +316,7 @@ class CFM(nn.Module):
         """
         return ψ1 - ψ0
 
-    def _to_v(self, *, ψt, x, t: float | Tensor):
+    def _to_v(self, *, ψt, x, t: Union[float, Tensor]):
         """
         Args:
             ψt: (b c t)
@@ -364,7 +369,13 @@ class CFM(nn.Module):
         ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
         return ψ1
 
-    def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
+    def forward(
+        self,
+        x: Tensor,
+        y: Union[Tensor, None] = None,
+        ψ0: Union[Tensor, None] = None,
+        t0=0.0,
+    ):
         if y is None:
             y = self.sample(x, ψ0=ψ0, t0=t0)
         else:
diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py
index 91f5dbb506187271c67c7bbbf55475021854ab27..aa82827c8809b001d31827d76bbee731e11ae2e2 100644
--- a/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py
+++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py
@@ -1,5 +1,6 @@
 import logging
 from dataclasses import dataclass
+from typing import Union
 
 import torch.nn as nn
 import torch.nn.functional as F
@@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
 @dataclass
 class IRMAEOutput:
     latent: Tensor  # latent vector
-    decoded: Tensor | None  # decoder output, include extra dim
+    decoded: Union[Tensor, None]  # decoder output, include extra dim
 
 
 class ResBlock(nn.Sequential):
diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py
index 4c2f5f88718e2f42f82e2f4714ea510b4677b450..8d1c241312f96525fcf7630e805560cbe9b84406 100644
--- a/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py
+++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py
@@ -1,5 +1,6 @@
 import logging
 from enum import Enum
+from typing import Union
 
 import matplotlib.pyplot as plt
 import torch
@@ -70,19 +71,34 @@ class LCFM(nn.Module):
             return
 
         plt.subplot(221)
-        plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
+        plt.imshow(
+            y[0].detach().cpu().numpy(),
+            aspect="auto",
+            origin="lower",
+            interpolation="none",
+        )
         plt.title("GT")
 
         plt.subplot(222)
         y_ = y_[:, : y.shape[1]]
-        plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
+        plt.imshow(
+            y_[0].detach().cpu().numpy(),
+            aspect="auto",
+            origin="lower",
+            interpolation="none",
+        )
         plt.title("Posterior")
 
         plt.subplot(223)
         z_ = self.cfm(x)
         y__ = self.ae.decode(z_)
         y__ = y__[:, : y.shape[1]]
-        plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
+        plt.imshow(
+            y__[0].detach().cpu().numpy(),
+            aspect="auto",
+            origin="lower",
+            interpolation="none",
+        )
         plt.title("C-Prior")
         del y__
 
@@ -90,7 +106,12 @@ class LCFM(nn.Module):
         z_ = torch.randn_like(z_)
         y__ = self.ae.decode(z_)
         y__ = y__[:, : y.shape[1]]
-        plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
+        plt.imshow(
+            y__[0].detach().cpu().numpy(),
+            aspect="auto",
+            origin="lower",
+            interpolation="none",
+        )
         plt.title("Prior")
         del z_, y__
 
@@ -109,7 +130,7 @@ class LCFM(nn.Module):
     def eval_tau_(self, tau):
         self._eval_tau = tau
 
-    def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None):
+    def forward(self, x, y: Union[Tensor, None] = None, ψ0: Union[Tensor, None] = None):
         """
         Args:
             x: (b d t), condition mel
@@ -139,14 +160,20 @@ class LCFM(nn.Module):
 
             h = self.ae.decode(z)
         else:
-            ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM)
+            ae_output: IRMAEOutput = self.ae(
+                y, skip_decoding=self.mode == self.Mode.CFM
+            )
 
             if self.mode == self.Mode.CFM:
                 _ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
 
             h = ae_output.decoded
 
-            if h is not None and self.global_step is not None and self.global_step % 100 == 0:
+            if (
+                h is not None
+                and self.global_step is not None
+                and self.global_step % 100 == 0
+            ):
                 self._visualize(x[:1], y[:1], h[:1])
 
         return h
diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py b/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py
index bb20217f048f398236698f6a38927310d0c1ba9b..602f08851095b7a25a1bddc8a2daa7e48fc10cb1 100644
--- a/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py
+++ b/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py
@@ -1,3 +1,4 @@
+from typing import Union
 import numpy as np
 import torch
 import torch.nn.functional as F
@@ -50,7 +51,9 @@ class UnivNet(nn.Module):
             ]
         )
 
-        self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect"))
+        self.conv_pre = weight_norm(
+            nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect")
+        )
 
         self.conv_post = nn.Sequential(
             nn.LeakyReLU(0.2),
@@ -64,7 +67,7 @@ class UnivNet(nn.Module):
     def eps(self):
         return 1e-5
 
-    def forward(self, x: Tensor, y: Tensor | None = None, npad=10):
+    def forward(self, x: Tensor, y: Union[Tensor, None] = None, npad=10):
         """
         Args:
             x: (b c t), acoustic features
@@ -74,7 +77,9 @@ class UnivNet(nn.Module):
         """
         assert x.ndim == 3, "x must be 3D tensor"
         assert y is None or y.ndim == 2, "y must be 2D tensor"
-        assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}"
+        assert (
+            x.shape[1] == self.d_input
+        ), f"x.shape[1] must be {self.d_input}, but got {x.shape}"
         assert npad >= 0, "npad must be positive or zero"
 
         x = F.pad(x, (0, npad), "constant", 0)
diff --git a/modules/repos_static/resemble_enhance/hparams.py b/modules/repos_static/resemble_enhance/hparams.py
index a8e716175fa962ada1d98cd755430e2ea770278c..9f796e97c3ab1c3d540d9aed14c8bf0796a7d39b 100644
--- a/modules/repos_static/resemble_enhance/hparams.py
+++ b/modules/repos_static/resemble_enhance/hparams.py
@@ -1,6 +1,7 @@
 import logging
 from dataclasses import asdict, dataclass
 from pathlib import Path
+from typing import Union
 
 from omegaconf import OmegaConf
 from rich.console import Console
@@ -102,7 +103,7 @@ class HParams:
         OmegaConf.save(asdict(self), str(path))
 
     @classmethod
-    def load(cls, run_dir, yaml: Path | None = None):
+    def load(cls, run_dir, yaml: Union[Path, None] = None):
         hps = []
 
         if (run_dir / "hparams.yaml").exists():
@@ -120,7 +121,9 @@ class HParams:
                 for k, v in asdict(hp).items():
                     if getattr(hps[0], k) != v:
                         errors[k] = f"{getattr(hps[0], k)} != {v}"
-                raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
+                raise ValueError(
+                    f"Found inconsistent hparams: {errors}, consider deleting {run_dir}"
+                )
 
         return hps[0]
 
diff --git a/modules/speaker.py b/modules/speaker.py
index 46ceaf947a5be1c2915c51d29c5a48707388af82..764cf48839fe8127af63704f9ea947710ea1b3a4 100644
--- a/modules/speaker.py
+++ b/modules/speaker.py
@@ -29,13 +29,15 @@ class Speaker:
         speaker.emb = tensor
         return speaker
 
-    def __init__(self, seed, name="", gender="", describe=""):
+    def __init__(
+        self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe=""
+    ):
         self.id = uuid.uuid4()
-        self.seed = seed
+        self.seed = -2 if isinstance(seed_or_tensor, torch.Tensor) else seed_or_tensor
         self.name = name
         self.gender = gender
         self.describe = describe
-        self.emb = None
+        self.emb = None if isinstance(seed_or_tensor, int) else seed_or_tensor
 
         # TODO replace emb => tokens
         self.tokens = []
diff --git a/modules/ssml_parser/SSMLParser.py b/modules/ssml_parser/SSMLParser.py
index 5db290006a51a809178c7c0198a51a9d6324a888..bda224cd2e9846cff34eb21e7430f1a5b7f9e42a 100644
--- a/modules/ssml_parser/SSMLParser.py
+++ b/modules/ssml_parser/SSMLParser.py
@@ -11,8 +11,8 @@ import copy
 
 
 class SSMLContext(Box):
-    def __init__(self, parent=None):
-        self.parent: Union[SSMLContext, None] = parent
+    def __init__(self, *args, **kwargs):
+        self.parent: Union[SSMLContext, None] = None
 
         self.style = None
         self.spk = None
@@ -29,18 +29,14 @@ class SSMLContext(Box):
         self.prompt2 = None
         self.prefix = None
 
-    def clone(self):
-        ctx = SSMLContext()
-        for k, v in self.items():
-            ctx[k] = v
-        return ctx
+        super().__init__(*args, **kwargs)
 
 
 class SSMLSegment(Box):
-    def __init__(self, text: str, attrs=SSMLContext()):
-        self.attrs = attrs
+    def __init__(self, text: str, attrs=SSMLContext(), params=None):
+        self.attrs = SSMLContext(**attrs)
         self.text = text
-        self.params = None
+        self.params = params
 
 
 class SSMLBreak:
@@ -68,7 +64,7 @@ class SSMLParser:
         root = etree.fromstring(ssml)
 
         root_ctx = SSMLContext()
-        segments = []
+        segments: List[Union[SSMLSegment, SSMLBreak]] = []
         self.resolve(root, root_ctx, segments)
 
         return segments
@@ -89,8 +85,13 @@ def create_ssml_parser():
     parser = SSMLParser()
 
     @parser.resolver("speak")
-    def tag_speak(element, context, segments, parser):
-        ctx = context.clone() if context is not None else SSMLContext()
+    def tag_speak(
+        element: etree.Element,
+        context: Box,
+        segments: List[Union[SSMLSegment, SSMLBreak]],
+        parser: SSMLParser,
+    ):
+        ctx = context.copy() if context is not None else SSMLContext()
 
         version = element.get("version")
         if version != "0.1":
@@ -100,8 +101,13 @@ def create_ssml_parser():
             parser.resolve(child, ctx, segments)
 
     @parser.resolver("voice")
-    def tag_voice(element, context, segments, parser):
-        ctx = context.clone() if context is not None else SSMLContext()
+    def tag_voice(
+        element: etree.Element,
+        context: Box,
+        segments: List[Union[SSMLSegment, SSMLBreak]],
+        parser: SSMLParser,
+    ):
+        ctx = context.copy() if context is not None else SSMLContext()
 
         ctx.spk = element.get("spk", ctx.spk)
         ctx.style = element.get("style", ctx.style)
@@ -131,13 +137,23 @@ def create_ssml_parser():
                 segments.append(SSMLSegment(child.tail.strip(), ctx))
 
     @parser.resolver("break")
-    def tag_break(element, context, segments, parser):
+    def tag_break(
+        element: etree.Element,
+        context: Box,
+        segments: List[Union[SSMLSegment, SSMLBreak]],
+        parser: SSMLParser,
+    ):
         time_ms = int(element.get("time", "0").replace("ms", ""))
         segments.append(SSMLBreak(time_ms))
 
     @parser.resolver("prosody")
-    def tag_prosody(element, context, segments, parser):
-        ctx = context.clone() if context is not None else SSMLContext()
+    def tag_prosody(
+        element: etree.Element,
+        context: Box,
+        segments: List[Union[SSMLSegment, SSMLBreak]],
+        parser: SSMLParser,
+    ):
+        ctx = context.copy() if context is not None else SSMLContext()
 
         ctx.spk = element.get("spk", ctx.spk)
         ctx.style = element.get("style", ctx.style)
diff --git a/modules/synthesize_audio.py b/modules/synthesize_audio.py
index a07c7bc69de6352cfec6a9c01d6c0203ac4f8d94..c032abc2ee99237b729c23b9e4e2cd1b0cf683d9 100644
--- a/modules/synthesize_audio.py
+++ b/modules/synthesize_audio.py
@@ -7,6 +7,7 @@ from modules import generate_audio as generate
 
 
 from modules.speaker import Speaker
+from modules.ssml_parser.SSMLParser import SSMLSegment
 from modules.utils import audio
 
 
@@ -23,45 +24,33 @@ def synthesize_audio(
     prefix: str = "",
     batch_size: int = 1,
     spliter_threshold: int = 100,
+    end_of_sentence="",
 ):
-    if batch_size == 1:
-        return generate.generate_audio(
-            text,
-            temperature=temperature,
-            top_P=top_P,
-            top_K=top_K,
-            spk=spk,
-            infer_seed=infer_seed,
-            use_decoder=use_decoder,
-            prompt1=prompt1,
-            prompt2=prompt2,
-            prefix=prefix,
+    spliter = SentenceSplitter(spliter_threshold)
+    sentences = spliter.parse(text)
+
+    text_segments = [
+        SSMLSegment(
+            text=s,
+            params={
+                "temperature": temperature,
+                "top_P": top_P,
+                "top_K": top_K,
+                "spk": spk,
+                "infer_seed": infer_seed,
+                "use_decoder": use_decoder,
+                "prompt1": prompt1,
+                "prompt2": prompt2,
+                "prefix": prefix,
+            },
         )
-    else:
-        spliter = SentenceSplitter(spliter_threshold)
-        sentences = spliter.parse(text)
+        for s in sentences
+    ]
+    synthesizer = SynthesizeSegments(
+        batch_size=batch_size, eos=end_of_sentence, spliter_thr=spliter_threshold
+    )
+    audio_segments = synthesizer.synthesize_segments(text_segments)
 
-        text_segments = [
-            {
-                "text": s,
-                "params": {
-                    "text": s,
-                    "temperature": temperature,
-                    "top_P": top_P,
-                    "top_K": top_K,
-                    "spk": spk,
-                    "infer_seed": infer_seed,
-                    "use_decoder": use_decoder,
-                    "prompt1": prompt1,
-                    "prompt2": prompt2,
-                    "prefix": prefix,
-                },
-            }
-            for s in sentences
-        ]
-        synthesizer = SynthesizeSegments(batch_size)
-        audio_segments = synthesizer.synthesize_segments(text_segments)
+    combined_audio = combine_audio_segments(audio_segments)
 
-        combined_audio = combine_audio_segments(audio_segments)
-
-        return audio.pydub_to_np(combined_audio)
+    return audio.pydub_to_np(combined_audio)
diff --git a/modules/utils/audio.py b/modules/utils/audio.py
index 48f38c598db590bad30687e519db78f1b0b491af..b1a97eeee49ea8aa9d877fc3e9cdeb6e8f1ea1cf 100644
--- a/modules/utils/audio.py
+++ b/modules/utils/audio.py
@@ -95,7 +95,11 @@ def pitch_shift(
 
 
 def apply_prosody_to_audio_data(
-    audio_data: np.ndarray, rate: float, volume: float, pitch: float, sr: int
+    audio_data: np.ndarray,
+    rate: float = 1,
+    volume: float = 0,
+    pitch: float = 0,
+    sr: int = 24000,
 ) -> np.ndarray:
     if rate != 1:
         audio_data = pyrb.time_stretch(audio_data, sr=sr, rate=rate)
diff --git a/modules/webui/app.py b/modules/webui/app.py
index 1f64fa8d33ac67801307735ac3cfe26c63e384fa..b1acbf9e5f1fa6be03466d739b8f8445bbd853f7 100644
--- a/modules/webui/app.py
+++ b/modules/webui/app.py
@@ -7,6 +7,7 @@ from modules import config
 from modules.webui import gradio_extensions, webui_config
 
 from modules.webui.changelog_tab import create_changelog_tab
+from modules.webui.finetune.ft_tab import create_ft_tabs
 from modules.webui.localization_runtime import ENLocalizationVars, ZHLocalizationVars
 from modules.webui.ssml.podcast_tab import create_ssml_podcast_tab
 from modules.webui.system_tab import create_system_tab
@@ -118,6 +119,8 @@ def create_interface():
                 gr.Markdown("🚧 Under construction")
             with gr.TabItem("ASR", visible=webui_config.experimental):
                 gr.Markdown("🚧 Under construction")
+            with gr.TabItem("Finetune", visible=webui_config.experimental):
+                create_ft_tabs(demo)
 
             with gr.TabItem("System"):
                 create_system_tab()
diff --git a/modules/webui/finetune/ProcessMonitor.py b/modules/webui/finetune/ProcessMonitor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a92c187ae0de80ec0ad93f56d65e623d4a916c55
--- /dev/null
+++ b/modules/webui/finetune/ProcessMonitor.py
@@ -0,0 +1,92 @@
+import os
+import sys
+import subprocess
+import threading
+
+
+class ProcessMonitor:
+    def __init__(self):
+        self.process = None
+        self.stdout = ""
+        self.stderr = ""
+        self.lock = threading.Lock()
+
+    def start_process(self, command):
+        self.process = subprocess.Popen(
+            command,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            bufsize=1,
+            universal_newlines=True,
+        )
+
+        # Set pipes to non-blocking mode
+        fd_out = self.process.stdout.fileno()
+        fd_err = self.process.stderr.fileno()
+
+        if sys.platform != "win32":
+            import fcntl
+
+            fl_out = fcntl.fcntl(fd_out, fcntl.F_GETFL)
+            fl_err = fcntl.fcntl(fd_err, fcntl.F_GETFL)
+            fcntl.fcntl(fd_out, fcntl.F_SETFL, fl_out | os.O_NONBLOCK)
+            fcntl.fcntl(fd_err, fcntl.F_SETFL, fl_err | os.O_NONBLOCK)
+
+        # Start threads to read stdout and stderr
+        threading.Thread(target=self._read_stdout).start()
+        threading.Thread(target=self._read_stderr).start()
+
+    def _read_stdout(self):
+        while self.process is not None and self.process.poll() is None:
+            try:
+                output = self.process.stdout.read()
+                if output:
+                    with self.lock:
+                        self.stdout += output
+            except:
+                pass
+
+    def _read_stderr(self):
+        while self.process is not None and self.process.poll() is None:
+            try:
+                error = self.process.stderr.read()
+                if error:
+                    with self.lock:
+                        self.stderr += error
+            except:
+                pass
+
+    def get_output(self):
+        with self.lock:
+            return self.stdout, self.stderr
+
+    def stop_process(self):
+        if self.process:
+            self.process.terminate()
+            self.process = None
+
+
+if __name__ == "__main__":
+    import time
+
+    pm = ProcessMonitor()
+    pm.start_process(
+        [
+            "python",
+            "-u",
+            "-c",
+            "import time; [print(i) or time.sleep(1) for i in range(5)]",
+        ]
+    )
+
+    while pm.process and pm.process.poll() is None:
+        stdout, stderr = pm.get_output()
+        if stdout:
+            print("STDOUT:", stdout)
+        if stderr:
+            print("STDERR:", stderr)
+        time.sleep(1)
+
+    stdout, stderr = pm.get_output()
+    print("Final STDOUT:", stdout)
+    print("Final STDERR:", stderr)
diff --git a/modules/webui/finetune/ft_tab.py b/modules/webui/finetune/ft_tab.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac7147f40e598abf846ef5ee2d6e3e8a6bf005b9
--- /dev/null
+++ b/modules/webui/finetune/ft_tab.py
@@ -0,0 +1,13 @@
+import gradio as gr
+
+from modules.webui.finetune.speaker_ft_tab import create_speaker_ft_tab
+
+
+def create_ft_tabs(demo):
+    with gr.Tabs():
+        with gr.TabItem("Speaker"):
+            create_speaker_ft_tab(demo)
+        with gr.TabItem("GPT"):
+            gr.Markdown("🚧 Under construction")
+        with gr.TabItem("AE"):
+            gr.Markdown("🚧 Under construction")
diff --git a/modules/webui/finetune/ft_ui_utils.py b/modules/webui/finetune/ft_ui_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2a8e8ca09ebe5a8c9c6df6f6e336a77de10fb17
--- /dev/null
+++ b/modules/webui/finetune/ft_ui_utils.py
@@ -0,0 +1,49 @@
+import os
+from typing import IO, Union
+from modules.speaker import Speaker, speaker_mgr
+import subprocess
+
+
+def get_datasets_dir():
+    """
+    列出 ./datasets/data_* 文件夹
+    """
+    dataset_path = "./datasets"
+    dataset_list = os.listdir(dataset_path)
+    dataset_list = [
+        d for d in dataset_list if os.path.isdir(os.path.join(dataset_path, d))
+    ]
+    dataset_list = [d for d in dataset_list if d.startswith("data_")]
+    return dataset_list
+
+
+def get_datasets_listfile():
+    datasets = get_datasets_dir()
+    listfiles = []
+    for d in datasets:
+        dir_path = os.path.join("./datasets", d)
+        files = os.listdir(dir_path)
+        for f in files:
+            if f.endswith(".list"):
+                listfiles.append(os.path.join(dir_path, f))
+    return listfiles
+
+
+def run_speaker_ft(
+    batch_size: int, epochs: int, train_text: bool, data_path: str, init_speaker: str
+):
+    command = ["python3", "-m", "modules.finetune.train_speaker"]
+    command += [
+        f"--batch_size={batch_size}",
+        f"--epochs={epochs}",
+        f"--data_path={data_path}",
+    ]
+    if train_text:
+        command.append("--train_text")
+    if init_speaker:
+        command.append(f"--init_speaker={init_speaker}")
+    process = subprocess.Popen(
+        command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1
+    )
+
+    return process
diff --git a/modules/webui/finetune/speaker_ft_tab.py b/modules/webui/finetune/speaker_ft_tab.py
new file mode 100644
index 0000000000000000000000000000000000000000..f652ec422a5b71acfaec98aa74f93c72b61ae32b
--- /dev/null
+++ b/modules/webui/finetune/speaker_ft_tab.py
@@ -0,0 +1,130 @@
+import gradio as gr
+
+from modules.Enhancer.ResembleEnhance import unload_enhancer
+from modules.webui import webui_config
+from modules.webui.webui_utils import get_speaker_names
+from .ft_ui_utils import get_datasets_listfile, run_speaker_ft
+from .ProcessMonitor import ProcessMonitor
+from modules.speaker import speaker_mgr
+from modules.models import unload_chat_tts
+
+
+class SpeakerFt:
+    def __init__(self):
+        self.process_monitor = ProcessMonitor()
+        self.status_str = "idle"
+
+    def unload_main_thread_models(self):
+        unload_chat_tts()
+        unload_enhancer()
+
+    def run(
+        self,
+        batch_size: int,
+        epochs: int,
+        lr: str,
+        train_text: bool,
+        data_path: str,
+        select_speaker: str = "",
+    ):
+        if self.process_monitor.process:
+            return
+        self.unload_main_thread_models()
+        spk_path = None
+        if select_speaker != "" and select_speaker != "none":
+            select_speaker = select_speaker.split(" : ")[1].strip()
+            spk = speaker_mgr.get_speaker(select_speaker)
+            if spk is None:
+                return ["Speaker not found"]
+            spk_filename = speaker_mgr.get_speaker_filename(spk.id)
+            spk_path = f"./data/speakers/{spk_filename}"
+
+        command = ["python3", "-m", "modules.finetune.train_speaker"]
+        command += [
+            f"--batch_size={batch_size}",
+            f"--epochs={epochs}",
+            f"--data_path={data_path}",
+        ]
+        if train_text:
+            command.append("--train_text")
+        if spk_path:
+            command.append(f"--init_speaker={spk_path}")
+
+        self.status("Training process starting")
+
+        self.process_monitor.start_process(command)
+
+        self.status("Training started")
+
+    def status(self, text: str):
+        self.status_str = text
+
+    def flush(self):
+        stdout, stderr = self.process_monitor.get_output()
+        return f"{self.status_str}\n{stdout}\n{stderr}"
+
+    def clear(self):
+        self.process_monitor.stdout = ""
+        self.process_monitor.stderr = ""
+        self.status("Logs cleared")
+
+    def stop(self):
+        self.process_monitor.stop_process()
+        self.status("Training stopped")
+
+
+def create_speaker_ft_tab(demo: gr.Blocks):
+    spk_ft = SpeakerFt()
+    speakers, speaker_names = get_speaker_names()
+    speaker_names = ["none"] + speaker_names
+
+    with gr.Row():
+        with gr.Column(scale=2):
+            with gr.Group():
+                gr.Markdown("🎛️hparams")
+                dataset_input = gr.Dropdown(
+                    label="Dataset", choices=get_datasets_listfile()
+                )
+                lr_input = gr.Textbox(label="Learning Rate", value="1e-2")
+                epochs_input = gr.Slider(
+                    label="Epochs", value=10, minimum=1, maximum=100, step=1
+                )
+                batch_size_input = gr.Slider(
+                    label="Batch Size", value=4, minimum=1, maximum=64, step=1
+                )
+                train_text_checkbox = gr.Checkbox(label="Train text_loss", value=True)
+                init_spk_dropdown = gr.Dropdown(
+                    label="Initial Speaker",
+                    choices=speaker_names,
+                    value="none",
+                )
+
+            with gr.Group():
+                start_train_btn = gr.Button("Start Training")
+                stop_train_btn = gr.Button("Stop Training")
+                clear_train_btn = gr.Button("Clear logs")
+        with gr.Column(scale=5):
+            with gr.Group():
+                # log
+                gr.Markdown("📜logs")
+                log_output = gr.Textbox(
+                    show_label=False, label="Log", value="", lines=20, interactive=True
+                )
+
+    start_train_btn.click(
+        spk_ft.run,
+        inputs=[
+            batch_size_input,
+            epochs_input,
+            lr_input,
+            train_text_checkbox,
+            dataset_input,
+            init_spk_dropdown,
+        ],
+        outputs=[],
+    )
+    stop_train_btn.click(spk_ft.stop)
+    clear_train_btn.click(spk_ft.clear)
+
+    if webui_config.experimental:
+        demo.load(spk_ft.flush, every=1, outputs=[log_output])
diff --git a/modules/webui/localization_runtime.py b/modules/webui/localization_runtime.py
index 9689c960e900ad1ec93c3e85e5c09d8bb5a54626..273eb05c66676f525fb484b6a6a64a1129091462 100644
--- a/modules/webui/localization_runtime.py
+++ b/modules/webui/localization_runtime.py
@@ -7,6 +7,7 @@ class LocalizationVars:
 
         self.ssml_examples = []
         self.tts_examples = []
+        self.podcast_default = []
 
 
 class ZHLocalizationVars(LocalizationVars):
@@ -167,6 +168,69 @@ class ZHLocalizationVars(LocalizationVars):
             },
         ]
 
+        self.podcast_default = [
+            [
+                1,
+                "female2",
+                "你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。",
+                "podcast",
+            ],
+            [
+                2,
+                "Alice",
+                "嗨,我特别期待这个话题!中华料理真的是博大精深。",
+                "podcast",
+            ],
+            [
+                3,
+                "Bob",
+                "没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。",
+                "podcast",
+            ],
+            [
+                4,
+                "female2",
+                "那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。",
+                "podcast",
+            ],
+            [
+                5,
+                "Alice",
+                "对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。",
+                "podcast",
+            ],
+            [
+                6,
+                "Bob",
+                "除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。",
+                "podcast",
+            ],
+            [
+                7,
+                "female2",
+                "对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。",
+                "podcast",
+            ],
+            [
+                8,
+                "Alice",
+                "还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。",
+                "podcast",
+            ],
+            [
+                9,
+                "Bob",
+                "不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。",
+                "podcast",
+            ],
+            [
+                10,
+                "female2",
+                "对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。",
+                "podcast",
+            ],
+        ]
+
 
 class ENLocalizationVars(LocalizationVars):
     def __init__(self):
@@ -224,3 +288,65 @@ class ENLocalizationVars(LocalizationVars):
                 "text": "Don't ever let somebody tell you you can't do something. Not even me. Alright? You got a dream, you gotta protect it. When people can't do something themselves, they're gonna tell you that you can't do it. You want something, go get it. Period.",
             },
         ]
+        self.podcast_default = [
+            [
+                1,
+                "female2",
+                "Hello, welcome to today's podcast. Today, we're going to talk about global cuisine.",
+                "podcast",
+            ],
+            [
+                2,
+                "Alice",
+                "Hi, I'm really excited about this topic! Global cuisine is incredibly diverse and fascinating.",
+                "podcast",
+            ],
+            [
+                3,
+                "Bob",
+                "Absolutely, every country has its own unique culinary traditions and specialties.",
+                "podcast",
+            ],
+            [
+                4,
+                "female2",
+                "Let's start with Italian cuisine. Italian food is loved worldwide, especially for its pasta and pizza.",
+                "podcast",
+            ],
+            [
+                5,
+                "Alice",
+                "Yes, I especially love a good Margherita pizza and a hearty plate of spaghetti carbonara. The flavors are simply amazing.",
+                "podcast",
+            ],
+            [
+                6,
+                "Bob",
+                "Besides Italian cuisine, Japanese cuisine is also very popular. Dishes like sushi and ramen have become global favorites.",
+                "podcast",
+            ],
+            [
+                7,
+                "female2",
+                "Exactly, Japanese cuisine is known for its emphasis on fresh ingredients and delicate presentation.",
+                "podcast",
+            ],
+            [
+                8,
+                "Alice",
+                "And then there's Mexican cuisine, with its bold flavors and colorful dishes like tacos and guacamole.",
+                "podcast",
+            ],
+            [
+                9,
+                "Bob",
+                "Not to mention, there's also Indian cuisine, Thai cuisine, French cuisine, and so many more, each with its own distinctive flavors and techniques.",
+                "podcast",
+            ],
+            [
+                10,
+                "female2",
+                "Yes, like Indian curry, Thai tom yum soup, and French croissants, these are all mouth-watering dishes that are loved by people all over the world.",
+                "podcast",
+            ],
+        ]
diff --git a/modules/webui/ssml/podcast_tab.py b/modules/webui/ssml/podcast_tab.py
index 440e0b60eb5f8b394a86cc6dac7266ec018fb6ec..32e732d34fb9c4b4e86cf88b5e549312c88174b4 100644
--- a/modules/webui/ssml/podcast_tab.py
+++ b/modules/webui/ssml/podcast_tab.py
@@ -3,72 +3,9 @@ import pandas as pd
 import torch
 
 from modules.normalization import text_normalize
-from modules.webui import webui_utils
+from modules.webui import webui_config, webui_utils
 from modules.utils.hf import spaces
 
-podcast_default_case = [
-    [
-        1,
-        "female2",
-        "你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。 [lbreak]",
-        "podcast",
-    ],
-    [
-        2,
-        "Alice",
-        "嗨,我特别期待这个话题!中华料理真的是博大精深。 [lbreak]",
-        "podcast",
-    ],
-    [
-        3,
-        "Bob",
-        "没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。 [lbreak]",
-        "podcast",
-    ],
-    [
-        4,
-        "female2",
-        "那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。 [lbreak]",
-        "podcast",
-    ],
-    [
-        5,
-        "Alice",
-        "对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。 [lbreak]",
-        "podcast",
-    ],
-    [
-        6,
-        "Bob",
-        "除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。 [lbreak]",
-        "podcast",
-    ],
-    [
-        7,
-        "female2",
-        "对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。 [lbreak]",
-        "podcast",
-    ],
-    [
-        8,
-        "Alice",
-        "还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。 [lbreak]",
-        "podcast",
-    ],
-    [
-        9,
-        "Bob",
-        "不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。 [lbreak]",
-        "podcast",
-    ],
-    [
-        10,
-        "female2",
-        "对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。 [lbreak]",
-        "podcast",
-    ],
-]
-
 
 # NOTE: 因为 text_normalize 需要使用 tokenizer
 @torch.inference_mode()
@@ -133,7 +70,7 @@ def create_ssml_podcast_tab(ssml_input: gr.Textbox, tabs1: gr.Tabs, tabs2: gr.Ta
                     datatype=["number", "str", "str", "str"],
                     interactive=True,
                     wrap=True,
-                    value=podcast_default_case,
+                    value=webui_config.localization.podcast_default,
                     row_count=(0, "dynamic"),
                     col_count=(4, "fixed"),
                 )
diff --git a/modules/webui/ssml/ssml_tab.py b/modules/webui/ssml/ssml_tab.py
index f2de84c2e28bcd35e7c14611cc9ff3fdb57bdfc7..6fa6dd861daa3fd246d56aec4ded84ef068537d3 100644
--- a/modules/webui/ssml/ssml_tab.py
+++ b/modules/webui/ssml/ssml_tab.py
@@ -22,7 +22,6 @@ def create_ssml_interface():
                 ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
         with gr.Column(scale=1):
             with gr.Group():
-                # 参数
                 gr.Markdown("🎛️Parameters")
                 # batch size
                 batch_size_input = gr.Slider(
@@ -32,6 +31,19 @@ def create_ssml_interface():
                     maximum=webui_config.max_batch_size,
                     step=1,
                 )
+            with gr.Group():
+                gr.Markdown("🎛️Spliter")
+                eos_input = gr.Textbox(
+                    label="eos",
+                    value="[uv_break]",
+                )
+                spliter_thr_input = gr.Slider(
+                    label="Spliter Threshold",
+                    value=100,
+                    minimum=50,
+                    maximum=1000,
+                    step=1,
+                )
 
             with gr.Group():
                 gr.Markdown("💪🏼Enhance")
@@ -49,7 +61,14 @@ def create_ssml_interface():
 
     ssml_button.click(
         synthesize_ssml,
-        inputs=[ssml_input, batch_size_input, enable_enhance, enable_de_noise],
+        inputs=[
+            ssml_input,
+            batch_size_input,
+            enable_enhance,
+            enable_de_noise,
+            eos_input,
+            spliter_thr_input,
+        ],
         outputs=ssml_output,
     )
 
diff --git a/modules/webui/tts_tab.py b/modules/webui/tts_tab.py
index c39e7ed284cca32eae897cfcfaf80559a1c3d49b..ab81e12243bbd83eccf0a0f026cb9650f86df690 100644
--- a/modules/webui/tts_tab.py
+++ b/modules/webui/tts_tab.py
@@ -29,32 +29,6 @@ def create_tts_interface():
 
     with gr.Row():
         with gr.Column(scale=1):
-            with gr.Group():
-                gr.Markdown("🎛️Sampling")
-                temperature_input = gr.Slider(
-                    0.01, 2.0, value=0.3, step=0.01, label="Temperature"
-                )
-                top_p_input = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Top P")
-                top_k_input = gr.Slider(1, 50, value=20, step=1, label="Top K")
-                batch_size_input = gr.Slider(
-                    1,
-                    webui_config.max_batch_size,
-                    value=4,
-                    step=1,
-                    label="Batch Size",
-                )
-
-            with gr.Row():
-                with gr.Group():
-                    gr.Markdown("🎭Style")
-                    gr.Markdown("TTS_STYLE_GUIDE")
-                    style_input_dropdown = gr.Dropdown(
-                        choices=styles,
-                        # label="Choose Style",
-                        interactive=True,
-                        show_label=False,
-                        value="*auto",
-                    )
             with gr.Row():
                 with gr.Group():
                     gr.Markdown("🗣️Speaker")
@@ -102,7 +76,47 @@ def create_tts_interface():
                                 fn=load_spk_info,
                                 inputs=[spk_file_upload],
                                 outputs=[infos],
-                            ),
+                            )
+
+            with gr.Row():
+                with gr.Group():
+                    gr.Markdown("🎭Style")
+                    gr.Markdown("TTS_STYLE_GUIDE")
+                    style_input_dropdown = gr.Dropdown(
+                        choices=styles,
+                        # label="Choose Style",
+                        interactive=True,
+                        show_label=False,
+                        value="*auto",
+                    )
+
+            with gr.Group():
+                gr.Markdown("🎛️Sampling")
+                temperature_input = gr.Slider(
+                    0.01, 2.0, value=0.3, step=0.01, label="Temperature"
+                )
+                top_p_input = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Top P")
+                top_k_input = gr.Slider(1, 50, value=20, step=1, label="Top K")
+                batch_size_input = gr.Slider(
+                    1,
+                    webui_config.max_batch_size,
+                    value=4,
+                    step=1,
+                    label="Batch Size",
+                )
+            with gr.Group():
+                gr.Markdown("🎛️Spliter")
+                eos_input = gr.Textbox(
+                    label="eos",
+                    value="[uv_break]",
+                )
+                spliter_thr_input = gr.Slider(
+                    label="Spliter Threshold",
+                    value=100,
+                    minimum=50,
+                    maximum=1000,
+                    step=1,
+                )
 
             with gr.Group():
                 gr.Markdown("💃Inference Seed")
@@ -202,7 +216,8 @@ def create_tts_interface():
                 )
                 refine_button = gr.Button("✍️Refine Text")
 
-            with gr.Group():
+            # 由于使用不是很方便,所以列为实验性功能
+            with gr.Group(visible=webui_config.experimental):
                 gr.Markdown("🔧Prompt engineering")
                 prompt1_input = gr.Textbox(label="Prompt 1")
                 prompt2_input = gr.Textbox(label="Prompt 2")
@@ -253,6 +268,8 @@ def create_tts_interface():
             enable_enhance,
             enable_de_noise,
             spk_file_upload,
+            spliter_thr_input,
+            eos_input,
         ],
         outputs=tts_output,
     )
diff --git a/modules/webui/webui_utils.py b/modules/webui/webui_utils.py
index 4a6d6dedf17c6a37c16a3546afab089a69599d38..cf57c80c597d20508a7e70a3d544e90f9bafe92a 100644
--- a/modules/webui/webui_utils.py
+++ b/modules/webui/webui_utils.py
@@ -95,6 +95,8 @@ def synthesize_ssml(
     batch_size=4,
     enable_enhance=False,
     enable_denoise=False,
+    eos: str = "[uv_break]",
+    spliter_thr: int = 100,
 ):
     try:
         batch_size = int(batch_size)
@@ -114,7 +116,9 @@ def synthesize_ssml(
     if len(segments) == 0:
         return None
 
-    synthesize = SynthesizeSegments(batch_size=batch_size)
+    synthesize = SynthesizeSegments(
+        batch_size=batch_size, eos=eos, spliter_thr=spliter_thr
+    )
     audio_segments = synthesize.synthesize_segments(segments)
     combined_audio = combine_audio_segments(audio_segments)
 
@@ -151,6 +155,8 @@ def tts_generate(
     enable_enhance=False,
     enable_denoise=False,
     spk_file=None,
+    spliter_thr: int = 100,
+    eos: str = "[uv_break]",
 ):
     try:
         batch_size = int(batch_size)
@@ -199,6 +205,8 @@ def tts_generate(
         prompt2=prompt2,
         prefix=prefix,
         batch_size=batch_size,
+        end_of_sentence=eos,
+        spliter_threshold=spliter_thr,
     )
 
     audio_data, sample_rate = apply_audio_enhance(
diff --git a/webui.py b/webui.py
index 055874043e85b90485fee64bc4982cd4d3373486..d2b7fdb075b9eb61ca9f0e37c63cb9e4c7469ea2 100644
--- a/webui.py
+++ b/webui.py
@@ -1,4 +1,5 @@
 import os
+import sys
 import logging
 
 from modules.api.api_setup import (
@@ -106,6 +107,7 @@ def process_webui_args(args):
         auth=auth,
         show_api=False,
         prevent_thread_lock=True,
+        inbrowser=sys.platform == "win32",
         app_kwargs={
             "title": app_title,
             "description": app_description,