File size: 8,581 Bytes
f65fe2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os
import json
import queue
from pathlib import Path
from typing import Optional

import click
import torch
import soundfile as sf
from loguru import logger

from fish_speech.models.text2semantic.inference import (
    CodebookSamplingParams,
    SamplingParams,
    generate_long,
    launch_thread_safe_queue,
    GenerateRequest,
    WrappedGenerateResponse,
)
from fish_speech.models.text2semantic.llama import BaseTransformer
from fish_speech.models.dac.inference import load_model as load_decoder_model
from fish_speech.text import clean_text
from fish_speech.inference_engine.vq_manager import VQManager
from tools.api import load_audio


def load_llm_model(model_path: str, device: str, compile: bool = False):
    """加载LLM模型"""
    logger.info(f"Loading LLM model from {model_path}")
    model = BaseTransformer.from_pretrained(
        path=model_path,
        load_weights=True,
    )
    model = model.to(device=device, dtype=torch.bfloat16)
    
    if isinstance(model, model.__class__.__bases__[0].__subclasses__()[1]):  # DualARTransformer
        from fish_speech.models.text2semantic.inference import decode_one_token_ar as decode_one_token
        logger.info("Using DualARTransformer")
    else:
        from fish_speech.models.text2semantic.inference import decode_one_token_naive as decode_one_token
        logger.info("Using NaiveTransformer")
    
    if compile:
        logger.info("Compiling decode function...")
        decode_one_token = torch.compile(
            decode_one_token,
            fullgraph=True,
            backend="inductor" if torch.cuda.is_available() else "aot_eager",
            mode="reduce-overhead" if torch.cuda.is_available() else None,
        )
    
    return model.eval(), decode_one_token


def load_dac_model(config_name: str, checkpoint_path: str, device: str):
    """加载DAC模型"""
    logger.info(f"Loading DAC model from {checkpoint_path}")
    model = load_decoder_model(
        config_name=config_name,
        checkpoint_path=checkpoint_path,
        device=device,
    )
    return model


