ChatTTS-Forge / modules /api /impl /google_api.py
zhzluke96
update
ebc4336
import base64
from fastapi import HTTPException
import io
import soundfile as sf
from pydantic import BaseModel
from modules.api.Api import APIManager
from modules.utils.audio import apply_prosody_to_audio_data
from modules.normalization import text_normalize
from modules import generate_audio as generate
from modules.speaker import speaker_mgr
from modules.ssml_parser.SSMLParser import create_ssml_parser
from modules.SynthesizeSegments import (
SynthesizeSegments,
combine_audio_segments,
)
from modules.api import utils as api_utils
class SynthesisInput(BaseModel):
text: str = ""
ssml: str = ""
class VoiceSelectionParams(BaseModel):
languageCode: str = "ZH-CN"
name: str = "female2"
style: str = ""
temperature: float = 0.3
topP: float = 0.7
topK: int = 20
seed: int = 42
class AudioConfig(BaseModel):
audioEncoding: api_utils.AudioFormat = "mp3"
speakingRate: float = 1
pitch: float = 0
volumeGainDb: float = 0
sampleRateHertz: int
batchSize: int = 1
spliterThreshold: int = 100
class GoogleTextSynthesizeRequest(BaseModel):
input: SynthesisInput
voice: VoiceSelectionParams
audioConfig: dict
class GoogleTextSynthesizeResponse(BaseModel):
audioContent: str
async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
input = request.input
voice = request.voice
audioConfig = request.audioConfig
# 提取参数
# TODO 这个也许应该传给 normalizer
language_code = voice.languageCode
voice_name = voice.name
infer_seed = voice.seed or 42
audio_format = audioConfig.get("audioEncoding", "mp3")
speaking_rate = audioConfig.get("speakingRate", 1)
pitch = audioConfig.get("pitch", 0)
volume_gain_db = audioConfig.get("volumeGainDb", 0)
batch_size = audioConfig.get("batchSize", 1)
# TODO spliter_threshold
spliter_threshold = audioConfig.get("spliterThreshold", 100)
# TODO sample_rate
sample_rate_hertz = audioConfig.get("sampleRateHertz", 24000)
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
# TODO maybe need to change the sample rate
sample_rate = 24000
# 虽然 calc_spk_style 可以解析 seed 形式,但是这个接口只准备支持 speakers list 中存在的 speaker
if speaker_mgr.get_speaker(voice_name) is None:
raise HTTPException(
status_code=400, detail="The specified voice name is not supported."
)
if audio_format != "mp3" and audio_format != "wav":
raise HTTPException(
status_code=400, detail="Invalid audio encoding format specified."
)
try:
if input.text:
# 处理文本合成逻辑
text = text_normalize(input.text, is_end=True)
sample_rate, audio_data = generate.generate_audio(
text,
temperature=(
voice.temperature
if voice.temperature
else params.get("temperature", 0.3)
),
top_P=voice.topP if voice.topP else params.get("top_p", 0.7),
top_K=voice.topK if voice.topK else params.get("top_k", 20),
spk=params.get("spk", -1),
infer_seed=infer_seed,
prompt1=params.get("prompt1", ""),
prompt2=params.get("prompt2", ""),
prefix=params.get("prefix", ""),
)
elif input.ssml:
# 处理SSML合成逻辑
parser = create_ssml_parser()
segments = parser.parse(input.ssml)
for seg in segments:
seg["text"] = text_normalize(seg["text"], is_end=True)
if len(segments) == 0:
raise HTTPException(
status_code=400, detail="The SSML text is empty or parsing failed."
)
synthesize = SynthesizeSegments(batch_size=batch_size)
audio_segments = synthesize.synthesize_segments(segments)
combined_audio = combine_audio_segments(audio_segments)
buffer = io.BytesIO()
combined_audio.export(buffer, format="wav")
buffer.seek(0)
audio_data = buffer.read()
else:
raise HTTPException(
status_code=400, detail="Either text or SSML input must be provided."
)
audio_data = apply_prosody_to_audio_data(
audio_data,
rate=speaking_rate,
pitch=pitch,
volume=volume_gain_db,
sr=sample_rate,
)
buffer = io.BytesIO()
sf.write(buffer, audio_data, sample_rate, format="wav")
buffer.seek(0)
if audio_format == "mp3":
buffer = api_utils.wav_to_mp3(buffer)
base64_encoded = base64.b64encode(buffer.read())
base64_string = base64_encoded.decode("utf-8")
return {
"audioContent": f"data:audio/{audio_format.lower()};base64,{base64_string}"
}
except Exception as e:
import logging
logging.exception(e)
if isinstance(e, HTTPException):
raise e
else:
raise HTTPException(status_code=500, detail=str(e))
def setup(app: APIManager):
app.post(
"/v1/text:synthesize",
response_model=GoogleTextSynthesizeResponse,
description="""
google api document: <br/>
[https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize](https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize)
- 多个属性在本系统中无用仅仅是为了兼容google api
- voice 中的 topP, topK, temperature 为本系统中的参数
- voice.name 即 speaker name (或者speaker seed)
- voice.seed 为 infer seed (可在webui中测试具体作用)
- 编码格式影响的是 audioContent 的二进制格式,所以所有format都是返回带有base64数据的json
""",
)(google_text_synthesize)