|
import base64 |
|
import ctypes |
|
import io |
|
import json |
|
import os |
|
import struct |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from typing import AsyncGenerator, Union |
|
|
|
import httpx |
|
import numpy as np |
|
import ormsgpack |
|
import soundfile as sf |
|
|
|
from fish_speech.utils.schema import ( |
|
ServeChatRequest, |
|
ServeMessage, |
|
ServeTextPart, |
|
ServeVQGANDecodeRequest, |
|
ServeVQGANEncodeRequest, |
|
ServeVQPart, |
|
) |
|
|
|
|
|
class CustomAudioFrame: |
|
def __init__(self, data, sample_rate, num_channels, samples_per_channel): |
|
if len(data) < num_channels * samples_per_channel * ctypes.sizeof( |
|
ctypes.c_int16 |
|
): |
|
raise ValueError( |
|
"data length must be >= num_channels * samples_per_channel * sizeof(int16)" |
|
) |
|
|
|
self._data = bytearray(data) |
|
self._sample_rate = sample_rate |
|
self._num_channels = num_channels |
|
self._samples_per_channel = samples_per_channel |
|
|
|
@property |
|
def data(self): |
|
return memoryview(self._data).cast("h") |
|
|
|
@property |
|
def sample_rate(self): |
|
return self._sample_rate |
|
|
|
@property |
|
def num_channels(self): |
|
return self._num_channels |
|
|
|
@property |
|
def samples_per_channel(self): |
|
return self._samples_per_channel |
|
|
|
@property |
|
def duration(self): |
|
return self.samples_per_channel / self.sample_rate |
|
|
|
def __repr__(self): |
|
return ( |
|
f"CustomAudioFrame(sample_rate={self.sample_rate}, " |
|
f"num_channels={self.num_channels}, " |
|
f"samples_per_channel={self.samples_per_channel}, " |
|
f"duration={self.duration:.3f})" |
|
) |
|
|
|
|
|
class FishE2EEventType(Enum): |
|
SPEECH_SEGMENT = 1 |
|
TEXT_SEGMENT = 2 |
|
END_OF_TEXT = 3 |
|
END_OF_SPEECH = 4 |
|
ASR_RESULT = 5 |
|
USER_CODES = 6 |
|
|
|
|
|
@dataclass |
|
class FishE2EEvent: |
|
type: FishE2EEventType |
|
frame: np.ndarray = None |
|
text: str = None |
|
vq_codes: list[list[int]] = None |
|
|
|
|
|
client = httpx.AsyncClient( |
|
timeout=None, |
|
limits=httpx.Limits( |
|
max_connections=None, |
|
max_keepalive_connections=None, |
|
keepalive_expiry=None, |
|
), |
|
) |
|
|
|
|
|
class FishE2EAgent: |
|
def __init__(self): |
|
self.llm_url = "http://localhost:8080/v1/chat" |
|
self.vqgan_url = "http://localhost:8080" |
|
self.client = httpx.AsyncClient(timeout=None) |
|
|
|
async def get_codes(self, audio_data, sample_rate): |
|
audio_buffer = io.BytesIO() |
|
sf.write(audio_buffer, audio_data, sample_rate, format="WAV") |
|
audio_buffer.seek(0) |
|
|
|
encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()]) |
|
encode_request_bytes = ormsgpack.packb( |
|
encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC |
|
) |
|
encode_response = await self.client.post( |
|
f"{self.vqgan_url}/v1/vqgan/encode", |
|
data=encode_request_bytes, |
|
headers={"Content-Type": "application/msgpack"}, |
|
) |
|
encode_response_data = ormsgpack.unpackb(encode_response.content) |
|
codes = encode_response_data["tokens"][0] |
|
return codes |
|
|
|
async def stream( |
|
self, |
|
system_audio_data: np.ndarray | None, |
|
user_audio_data: np.ndarray | None, |
|
sample_rate: int, |
|
num_channels: int, |
|
chat_ctx: dict | None = None, |
|
) -> AsyncGenerator[bytes, None]: |
|
|
|
if system_audio_data is not None: |
|
sys_codes = await self.get_codes(system_audio_data, sample_rate) |
|
else: |
|
sys_codes = None |
|
if user_audio_data is not None: |
|
user_codes = await self.get_codes(user_audio_data, sample_rate) |
|
|
|
if chat_ctx is None: |
|
sys_parts = [ |
|
ServeTextPart( |
|
text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。' |
|
), |
|
] |
|
if system_audio_data is not None: |
|
sys_parts.append(ServeVQPart(codes=sys_codes)) |
|
chat_ctx = { |
|
"messages": [ |
|
ServeMessage( |
|
role="system", |
|
parts=sys_parts, |
|
), |
|
], |
|
} |
|
else: |
|
if chat_ctx["added_sysaudio"] is False and sys_codes: |
|
chat_ctx["added_sysaudio"] = True |
|
chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes)) |
|
|
|
prev_messages = chat_ctx["messages"].copy() |
|
if user_audio_data is not None: |
|
yield FishE2EEvent( |
|
type=FishE2EEventType.USER_CODES, |
|
vq_codes=user_codes, |
|
) |
|
else: |
|
user_codes = None |
|
|
|
request = ServeChatRequest( |
|
messages=prev_messages |
|
+ ( |
|
[ |
|
ServeMessage( |
|
role="user", |
|
parts=[ServeVQPart(codes=user_codes)], |
|
) |
|
] |
|
if user_codes |
|
else [] |
|
), |
|
streaming=True, |
|
num_samples=1, |
|
) |
|
|
|
|
|
buffer = b"" |
|
vq_codes = [] |
|
current_vq = False |
|
|
|
async def decode_send(): |
|
nonlocal current_vq |
|
nonlocal vq_codes |
|
|
|
data = np.concatenate(vq_codes, axis=1).tolist() |
|
|
|
decode_request = ServeVQGANDecodeRequest(tokens=[data]) |
|
decode_response = await self.client.post( |
|
f"{self.vqgan_url}/v1/vqgan/decode", |
|
data=ormsgpack.packb( |
|
decode_request, |
|
option=ormsgpack.OPT_SERIALIZE_PYDANTIC, |
|
), |
|
headers={"Content-Type": "application/msgpack"}, |
|
) |
|
decode_data = ormsgpack.unpackb(decode_response.content) |
|
|
|
|
|
audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16) |
|
audio_data = (audio_data * 32768).astype(np.int16).tobytes() |
|
|
|
audio_frame = CustomAudioFrame( |
|
data=audio_data, |
|
samples_per_channel=len(audio_data) // 2, |
|
sample_rate=44100, |
|
num_channels=1, |
|
) |
|
yield FishE2EEvent( |
|
type=FishE2EEventType.SPEECH_SEGMENT, |
|
frame=audio_frame, |
|
vq_codes=data, |
|
) |
|
|
|
current_vq = False |
|
vq_codes = [] |
|
|
|
async with self.client.stream( |
|
"POST", |
|
self.llm_url, |
|
data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), |
|
headers={"Content-Type": "application/msgpack"}, |
|
) as response: |
|
|
|
async for chunk in response.aiter_bytes(): |
|
buffer += chunk |
|
|
|
while len(buffer) >= 4: |
|
read_length = struct.unpack("I", buffer[:4])[0] |
|
if len(buffer) < 4 + read_length: |
|
break |
|
|
|
body = buffer[4 : 4 + read_length] |
|
buffer = buffer[4 + read_length :] |
|
data = ormsgpack.unpackb(body) |
|
|
|
if data["delta"] and data["delta"]["part"]: |
|
if current_vq and data["delta"]["part"]["type"] == "text": |
|
async for event in decode_send(): |
|
yield event |
|
if data["delta"]["part"]["type"] == "text": |
|
yield FishE2EEvent( |
|
type=FishE2EEventType.TEXT_SEGMENT, |
|
text=data["delta"]["part"]["text"], |
|
) |
|
elif data["delta"]["part"]["type"] == "vq": |
|
vq_codes.append(np.array(data["delta"]["part"]["codes"])) |
|
current_vq = True |
|
|
|
if current_vq and vq_codes: |
|
async for event in decode_send(): |
|
yield event |
|
|
|
yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT) |
|
yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH) |
|
|
|
|
|
|
|
async def main(): |
|
import torchaudio |
|
|
|
agent = FishE2EAgent() |
|
|
|
|
|
with open("uz_story_en.m4a", "rb") as f: |
|
audio_data = f.read() |
|
|
|
audio_data, sample_rate = torchaudio.load("uz_story_en.m4a") |
|
audio_data = (audio_data.numpy() * 32768).astype(np.int16) |
|
|
|
stream = agent.stream(audio_data, sample_rate, 1) |
|
if os.path.exists("audio_segment.wav"): |
|
os.remove("audio_segment.wav") |
|
|
|
async for event in stream: |
|
if event.type == FishE2EEventType.SPEECH_SEGMENT: |
|
|
|
with open("audio_segment.wav", "ab+") as f: |
|
f.write(event.frame.data) |
|
elif event.type == FishE2EEventType.ASR_RESULT: |
|
print(event.text, flush=True) |
|
elif event.type == FishE2EEventType.TEXT_SEGMENT: |
|
print(event.text, flush=True, end="") |
|
elif event.type == FishE2EEventType.END_OF_TEXT: |
|
print("\nEnd of text reached.") |
|
elif event.type == FishE2EEventType.END_OF_SPEECH: |
|
print("End of speech reached.") |
|
|
|
|
|
if __name__ == "__main__": |
|
import asyncio |
|
|
|
asyncio.run(main()) |
|
|