@click.command()
#@click.argument("text", type=str)
@click.option("--llm-model-path", type=str, required=True, help="Path to the LLM model")
@click.option("--dac-model-path", type=str, required=True, help="Path to the DAC model")
@click.option("--dac-config-name", type=str, default="modded_dac_vq", help="DAC model config name")
@click.option("--output-path", type=str, required=True, help="Path to save the output audio")
@click.option("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
@click.option("--max-new-tokens", type=int, default=4096, help="Maximum new tokens to generate")
@click.option("--chunk-length", type=int, default=1000, help="Chunk length for synthesis")
@click.option("--compile", is_flag=True, help="Whether to compile the model")
@click.option("--iterative-prompt", is_flag=True, help="Whether to use iterative prompt")
@click.option("--params-file", type=str, default="sampling_params_example.json", help="Path to JSON file containing sampling parameters")
@click.option(
    "--ref-audio",
    type=click.Path(path_type=Path, exists=True),
    default="ref.wav",
    help="参考音频文件路径,默认ref.wav"
)
def main(
    #text: str,
    llm_model_path: str,
    dac_model_path: str,
    dac_config_name: str,
    output_path: str,
    device: str,
    max_new_tokens: int,
    chunk_length: int,
    compile: bool,
    iterative_prompt: bool,
    params_file: Optional[str],
    ref_audio: Path,
):
    """生成语音,包括LLM生成token和DAC生成音频两个步骤"""
    
    # 设置精度
    precision = torch.half if torch.cuda.is_available() else torch.bfloat16
    
    # 加载LLM模型(使用线程安全的队列)
    logger.info("Loading LLM model...")
    llama_queue = launch_thread_safe_queue(
        checkpoint_path=llm_model_path,
        device="cuda:0",
        precision=precision,
        compile=compile,
    )
    logger.info("LLM model loaded")
    
    # 加载DAC模型
    logger.info("Loading DAC model...")
    dac_model = load_decoder_model(
        config_name=dac_config_name,
        checkpoint_path=dac_model_path,
        device="cuda:1",
    )
    logger.info("DAC model loaded")
    
    # 加载采样参数
    if params_file:
        with open(params_file, "r", encoding="utf-8") as f:
            params_data = json.load(f)
        text = params_data.get("text", "")
            
        semantic_params = CodebookSamplingParams(**params_data.get("semantic", {}))
        codebook_params = [
            CodebookSamplingParams(**params) for params in params_data.get("codebooks", [])
        ]
        sampling_params = SamplingParams(
            semantic=semantic_params,
            codebooks=codebook_params,
        )
    else:
        sampling_params = SamplingParams()
    
    # 清理文本
    text = clean_text(text)
    
    # 加载参考音频
    if not ref_audio.exists():
        ref_audio_data, ref_sr = sf.read(ref_audio)
        logger.info(f"Loaded reference audio: {ref_audio}, shape={ref_audio_data.shape}, sr={ref_sr}")
        # 编码参考音频为prompt_tokens
        vq_manager = VQManager()
        vq_manager.decoder_model = dac_model
        vq_manager.load_audio = load_audio
        prompt_tokens = vq_manager.encode_reference(ref_audio, enable_reference_audio=True)
        logger.info(f"Encoded reference audio to prompt_tokens, shape={prompt_tokens.shape if prompt_tokens is not None else None}")
    else:
        prompt_tokens = []
        logger.warning(f"Reference audio {ref_audio} not found.")
    
    # 生成语音
    logger.info(f"Generating speech for text: {text}")
    logger.info(f"Using sampling parameters: {sampling_params}")
    
    output_path = Path(output_path)
    if not output_path.suffix:
        output_path = output_path.with_suffix('.wav')
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # 创建响应队列
    response_queue = queue.Queue()
    
    # 准备请求
    request = dict(
        device=device,
        max_new_tokens=max_new_tokens,
        text=text,
        sampling_params=sampling_params,
        compile=compile,
        iterative_prompt=iterative_prompt,
        chunk_length=chunk_length,
        prompt_text=[],
        prompt_tokens=[prompt_tokens] if prompt_tokens is not None and len(prompt_tokens) else [],
        #prompt_text=["Through the dense morning fog that rolled across the peaceful valley, the distant church bells chimed their melodic song, echoing off ancient stone walls and mingling with the gentle rustling of maple leaves in the cool breeze. Inside the cozy lakeside cottage, fresh bread baked in the old clay oven filled every corner with its rich, comforting aroma, while steam rose lazily from ceramic mugs of fresh-brewed coffee on the handcrafted pine table. The persistent rain finally gave way to brilliant sunshine, transforming ordinary dewdrops into countless sparkling diamonds scattered across the vibrant garden flowers."],
    )
    
    # 发送请求到LLM模型
    llama_queue.put(GenerateRequest(request=request, response_queue=response_queue))
    
    # 收集生成的token
    all_tokens = []
    while True:
        wrapped_result: WrappedGenerateResponse = response_queue.get()
        
        if wrapped_result.status == "error":
            error = wrapped_result.response if isinstance(wrapped_result.response, Exception) else Exception("Unknown error")
            logger.error(f"Error during generation: {error}")
            break
            
        result = wrapped_result.response
        if result.action == "next":
            break
            
        all_tokens.append(result.codes)
        logger.info(f"Generated chunk {len(all_tokens)}")
    
    if not all_tokens:
        logger.error("No tokens generated")
        return
    
    # 合并所有token
    if len(all_tokens) > 1:
        tokens = torch.cat(all_tokens, dim=1)
    else:
        tokens = all_tokens[0]
    
    # 使用DAC模型生成音频
    logger.info("Converting tokens to audio...")
    feature_lengths = torch.tensor([tokens.shape[1]], device=device)
    audio, _ = dac_model.decode(
        indices=tokens[None].to("cuda:1"),
        feature_lengths=feature_lengths.to("cuda:1")
    )
    
    # 保存音频
    audio = audio[0, 0].detach().float().cpu().numpy()
    sf.write(output_path, audio, dac_model.sample_rate)
    logger.info(f"Saved audio to {output_path}")


if __name__ == "__main__":
    main()