Spaces:
Sleeping
Sleeping
Upload 37 files
Browse files- tools/__pycache__/api.cpython-310.pyc +0 -0
- tools/__pycache__/api.cpython-311.pyc +0 -0
- tools/__pycache__/auto_rerank.cpython-310.pyc +0 -0
- tools/__pycache__/commons.cpython-310.pyc +0 -0
- tools/__pycache__/file.cpython-310.pyc +0 -0
- tools/__pycache__/schema.cpython-310.pyc +0 -0
- tools/__pycache__/webui.cpython-310.pyc +0 -0
- tools/api.py +943 -0
- tools/auto_rerank.py +159 -0
- tools/commons.py +35 -0
- tools/download_models.py +55 -0
- tools/e2e_webui.py +232 -0
- tools/extract_model.py +21 -0
- tools/file.py +125 -0
- tools/fish_e2e.py +298 -0
- tools/llama/__pycache__/generate.cpython-310.pyc +0 -0
- tools/llama/build_dataset.py +169 -0
- tools/llama/eval_in_context.py +171 -0
- tools/llama/generate.py +1087 -0
- tools/llama/merge_lora.py +95 -0
- tools/llama/quantize.py +497 -0
- tools/llama/rebuild_tokenizer.py +57 -0
- tools/msgpack_api.py +95 -0
- tools/post_api.py +227 -0
- tools/schema.py +187 -0
- tools/sensevoice/README.md +59 -0
- tools/sensevoice/__init__.py +0 -0
- tools/sensevoice/auto_model.py +573 -0
- tools/sensevoice/fun_asr.py +332 -0
- tools/sensevoice/vad_utils.py +61 -0
- tools/smart_pad.py +60 -0
- tools/vqgan/__pycache__/inference.cpython-310.pyc +0 -0
- tools/vqgan/create_train_split.py +83 -0
- tools/vqgan/extract_vq.py +233 -0
- tools/vqgan/inference.py +121 -0
- tools/webui.py +570 -0
- tools/whisper_asr.py +176 -0
tools/__pycache__/api.cpython-310.pyc
ADDED
Binary file (22.2 kB). View file
|
|
tools/__pycache__/api.cpython-311.pyc
ADDED
Binary file (45 kB). View file
|
|
tools/__pycache__/auto_rerank.cpython-310.pyc
ADDED
Binary file (3.49 kB). View file
|
|
tools/__pycache__/commons.cpython-310.pyc
ADDED
Binary file (1.49 kB). View file
|
|
tools/__pycache__/file.cpython-310.pyc
ADDED
Binary file (2.99 kB). View file
|
|
tools/__pycache__/schema.cpython-310.pyc
ADDED
Binary file (7.67 kB). View file
|
|
tools/__pycache__/webui.cpython-310.pyc
ADDED
Binary file (11.6 kB). View file
|
|
tools/api.py
ADDED
@@ -0,0 +1,943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import queue
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
import traceback
|
7 |
+
import wave
|
8 |
+
from argparse import ArgumentParser
|
9 |
+
from http import HTTPStatus
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Annotated, Any
|
12 |
+
|
13 |
+
import librosa
|
14 |
+
import numpy as np
|
15 |
+
import ormsgpack
|
16 |
+
import pyrootutils
|
17 |
+
import soundfile as sf
|
18 |
+
import torch
|
19 |
+
import torchaudio
|
20 |
+
from baize.datastructures import ContentType
|
21 |
+
from kui.asgi import (
|
22 |
+
Body,
|
23 |
+
FactoryClass,
|
24 |
+
HTTPException,
|
25 |
+
HttpRequest,
|
26 |
+
HttpView,
|
27 |
+
JSONResponse,
|
28 |
+
Kui,
|
29 |
+
OpenAPI,
|
30 |
+
StreamResponse,
|
31 |
+
request,
|
32 |
+
)
|
33 |
+
from kui.asgi.routing import MultimethodRoutes
|
34 |
+
from loguru import logger
|
35 |
+
from transformers import AutoTokenizer
|
36 |
+
|
37 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
38 |
+
import struct
|
39 |
+
from threading import Lock
|
40 |
+
|
41 |
+
import httpx
|
42 |
+
from cachetools import LRUCache, cached
|
43 |
+
from funasr import AutoModel
|
44 |
+
from silero_vad import get_speech_timestamps, load_silero_vad
|
45 |
+
|
46 |
+
from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
|
47 |
+
from fish_speech.models.text2semantic.llama import BaseModelArgs
|
48 |
+
|
49 |
+
# from fish_speech.models.vqgan.lit_module import VQGAN
|
50 |
+
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
51 |
+
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
52 |
+
from fish_speech.utils import autocast_exclude_mps, set_seed
|
53 |
+
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
54 |
+
from tools.llama.generate import (
|
55 |
+
GenerateRequest,
|
56 |
+
GenerateResponse,
|
57 |
+
WrappedGenerateResponse,
|
58 |
+
launch_thread_safe_queue,
|
59 |
+
launch_thread_safe_queue_agent,
|
60 |
+
)
|
61 |
+
from tools.schema import (
|
62 |
+
GLOBAL_NUM_SAMPLES,
|
63 |
+
ASRPackRequest,
|
64 |
+
ServeASRRequest,
|
65 |
+
ServeASRResponse,
|
66 |
+
ServeASRSegment,
|
67 |
+
ServeAudioPart,
|
68 |
+
ServeForwardMessage,
|
69 |
+
ServeMessage,
|
70 |
+
ServeRequest,
|
71 |
+
ServeResponse,
|
72 |
+
ServeStreamDelta,
|
73 |
+
ServeStreamResponse,
|
74 |
+
ServeTextPart,
|
75 |
+
ServeTimedASRResponse,
|
76 |
+
ServeTTSRequest,
|
77 |
+
ServeVQGANDecodeRequest,
|
78 |
+
ServeVQGANDecodeResponse,
|
79 |
+
ServeVQGANEncodeRequest,
|
80 |
+
ServeVQGANEncodeResponse,
|
81 |
+
ServeVQPart,
|
82 |
+
)
|
83 |
+
from tools.vqgan.inference import load_model as load_decoder_model
|
84 |
+
|
85 |
+
global_lock = Lock()
|
86 |
+
|
87 |
+
# Whether to disable keepalive (which is helpful if the server is in the same cluster)
|
88 |
+
DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
|
89 |
+
async_client = httpx.AsyncClient(
|
90 |
+
timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
|
91 |
+
)
|
92 |
+
backends = torchaudio.list_audio_backends()
|
93 |
+
|
94 |
+
if "ffmpeg" in backends:
|
95 |
+
backend = "ffmpeg"
|
96 |
+
else:
|
97 |
+
backend = "soundfile"
|
98 |
+
|
99 |
+
|
100 |
+
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
101 |
+
buffer = io.BytesIO()
|
102 |
+
|
103 |
+
with wave.open(buffer, "wb") as wav_file:
|
104 |
+
wav_file.setnchannels(channels)
|
105 |
+
wav_file.setsampwidth(bit_depth // 8)
|
106 |
+
wav_file.setframerate(sample_rate)
|
107 |
+
|
108 |
+
wav_header_bytes = buffer.getvalue()
|
109 |
+
buffer.close()
|
110 |
+
return wav_header_bytes
|
111 |
+
|
112 |
+
|
113 |
+
# Define utils for web server
|
114 |
+
async def http_execption_handler(exc: HTTPException):
|
115 |
+
return JSONResponse(
|
116 |
+
dict(
|
117 |
+
statusCode=exc.status_code,
|
118 |
+
message=exc.content,
|
119 |
+
error=HTTPStatus(exc.status_code).phrase,
|
120 |
+
),
|
121 |
+
exc.status_code,
|
122 |
+
exc.headers,
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
async def other_exception_handler(exc: "Exception"):
|
127 |
+
traceback.print_exc()
|
128 |
+
|
129 |
+
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
130 |
+
return JSONResponse(
|
131 |
+
dict(statusCode=status, message=str(exc), error=status.phrase),
|
132 |
+
status,
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
def load_audio(reference_audio, sr):
|
137 |
+
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
138 |
+
audio_data = reference_audio
|
139 |
+
reference_audio = io.BytesIO(audio_data)
|
140 |
+
|
141 |
+
waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
|
142 |
+
|
143 |
+
if waveform.shape[0] > 1:
|
144 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
145 |
+
|
146 |
+
if original_sr != sr:
|
147 |
+
resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
|
148 |
+
waveform = resampler(waveform)
|
149 |
+
|
150 |
+
audio = waveform.squeeze().numpy()
|
151 |
+
return audio
|
152 |
+
|
153 |
+
|
154 |
+
def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
|
155 |
+
if enable_reference_audio and reference_audio is not None:
|
156 |
+
# Load audios, and prepare basic info here
|
157 |
+
reference_audio_content = load_audio(
|
158 |
+
reference_audio, decoder_model.spec_transform.sample_rate
|
159 |
+
)
|
160 |
+
|
161 |
+
audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
|
162 |
+
None, None, :
|
163 |
+
]
|
164 |
+
audio_lengths = torch.tensor(
|
165 |
+
[audios.shape[2]], device=decoder_model.device, dtype=torch.long
|
166 |
+
)
|
167 |
+
logger.info(
|
168 |
+
f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
|
169 |
+
)
|
170 |
+
|
171 |
+
# VQ Encoder
|
172 |
+
if isinstance(decoder_model, FireflyArchitecture):
|
173 |
+
prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
|
174 |
+
|
175 |
+
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
176 |
+
else:
|
177 |
+
prompt_tokens = None
|
178 |
+
logger.info("No reference audio provided")
|
179 |
+
|
180 |
+
return prompt_tokens
|
181 |
+
|
182 |
+
|
183 |
+
def decode_vq_tokens(
|
184 |
+
*,
|
185 |
+
decoder_model,
|
186 |
+
codes,
|
187 |
+
):
|
188 |
+
feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
|
189 |
+
logger.info(f"VQ features: {codes.shape}")
|
190 |
+
|
191 |
+
if isinstance(decoder_model, FireflyArchitecture):
|
192 |
+
# VQGAN Inference
|
193 |
+
return decoder_model.decode(
|
194 |
+
indices=codes[None],
|
195 |
+
feature_lengths=feature_lengths,
|
196 |
+
)[0].squeeze()
|
197 |
+
|
198 |
+
raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
199 |
+
|
200 |
+
|
201 |
+
routes = MultimethodRoutes(base_class=HttpView)
|
202 |
+
|
203 |
+
|
204 |
+
def get_content_type(audio_format):
|
205 |
+
if audio_format == "wav":
|
206 |
+
return "audio/wav"
|
207 |
+
elif audio_format == "flac":
|
208 |
+
return "audio/flac"
|
209 |
+
elif audio_format == "mp3":
|
210 |
+
return "audio/mpeg"
|
211 |
+
else:
|
212 |
+
return "application/octet-stream"
|
213 |
+
|
214 |
+
|
215 |
+
@torch.no_grad()
|
216 |
+
@torch.autocast(device_type="cuda", dtype=torch.half)
|
217 |
+
def batch_encode(model, audios: list[bytes | torch.Tensor]):
|
218 |
+
audios = [
|
219 |
+
(
|
220 |
+
torch.from_numpy(
|
221 |
+
librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
|
222 |
+
)[None]
|
223 |
+
if isinstance(audio, bytes)
|
224 |
+
else audio
|
225 |
+
)
|
226 |
+
for audio in audios
|
227 |
+
]
|
228 |
+
|
229 |
+
# if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
|
230 |
+
# raise ValueError("Single audio length is too long (>120s)")
|
231 |
+
|
232 |
+
max_length = max(audio.shape[-1] for audio in audios)
|
233 |
+
print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
|
234 |
+
|
235 |
+
lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
|
236 |
+
max_length = lengths.max().item()
|
237 |
+
padded = torch.stack(
|
238 |
+
[
|
239 |
+
torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
|
240 |
+
for audio in audios
|
241 |
+
]
|
242 |
+
).to(model.device)
|
243 |
+
|
244 |
+
features, feature_lengths = model.encode(padded, audio_lengths=lengths)
|
245 |
+
features, feature_lengths = features.cpu(), feature_lengths.cpu()
|
246 |
+
|
247 |
+
return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
|
248 |
+
|
249 |
+
|
250 |
+
@cached(
|
251 |
+
cache=LRUCache(maxsize=10000),
|
252 |
+
key=lambda model, audios: (model.device, tuple(audios)),
|
253 |
+
)
|
254 |
+
def cached_vqgan_batch_encode(model, audios: list[bytes]):
|
255 |
+
return batch_encode(model, audios)
|
256 |
+
|
257 |
+
|
258 |
+
@routes.http.post("/v1/vqgan/encode")
|
259 |
+
def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
|
260 |
+
|
261 |
+
start_time = time.time()
|
262 |
+
tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
|
263 |
+
logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
|
264 |
+
|
265 |
+
return ormsgpack.packb(
|
266 |
+
ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
|
267 |
+
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
268 |
+
)
|
269 |
+
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
@torch.autocast(device_type="cuda", dtype=torch.half)
|
273 |
+
def vqgan_decode(model, features):
|
274 |
+
lengths = torch.tensor(
|
275 |
+
[feature.shape[-1] for feature in features], device=model.device
|
276 |
+
)
|
277 |
+
max_length = lengths.max().item()
|
278 |
+
padded = torch.stack(
|
279 |
+
[
|
280 |
+
torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
|
281 |
+
for feature in features
|
282 |
+
]
|
283 |
+
).to(model.device)
|
284 |
+
|
285 |
+
# If bs too large, we do micro batch decode
|
286 |
+
audios, audio_lengths = [], []
|
287 |
+
for i in range(0, padded.shape[0], 8):
|
288 |
+
audio, audio_length = model.decode(
|
289 |
+
padded[i : i + 8], feature_lengths=lengths[i : i + 8]
|
290 |
+
)
|
291 |
+
audios.append(audio)
|
292 |
+
audio_lengths.append(audio_length)
|
293 |
+
audios = torch.cat(audios, dim=0)
|
294 |
+
audio_lengths = torch.cat(audio_lengths, dim=0)
|
295 |
+
audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
|
296 |
+
|
297 |
+
return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
|
298 |
+
|
299 |
+
|
300 |
+
@routes.http.post("/v1/vqgan/decode")
|
301 |
+
def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
|
302 |
+
tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
|
303 |
+
start_time = time.time()
|
304 |
+
audios = vqgan_decode(decoder_model, tokens)
|
305 |
+
logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
|
306 |
+
audios = [audio.astype(np.float16).tobytes() for audio in audios]
|
307 |
+
return ormsgpack.packb(
|
308 |
+
ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
309 |
+
)
|
310 |
+
|
311 |
+
|
312 |
+
@torch.no_grad()
|
313 |
+
def batch_asr(model, audios, sr, language="auto"):
|
314 |
+
resampled_audios = []
|
315 |
+
for audio in audios:
|
316 |
+
audio = torchaudio.functional.resample(audio, sr, 16000)
|
317 |
+
assert audio.ndim == 1
|
318 |
+
resampled_audios.append(audio)
|
319 |
+
|
320 |
+
with global_lock:
|
321 |
+
res = model.generate(
|
322 |
+
input=resampled_audios,
|
323 |
+
batch_size=len(resampled_audios),
|
324 |
+
language=language,
|
325 |
+
use_itn=True,
|
326 |
+
)
|
327 |
+
|
328 |
+
results = []
|
329 |
+
for r, audio in zip(res, audios):
|
330 |
+
text = r["text"]
|
331 |
+
text = re.sub(r"<\|.*?\|>", "", text)
|
332 |
+
duration = len(audio) / sr * 1000
|
333 |
+
huge_gap = False
|
334 |
+
|
335 |
+
if "timestamp" in r and len(r["timestamp"]) > 2:
|
336 |
+
for timestamp_a, timestamp_b in zip(
|
337 |
+
r["timestamp"][:-1], r["timestamp"][1:]
|
338 |
+
):
|
339 |
+
# If there is a gap of more than 5 seconds, we consider it as a huge gap
|
340 |
+
if timestamp_b[0] - timestamp_a[1] > 5000:
|
341 |
+
huge_gap = True
|
342 |
+
break
|
343 |
+
|
344 |
+
# Doesn't make sense to have a huge gap at the end
|
345 |
+
if duration - r["timestamp"][-1][1] > 3000:
|
346 |
+
huge_gap = True
|
347 |
+
|
348 |
+
results.append(
|
349 |
+
{
|
350 |
+
"text": text,
|
351 |
+
"duration": duration,
|
352 |
+
"huge_gap": huge_gap,
|
353 |
+
}
|
354 |
+
)
|
355 |
+
|
356 |
+
return results
|
357 |
+
|
358 |
+
|
359 |
+
@routes.http.post("/v1/asr")
|
360 |
+
def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
|
361 |
+
start_time = time.time()
|
362 |
+
audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
|
363 |
+
audios = [torch.from_numpy(audio).float() for audio in audios]
|
364 |
+
|
365 |
+
if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
|
366 |
+
raise HTTPException(status_code=400, detail="Audio length is too long")
|
367 |
+
|
368 |
+
transcriptions = batch_asr(
|
369 |
+
asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
|
370 |
+
)
|
371 |
+
logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
|
372 |
+
|
373 |
+
return ormsgpack.packb(
|
374 |
+
ServeASRResponse(transcriptions=transcriptions),
|
375 |
+
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
376 |
+
)
|
377 |
+
|
378 |
+
|
379 |
+
from fish_speech.conversation import Conversation, Message
|
380 |
+
|
381 |
+
|
382 |
+
def execute_request(
|
383 |
+
input_queue: queue.Queue,
|
384 |
+
tokenizer: AutoTokenizer,
|
385 |
+
config: BaseModelArgs,
|
386 |
+
request: ServeRequest,
|
387 |
+
device: str = "cuda:0",
|
388 |
+
):
|
389 |
+
semantic_id, im_end_id = tokenizer.convert_tokens_to_ids(
|
390 |
+
[SEMANTIC_TOKEN, IM_END_TOKEN]
|
391 |
+
)
|
392 |
+
messages = []
|
393 |
+
for message in request.messages:
|
394 |
+
messages.append(message.to_conversation_message())
|
395 |
+
|
396 |
+
assert len(messages) >= 1, "At least one message is required"
|
397 |
+
# assert messages[-1].role == "user", "The last message must be from the user"
|
398 |
+
|
399 |
+
if messages[-1].role == "user":
|
400 |
+
messages.append(Message(role="assistant", parts=[], add_im_end=False))
|
401 |
+
else:
|
402 |
+
assert (
|
403 |
+
messages[-1].role == "assistant"
|
404 |
+
), "The last message must be from the assistant"
|
405 |
+
messages[-1].add_im_end = False
|
406 |
+
|
407 |
+
conv = Conversation(messages=messages)
|
408 |
+
prompt = conv.encode_for_inference(
|
409 |
+
tokenizer=tokenizer, num_codebooks=config.num_codebooks
|
410 |
+
).to(device)
|
411 |
+
|
412 |
+
if request.streaming:
|
413 |
+
for i in range(request.num_samples):
|
414 |
+
yield ServeStreamResponse(
|
415 |
+
sample_id=i,
|
416 |
+
delta=ServeStreamDelta(
|
417 |
+
role="assistant",
|
418 |
+
),
|
419 |
+
)
|
420 |
+
|
421 |
+
req = {
|
422 |
+
"prompt": prompt,
|
423 |
+
"max_new_tokens": request.max_new_tokens,
|
424 |
+
"im_end_id": im_end_id,
|
425 |
+
"semantic_id": semantic_id,
|
426 |
+
"temperature": request.temperature,
|
427 |
+
"top_p": request.top_p,
|
428 |
+
"repetition_penalty": request.repetition_penalty,
|
429 |
+
"num_samples": request.num_samples,
|
430 |
+
"early_stop_threshold": request.early_stop_threshold,
|
431 |
+
}
|
432 |
+
|
433 |
+
start = time.time()
|
434 |
+
response_queue = queue.Queue()
|
435 |
+
input_queue.put(GenerateRequest(req, response_queue))
|
436 |
+
|
437 |
+
# Decoding
|
438 |
+
decode_buffer = [[] for _ in range(request.num_samples)]
|
439 |
+
parts = [[] for _ in range(request.num_samples)]
|
440 |
+
|
441 |
+
def send_reset_buffer(sample_id):
|
442 |
+
nonlocal decode_buffer
|
443 |
+
if len(decode_buffer[sample_id]) == 0:
|
444 |
+
return
|
445 |
+
|
446 |
+
decoded = tokenizer.decode(decode_buffer[sample_id])
|
447 |
+
part = ServeTextPart(text=decoded)
|
448 |
+
|
449 |
+
if request.streaming:
|
450 |
+
yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
|
451 |
+
else:
|
452 |
+
parts[sample_id].append(part)
|
453 |
+
|
454 |
+
decode_buffer[sample_id] = []
|
455 |
+
|
456 |
+
# Decode process
|
457 |
+
finished = [False for _ in range(request.num_samples)]
|
458 |
+
stats = {}
|
459 |
+
idx = 0
|
460 |
+
while True:
|
461 |
+
response = response_queue.get()
|
462 |
+
|
463 |
+
if response in ["stop", "error"]:
|
464 |
+
break
|
465 |
+
|
466 |
+
for sample_id, tokens in enumerate(response):
|
467 |
+
if finished[sample_id]:
|
468 |
+
continue
|
469 |
+
|
470 |
+
if tokens[0] == im_end_id:
|
471 |
+
finished[sample_id] = True
|
472 |
+
if request.streaming:
|
473 |
+
yield from send_reset_buffer(sample_id)
|
474 |
+
yield ServeStreamResponse(
|
475 |
+
sample_id=sample_id,
|
476 |
+
finish_reason="stop",
|
477 |
+
stats=stats,
|
478 |
+
)
|
479 |
+
continue
|
480 |
+
|
481 |
+
if tokens[0] == semantic_id and request.streaming:
|
482 |
+
yield from send_reset_buffer(sample_id)
|
483 |
+
# Streaming vq
|
484 |
+
_tokens = tokens[1:].clone() - 1
|
485 |
+
|
486 |
+
if config.share_codebook_embeddings is False:
|
487 |
+
for i in range(len(_tokens)):
|
488 |
+
_tokens[i] -= config.codebook_size * i
|
489 |
+
|
490 |
+
yield ServeStreamResponse(
|
491 |
+
sample_id=sample_id,
|
492 |
+
delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
|
493 |
+
)
|
494 |
+
continue
|
495 |
+
|
496 |
+
# Not streaming vq
|
497 |
+
if tokens[0] == semantic_id:
|
498 |
+
yield from send_reset_buffer(sample_id)
|
499 |
+
# None streaming vq
|
500 |
+
if len(parts[sample_id]) == 0 or not isinstance(
|
501 |
+
parts[sample_id][-1], ServeVQPart
|
502 |
+
):
|
503 |
+
_tokens = tokens[1:].clone() - 1
|
504 |
+
|
505 |
+
if config.share_codebook_embeddings is False:
|
506 |
+
for i in range(len(_tokens)):
|
507 |
+
_tokens[i] -= config.codebook_size * i
|
508 |
+
|
509 |
+
parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
|
510 |
+
else:
|
511 |
+
for codebook_id, value in enumerate(tokens[1:, :]):
|
512 |
+
val = value.item() - 1
|
513 |
+
if config.share_codebook_embeddings is False:
|
514 |
+
val -= config.codebook_size * codebook_id
|
515 |
+
|
516 |
+
parts[sample_id][-1].codes[codebook_id].append(val)
|
517 |
+
continue
|
518 |
+
|
519 |
+
if tokens[0] != semantic_id:
|
520 |
+
# Stream text decode is not supported now
|
521 |
+
decode_buffer[sample_id].append(tokens[0, 0])
|
522 |
+
|
523 |
+
if idx == 0:
|
524 |
+
stats["time_to_first_token"] = (time.time() - start) * 1000
|
525 |
+
|
526 |
+
idx += 1
|
527 |
+
|
528 |
+
for sample_id in range(request.num_samples):
|
529 |
+
yield from send_reset_buffer(sample_id)
|
530 |
+
|
531 |
+
stats["total_time"] = (time.time() - start) * 1000
|
532 |
+
stats["total_tokens"] = idx
|
533 |
+
|
534 |
+
if request.streaming:
|
535 |
+
for sample_id in range(request.num_samples):
|
536 |
+
if finished[sample_id]:
|
537 |
+
continue
|
538 |
+
yield ServeStreamResponse(
|
539 |
+
finish_reason=response, stats=stats, sample_id=sample_id
|
540 |
+
)
|
541 |
+
return
|
542 |
+
|
543 |
+
yield ServeResponse(
|
544 |
+
messages=[
|
545 |
+
ServeMessage(role="assistant", parts=parts[i])
|
546 |
+
for i in range(request.num_samples)
|
547 |
+
],
|
548 |
+
finish_reason=response,
|
549 |
+
stats=stats,
|
550 |
+
)
|
551 |
+
|
552 |
+
|
553 |
+
@routes.http.post("/v1/chat")
|
554 |
+
def api_invoke_chat(
|
555 |
+
req: Annotated[ServeRequest, Body(exclusive=True)],
|
556 |
+
):
|
557 |
+
"""
|
558 |
+
Invoke model and generate audio
|
559 |
+
"""
|
560 |
+
|
561 |
+
# This makes torch compile happy
|
562 |
+
assert (
|
563 |
+
req.num_samples == GLOBAL_NUM_SAMPLES
|
564 |
+
), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
|
565 |
+
|
566 |
+
content_type = request.headers.get("Content-Type", "application/json")
|
567 |
+
json_mode = "application/json" in content_type
|
568 |
+
|
569 |
+
async def wrapped_generator():
|
570 |
+
generator = execute_request(llama_queue, tokenizer, config, req, args.device)
|
571 |
+
|
572 |
+
for i in generator:
|
573 |
+
if json_mode:
|
574 |
+
body = i.model_dump_json().encode("utf-8")
|
575 |
+
yield b"data: " + body + b"\n\n"
|
576 |
+
else:
|
577 |
+
body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
578 |
+
yield struct.pack("I", len(body)) + body
|
579 |
+
|
580 |
+
# Naive mode
|
581 |
+
if req.streaming is False:
|
582 |
+
result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
|
583 |
+
|
584 |
+
if json_mode:
|
585 |
+
return JSONResponse(result.model_dump())
|
586 |
+
else:
|
587 |
+
return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
588 |
+
|
589 |
+
return StreamResponse(
|
590 |
+
iterable=wrapped_generator(), content_type="text/event-stream"
|
591 |
+
)
|
592 |
+
|
593 |
+
|
594 |
+
@torch.inference_mode()
|
595 |
+
def inference(req: ServeTTSRequest):
|
596 |
+
|
597 |
+
global prompt_tokens, prompt_texts
|
598 |
+
|
599 |
+
idstr: str | None = req.reference_id
|
600 |
+
if idstr is not None:
|
601 |
+
ref_folder = Path("references") / idstr
|
602 |
+
ref_folder.mkdir(parents=True, exist_ok=True)
|
603 |
+
ref_audios = list_files(
|
604 |
+
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
605 |
+
)
|
606 |
+
|
607 |
+
if req.use_memory_cache == "never" or (
|
608 |
+
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
609 |
+
):
|
610 |
+
prompt_tokens = [
|
611 |
+
encode_reference(
|
612 |
+
decoder_model=decoder_model,
|
613 |
+
reference_audio=audio_to_bytes(str(ref_audio)),
|
614 |
+
enable_reference_audio=True,
|
615 |
+
)
|
616 |
+
for ref_audio in ref_audios
|
617 |
+
]
|
618 |
+
prompt_texts = [
|
619 |
+
read_ref_text(str(ref_audio.with_suffix(".lab")))
|
620 |
+
for ref_audio in ref_audios
|
621 |
+
]
|
622 |
+
else:
|
623 |
+
logger.info("Use same references")
|
624 |
+
|
625 |
+
else:
|
626 |
+
# Parse reference audio aka prompt
|
627 |
+
refs = req.references
|
628 |
+
|
629 |
+
if req.use_memory_cache == "never" or (
|
630 |
+
req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
|
631 |
+
):
|
632 |
+
prompt_tokens = [
|
633 |
+
encode_reference(
|
634 |
+
decoder_model=decoder_model,
|
635 |
+
reference_audio=ref.audio,
|
636 |
+
enable_reference_audio=True,
|
637 |
+
)
|
638 |
+
for ref in refs
|
639 |
+
]
|
640 |
+
prompt_texts = [ref.text for ref in refs]
|
641 |
+
else:
|
642 |
+
logger.info("Use same references")
|
643 |
+
|
644 |
+
if req.seed is not None:
|
645 |
+
set_seed(req.seed)
|
646 |
+
logger.warning(f"set seed: {req.seed}")
|
647 |
+
|
648 |
+
# LLAMA Inference
|
649 |
+
request = dict(
|
650 |
+
device=decoder_model.device,
|
651 |
+
max_new_tokens=req.max_new_tokens,
|
652 |
+
text=(
|
653 |
+
req.text
|
654 |
+
if not req.normalize
|
655 |
+
else ChnNormedText(raw_text=req.text).normalize()
|
656 |
+
),
|
657 |
+
top_p=req.top_p,
|
658 |
+
repetition_penalty=req.repetition_penalty,
|
659 |
+
temperature=req.temperature,
|
660 |
+
compile=args.compile,
|
661 |
+
iterative_prompt=req.chunk_length > 0,
|
662 |
+
chunk_length=req.chunk_length,
|
663 |
+
max_length=4096,
|
664 |
+
prompt_tokens=prompt_tokens,
|
665 |
+
prompt_text=prompt_texts,
|
666 |
+
)
|
667 |
+
|
668 |
+
response_queue = queue.Queue()
|
669 |
+
llama_queue.put(
|
670 |
+
GenerateRequest(
|
671 |
+
request=request,
|
672 |
+
response_queue=response_queue,
|
673 |
+
)
|
674 |
+
)
|
675 |
+
|
676 |
+
if req.streaming:
|
677 |
+
yield wav_chunk_header()
|
678 |
+
|
679 |
+
segments = []
|
680 |
+
while True:
|
681 |
+
result: WrappedGenerateResponse = response_queue.get()
|
682 |
+
if result.status == "error":
|
683 |
+
raise result.response
|
684 |
+
break
|
685 |
+
|
686 |
+
result: GenerateResponse = result.response
|
687 |
+
if result.action == "next":
|
688 |
+
break
|
689 |
+
|
690 |
+
with autocast_exclude_mps(
|
691 |
+
device_type=decoder_model.device.type, dtype=args.precision
|
692 |
+
):
|
693 |
+
fake_audios = decode_vq_tokens(
|
694 |
+
decoder_model=decoder_model,
|
695 |
+
codes=result.codes,
|
696 |
+
)
|
697 |
+
|
698 |
+
fake_audios = fake_audios.float().cpu().numpy()
|
699 |
+
|
700 |
+
if req.streaming:
|
701 |
+
yield (fake_audios * 32768).astype(np.int16).tobytes()
|
702 |
+
else:
|
703 |
+
segments.append(fake_audios)
|
704 |
+
|
705 |
+
if req.streaming:
|
706 |
+
return
|
707 |
+
|
708 |
+
if len(segments) == 0:
|
709 |
+
raise HTTPException(
|
710 |
+
HTTPStatus.INTERNAL_SERVER_ERROR,
|
711 |
+
content="No audio generated, please check the input text.",
|
712 |
+
)
|
713 |
+
|
714 |
+
fake_audios = np.concatenate(segments, axis=0)
|
715 |
+
yield fake_audios
|
716 |
+
|
717 |
+
|
718 |
+
async def inference_async(req: ServeTTSRequest):
|
719 |
+
for chunk in inference(req):
|
720 |
+
yield chunk
|
721 |
+
|
722 |
+
|
723 |
+
async def buffer_to_async_generator(buffer):
|
724 |
+
yield buffer
|
725 |
+
|
726 |
+
|
727 |
+
@routes.http.post("/v1/tts")
|
728 |
+
async def api_invoke_model(
|
729 |
+
req: Annotated[ServeTTSRequest, Body(exclusive=True)],
|
730 |
+
):
|
731 |
+
"""
|
732 |
+
Invoke model and generate audio
|
733 |
+
"""
|
734 |
+
|
735 |
+
if args.max_text_length > 0 and len(req.text) > args.max_text_length:
|
736 |
+
raise HTTPException(
|
737 |
+
HTTPStatus.BAD_REQUEST,
|
738 |
+
content=f"Text is too long, max length is {args.max_text_length}",
|
739 |
+
)
|
740 |
+
|
741 |
+
if req.streaming and req.format != "wav":
|
742 |
+
raise HTTPException(
|
743 |
+
HTTPStatus.BAD_REQUEST,
|
744 |
+
content="Streaming only supports WAV format",
|
745 |
+
)
|
746 |
+
|
747 |
+
if req.streaming:
|
748 |
+
return StreamResponse(
|
749 |
+
iterable=inference_async(req),
|
750 |
+
headers={
|
751 |
+
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
752 |
+
},
|
753 |
+
content_type=get_content_type(req.format),
|
754 |
+
)
|
755 |
+
else:
|
756 |
+
fake_audios = next(inference(req))
|
757 |
+
buffer = io.BytesIO()
|
758 |
+
sf.write(
|
759 |
+
buffer,
|
760 |
+
fake_audios,
|
761 |
+
decoder_model.spec_transform.sample_rate,
|
762 |
+
format=req.format,
|
763 |
+
)
|
764 |
+
|
765 |
+
return StreamResponse(
|
766 |
+
iterable=buffer_to_async_generator(buffer.getvalue()),
|
767 |
+
headers={
|
768 |
+
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
769 |
+
},
|
770 |
+
content_type=get_content_type(req.format),
|
771 |
+
)
|
772 |
+
|
773 |
+
|
774 |
+
@routes.http.post("/v1/health")
|
775 |
+
async def api_health():
|
776 |
+
"""
|
777 |
+
Health check
|
778 |
+
"""
|
779 |
+
|
780 |
+
return JSONResponse({"status": "ok"})
|
781 |
+
|
782 |
+
|
783 |
+
def parse_args():
|
784 |
+
parser = ArgumentParser()
|
785 |
+
parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
|
786 |
+
parser.add_argument("--load-asr-model", action="store_true")
|
787 |
+
parser.add_argument(
|
788 |
+
"--llama-checkpoint-path",
|
789 |
+
type=str,
|
790 |
+
default="checkpoints/fish-speech-1.4",
|
791 |
+
)
|
792 |
+
parser.add_argument(
|
793 |
+
"--decoder-checkpoint-path",
|
794 |
+
type=str,
|
795 |
+
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
796 |
+
)
|
797 |
+
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
798 |
+
parser.add_argument("--device", type=str, default="cuda")
|
799 |
+
parser.add_argument("--half", action="store_true")
|
800 |
+
parser.add_argument("--compile", action="store_true")
|
801 |
+
parser.add_argument("--max-text-length", type=int, default=0)
|
802 |
+
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
803 |
+
parser.add_argument("--workers", type=int, default=1)
|
804 |
+
|
805 |
+
return parser.parse_args()
|
806 |
+
|
807 |
+
|
808 |
+
# Define Kui app
|
809 |
+
openapi = OpenAPI(
|
810 |
+
{
|
811 |
+
"title": "Fish Speech API",
|
812 |
+
"version": "1.4.2",
|
813 |
+
},
|
814 |
+
).routes
|
815 |
+
|
816 |
+
|
817 |
+
class MsgPackRequest(HttpRequest):
|
818 |
+
async def data(
|
819 |
+
self,
|
820 |
+
) -> Annotated[
|
821 |
+
Any, ContentType("application/msgpack"), ContentType("application/json")
|
822 |
+
]:
|
823 |
+
if self.content_type == "application/msgpack":
|
824 |
+
return ormsgpack.unpackb(await self.body)
|
825 |
+
|
826 |
+
elif self.content_type == "application/json":
|
827 |
+
return await self.json
|
828 |
+
|
829 |
+
raise HTTPException(
|
830 |
+
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
831 |
+
headers={"Accept": "application/msgpack, application/json"},
|
832 |
+
)
|
833 |
+
|
834 |
+
|
835 |
+
app = Kui(
|
836 |
+
routes=routes + openapi[1:], # Remove the default route
|
837 |
+
exception_handlers={
|
838 |
+
HTTPException: http_execption_handler,
|
839 |
+
Exception: other_exception_handler,
|
840 |
+
},
|
841 |
+
factory_class=FactoryClass(http=MsgPackRequest),
|
842 |
+
cors_config={},
|
843 |
+
)
|
844 |
+
|
845 |
+
|
846 |
+
def load_asr_model(*, device="cuda", hub="ms"):
|
847 |
+
return AutoModel(
|
848 |
+
model="iic/SenseVoiceSmall",
|
849 |
+
device=device,
|
850 |
+
disable_pbar=True,
|
851 |
+
hub=hub,
|
852 |
+
)
|
853 |
+
|
854 |
+
|
855 |
+
# Each worker process created by Uvicorn has its own memory space,
|
856 |
+
# meaning that models and variables are not shared between processes.
|
857 |
+
# Therefore, any global variables (like `llama_queue` or `decoder_model`)
|
858 |
+
# will not be shared across workers.
|
859 |
+
|
860 |
+
|
861 |
+
# Multi-threading for deep learning can cause issues, such as inconsistent
|
862 |
+
# outputs if multiple threads access the same buffers simultaneously.
|
863 |
+
# Instead, it's better to use multiprocessing or independent models per thread.
|
864 |
+
@app.on_startup
|
865 |
+
def initialize_app(app: Kui):
|
866 |
+
|
867 |
+
global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
|
868 |
+
|
869 |
+
prompt_tokens, prompt_texts = [], []
|
870 |
+
|
871 |
+
args = parse_args() # args same as ones in other processes
|
872 |
+
args.precision = torch.half if args.half else torch.bfloat16
|
873 |
+
|
874 |
+
if args.load_asr_model:
|
875 |
+
logger.info(f"Loading ASR model...")
|
876 |
+
asr_model = load_asr_model(device=args.device)
|
877 |
+
|
878 |
+
logger.info("Loading Llama model...")
|
879 |
+
|
880 |
+
if args.mode == "tts":
|
881 |
+
llama_queue = launch_thread_safe_queue(
|
882 |
+
checkpoint_path=args.llama_checkpoint_path,
|
883 |
+
device=args.device,
|
884 |
+
precision=args.precision,
|
885 |
+
compile=args.compile,
|
886 |
+
)
|
887 |
+
else:
|
888 |
+
llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
|
889 |
+
checkpoint_path=args.llama_checkpoint_path,
|
890 |
+
device=args.device,
|
891 |
+
precision=args.precision,
|
892 |
+
compile=args.compile,
|
893 |
+
)
|
894 |
+
|
895 |
+
logger.info("Llama model loaded, loading VQ-GAN model...")
|
896 |
+
|
897 |
+
decoder_model = load_decoder_model(
|
898 |
+
config_name=args.decoder_config_name,
|
899 |
+
checkpoint_path=args.decoder_checkpoint_path,
|
900 |
+
device=args.device,
|
901 |
+
)
|
902 |
+
|
903 |
+
logger.info("VQ-GAN model loaded, warming up...")
|
904 |
+
|
905 |
+
vad_model = load_silero_vad()
|
906 |
+
|
907 |
+
logger.info("VAD model loaded, warming up...")
|
908 |
+
|
909 |
+
if args.mode == "tts":
|
910 |
+
# Dry run to ensure models work and avoid first-time latency
|
911 |
+
list(
|
912 |
+
inference(
|
913 |
+
ServeTTSRequest(
|
914 |
+
text="Hello world.",
|
915 |
+
references=[],
|
916 |
+
reference_id=None,
|
917 |
+
max_new_tokens=0,
|
918 |
+
chunk_length=200,
|
919 |
+
top_p=0.7,
|
920 |
+
repetition_penalty=1.2,
|
921 |
+
temperature=0.7,
|
922 |
+
emotion=None,
|
923 |
+
format="wav",
|
924 |
+
)
|
925 |
+
)
|
926 |
+
)
|
927 |
+
|
928 |
+
logger.info(f"Warming up done, starting server at http://{args.listen}")
|
929 |
+
|
930 |
+
|
931 |
+
if __name__ == "__main__":
|
932 |
+
|
933 |
+
import uvicorn
|
934 |
+
|
935 |
+
args = parse_args()
|
936 |
+
host, port = args.listen.split(":")
|
937 |
+
uvicorn.run(
|
938 |
+
"tools.api:app",
|
939 |
+
host='0.0.0.0',
|
940 |
+
port=int(port),
|
941 |
+
workers=args.workers,
|
942 |
+
log_level="info",
|
943 |
+
)
|
tools/auto_rerank.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["MODELSCOPE_CACHE"] = ".cache/"
|
4 |
+
|
5 |
+
import string
|
6 |
+
import time
|
7 |
+
from threading import Lock
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import opencc
|
12 |
+
import torch
|
13 |
+
from faster_whisper import WhisperModel
|
14 |
+
|
15 |
+
t2s_converter = opencc.OpenCC("t2s")
|
16 |
+
|
17 |
+
|
18 |
+
def load_model(*, device="cuda"):
|
19 |
+
model = WhisperModel(
|
20 |
+
"medium",
|
21 |
+
device=device,
|
22 |
+
compute_type="float16",
|
23 |
+
download_root="faster_whisper",
|
24 |
+
)
|
25 |
+
print("faster_whisper loaded!")
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def batch_asr_internal(model: WhisperModel, audios, sr):
|
31 |
+
resampled_audios = []
|
32 |
+
for audio in audios:
|
33 |
+
|
34 |
+
if isinstance(audio, np.ndarray):
|
35 |
+
audio = torch.from_numpy(audio).float()
|
36 |
+
|
37 |
+
if audio.dim() > 1:
|
38 |
+
audio = audio.squeeze()
|
39 |
+
|
40 |
+
assert audio.dim() == 1
|
41 |
+
audio_np = audio.numpy()
|
42 |
+
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
|
43 |
+
resampled_audios.append(resampled_audio)
|
44 |
+
|
45 |
+
trans_results = []
|
46 |
+
|
47 |
+
for resampled_audio in resampled_audios:
|
48 |
+
segments, info = model.transcribe(
|
49 |
+
resampled_audio,
|
50 |
+
language=None,
|
51 |
+
beam_size=5,
|
52 |
+
initial_prompt="Punctuation is needed in any language.",
|
53 |
+
)
|
54 |
+
trans_results.append(list(segments))
|
55 |
+
|
56 |
+
results = []
|
57 |
+
for trans_res, audio in zip(trans_results, audios):
|
58 |
+
|
59 |
+
duration = len(audio) / sr * 1000
|
60 |
+
huge_gap = False
|
61 |
+
max_gap = 0.0
|
62 |
+
|
63 |
+
text = None
|
64 |
+
last_tr = None
|
65 |
+
|
66 |
+
for tr in trans_res:
|
67 |
+
delta = tr.text.strip()
|
68 |
+
if tr.id > 1:
|
69 |
+
max_gap = max(tr.start - last_tr.end, max_gap)
|
70 |
+
text += delta
|
71 |
+
else:
|
72 |
+
text = delta
|
73 |
+
|
74 |
+
last_tr = tr
|
75 |
+
if max_gap > 3.0:
|
76 |
+
huge_gap = True
|
77 |
+
break
|
78 |
+
|
79 |
+
sim_text = t2s_converter.convert(text)
|
80 |
+
results.append(
|
81 |
+
{
|
82 |
+
"text": sim_text,
|
83 |
+
"duration": duration,
|
84 |
+
"huge_gap": huge_gap,
|
85 |
+
}
|
86 |
+
)
|
87 |
+
|
88 |
+
return results
|
89 |
+
|
90 |
+
|
91 |
+
global_lock = Lock()
|
92 |
+
|
93 |
+
|
94 |
+
def batch_asr(model, audios, sr):
|
95 |
+
return batch_asr_internal(model, audios, sr)
|
96 |
+
|
97 |
+
|
98 |
+
def is_chinese(text):
|
99 |
+
return True
|
100 |
+
|
101 |
+
|
102 |
+
def calculate_wer(text1, text2, debug=False):
|
103 |
+
chars1 = remove_punctuation(text1)
|
104 |
+
chars2 = remove_punctuation(text2)
|
105 |
+
|
106 |
+
m, n = len(chars1), len(chars2)
|
107 |
+
|
108 |
+
if m > n:
|
109 |
+
chars1, chars2 = chars2, chars1
|
110 |
+
m, n = n, m
|
111 |
+
|
112 |
+
prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
|
113 |
+
curr = [0] * (m + 1)
|
114 |
+
|
115 |
+
for j in range(1, n + 1):
|
116 |
+
curr[0] = j
|
117 |
+
for i in range(1, m + 1):
|
118 |
+
if chars1[i - 1] == chars2[j - 1]:
|
119 |
+
curr[i] = prev[i - 1]
|
120 |
+
else:
|
121 |
+
curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
|
122 |
+
prev, curr = curr, prev
|
123 |
+
|
124 |
+
edits = prev[m]
|
125 |
+
tot = max(len(chars1), len(chars2))
|
126 |
+
wer = edits / tot
|
127 |
+
|
128 |
+
if debug:
|
129 |
+
print(" gt: ", chars1)
|
130 |
+
print(" pred: ", chars2)
|
131 |
+
print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
|
132 |
+
|
133 |
+
return wer
|
134 |
+
|
135 |
+
|
136 |
+
def remove_punctuation(text):
|
137 |
+
chinese_punctuation = (
|
138 |
+
" \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
|
139 |
+
'‛""„‟…‧﹏'
|
140 |
+
)
|
141 |
+
all_punctuation = string.punctuation + chinese_punctuation
|
142 |
+
translator = str.maketrans("", "", all_punctuation)
|
143 |
+
text_without_punctuation = text.translate(translator)
|
144 |
+
return text_without_punctuation
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
model = load_model()
|
149 |
+
audios = [
|
150 |
+
librosa.load("44100.wav", sr=44100)[0],
|
151 |
+
librosa.load("lengyue.wav", sr=44100)[0],
|
152 |
+
]
|
153 |
+
print(np.array(audios[0]))
|
154 |
+
print(batch_asr(model, audios, 44100))
|
155 |
+
|
156 |
+
start_time = time.time()
|
157 |
+
for _ in range(10):
|
158 |
+
print(batch_asr(model, audios, 44100))
|
159 |
+
print("Time taken:", time.time() - start_time)
|
tools/commons.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, Literal, Optional
|
2 |
+
|
3 |
+
from pydantic import BaseModel, Field, conint
|
4 |
+
|
5 |
+
|
6 |
+
class ServeReferenceAudio(BaseModel):
|
7 |
+
audio: bytes
|
8 |
+
text: str
|
9 |
+
|
10 |
+
|
11 |
+
class ServeTTSRequest(BaseModel):
|
12 |
+
text: str
|
13 |
+
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
14 |
+
# Audio format
|
15 |
+
format: Literal["wav", "pcm", "mp3"] = "wav"
|
16 |
+
mp3_bitrate: Literal[64, 128, 192] = 128
|
17 |
+
# References audios for in-context learning
|
18 |
+
references: list[ServeReferenceAudio] = []
|
19 |
+
# Reference id
|
20 |
+
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
21 |
+
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
22 |
+
reference_id: str | None = None
|
23 |
+
# Normalize text for en & zh, this increase stability for numbers
|
24 |
+
normalize: bool = True
|
25 |
+
mp3_bitrate: Optional[int] = 64
|
26 |
+
opus_bitrate: Optional[int] = -1000
|
27 |
+
# Balance mode will reduce latency to 300ms, but may decrease stability
|
28 |
+
latency: Literal["normal", "balanced"] = "normal"
|
29 |
+
# not usually used below
|
30 |
+
streaming: bool = False
|
31 |
+
emotion: Optional[str] = None
|
32 |
+
max_new_tokens: int = 1024
|
33 |
+
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
34 |
+
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
35 |
+
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
tools/download_models.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
|
5 |
+
|
6 |
+
# Download
|
7 |
+
def check_and_download_files(repo_id, file_list, local_dir):
|
8 |
+
os.makedirs(local_dir, exist_ok=True)
|
9 |
+
for file in file_list:
|
10 |
+
file_path = os.path.join(local_dir, file)
|
11 |
+
if not os.path.exists(file_path):
|
12 |
+
print(f"{file} 不存在,从 Hugging Face 仓库下载...")
|
13 |
+
hf_hub_download(
|
14 |
+
repo_id=repo_id,
|
15 |
+
filename=file,
|
16 |
+
resume_download=True,
|
17 |
+
local_dir=local_dir,
|
18 |
+
local_dir_use_symlinks=False,
|
19 |
+
)
|
20 |
+
else:
|
21 |
+
print(f"{file} 已存在,跳过下载。")
|
22 |
+
|
23 |
+
|
24 |
+
# 1st
|
25 |
+
repo_id_1 = "fishaudio/fish-speech-1.4"
|
26 |
+
local_dir_1 = "./checkpoints/fish-speech-1.4"
|
27 |
+
files_1 = [
|
28 |
+
"model.pth",
|
29 |
+
"README.md",
|
30 |
+
"special_tokens_map.json",
|
31 |
+
"tokenizer_config.json",
|
32 |
+
"tokenizer.json",
|
33 |
+
"config.json",
|
34 |
+
"firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
35 |
+
]
|
36 |
+
|
37 |
+
# 3rd
|
38 |
+
repo_id_3 = "fishaudio/fish-speech-1"
|
39 |
+
local_dir_3 = "./"
|
40 |
+
files_3 = [
|
41 |
+
"ffmpeg.exe",
|
42 |
+
"ffprobe.exe",
|
43 |
+
]
|
44 |
+
|
45 |
+
# 4th
|
46 |
+
repo_id_4 = "SpicyqSama007/fish-speech-packed"
|
47 |
+
local_dir_4 = "./"
|
48 |
+
files_4 = [
|
49 |
+
"asr-label-win-x64.exe",
|
50 |
+
]
|
51 |
+
|
52 |
+
check_and_download_files(repo_id_1, files_1, local_dir_1)
|
53 |
+
|
54 |
+
check_and_download_files(repo_id_3, files_3, local_dir_3)
|
55 |
+
check_and_download_files(repo_id_4, files_4, local_dir_4)
|
tools/e2e_webui.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import re
|
3 |
+
import wave
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from .fish_e2e import FishE2EAgent, FishE2EEventType
|
9 |
+
from .schema import ServeMessage, ServeTextPart, ServeVQPart
|
10 |
+
|
11 |
+
|
12 |
+
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
13 |
+
buffer = io.BytesIO()
|
14 |
+
|
15 |
+
with wave.open(buffer, "wb") as wav_file:
|
16 |
+
wav_file.setnchannels(channels)
|
17 |
+
wav_file.setsampwidth(bit_depth // 8)
|
18 |
+
wav_file.setframerate(sample_rate)
|
19 |
+
|
20 |
+
wav_header_bytes = buffer.getvalue()
|
21 |
+
buffer.close()
|
22 |
+
return wav_header_bytes
|
23 |
+
|
24 |
+
|
25 |
+
class ChatState:
|
26 |
+
def __init__(self):
|
27 |
+
self.conversation = []
|
28 |
+
self.added_systext = False
|
29 |
+
self.added_sysaudio = False
|
30 |
+
|
31 |
+
def get_history(self):
|
32 |
+
results = []
|
33 |
+
for msg in self.conversation:
|
34 |
+
results.append({"role": msg.role, "content": self.repr_message(msg)})
|
35 |
+
|
36 |
+
# Process assistant messages to extract questions and update user messages
|
37 |
+
for i, msg in enumerate(results):
|
38 |
+
if msg["role"] == "assistant":
|
39 |
+
match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
|
40 |
+
if match and i > 0 and results[i - 1]["role"] == "user":
|
41 |
+
# Update previous user message with extracted question
|
42 |
+
results[i - 1]["content"] += "\n" + match.group(1)
|
43 |
+
# Remove the Question/Answer format from assistant message
|
44 |
+
msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
|
45 |
+
return results
|
46 |
+
|
47 |
+
def repr_message(self, msg: ServeMessage):
|
48 |
+
response = ""
|
49 |
+
for part in msg.parts:
|
50 |
+
if isinstance(part, ServeTextPart):
|
51 |
+
response += part.text
|
52 |
+
elif isinstance(part, ServeVQPart):
|
53 |
+
response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
|
54 |
+
return response
|
55 |
+
|
56 |
+
|
57 |
+
def clear_fn():
|
58 |
+
return [], ChatState(), None, None, None
|
59 |
+
|
60 |
+
|
61 |
+
async def process_audio_input(
|
62 |
+
sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
|
63 |
+
):
|
64 |
+
if audio_input is None and not text_input:
|
65 |
+
raise gr.Error("No input provided")
|
66 |
+
|
67 |
+
agent = FishE2EAgent() # Create new agent instance for each request
|
68 |
+
|
69 |
+
# Convert audio input to numpy array
|
70 |
+
if isinstance(audio_input, tuple):
|
71 |
+
sr, audio_data = audio_input
|
72 |
+
elif text_input:
|
73 |
+
sr = 44100
|
74 |
+
audio_data = None
|
75 |
+
else:
|
76 |
+
raise gr.Error("Invalid audio format")
|
77 |
+
|
78 |
+
if isinstance(sys_audio_input, tuple):
|
79 |
+
sr, sys_audio_data = sys_audio_input
|
80 |
+
else:
|
81 |
+
sr = 44100
|
82 |
+
sys_audio_data = None
|
83 |
+
|
84 |
+
def append_to_chat_ctx(
|
85 |
+
part: ServeTextPart | ServeVQPart, role: str = "assistant"
|
86 |
+
) -> None:
|
87 |
+
if not state.conversation or state.conversation[-1].role != role:
|
88 |
+
state.conversation.append(ServeMessage(role=role, parts=[part]))
|
89 |
+
else:
|
90 |
+
state.conversation[-1].parts.append(part)
|
91 |
+
|
92 |
+
if state.added_systext is False and sys_text_input:
|
93 |
+
state.added_systext = True
|
94 |
+
append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
|
95 |
+
if text_input:
|
96 |
+
append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
|
97 |
+
audio_data = None
|
98 |
+
|
99 |
+
result_audio = b""
|
100 |
+
async for event in agent.stream(
|
101 |
+
sys_audio_data,
|
102 |
+
audio_data,
|
103 |
+
sr,
|
104 |
+
1,
|
105 |
+
chat_ctx={
|
106 |
+
"messages": state.conversation,
|
107 |
+
"added_sysaudio": state.added_sysaudio,
|
108 |
+
},
|
109 |
+
):
|
110 |
+
if event.type == FishE2EEventType.USER_CODES:
|
111 |
+
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
|
112 |
+
elif event.type == FishE2EEventType.SPEECH_SEGMENT:
|
113 |
+
append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
|
114 |
+
yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
|
115 |
+
elif event.type == FishE2EEventType.TEXT_SEGMENT:
|
116 |
+
append_to_chat_ctx(ServeTextPart(text=event.text))
|
117 |
+
yield state.get_history(), None, None, None
|
118 |
+
|
119 |
+
yield state.get_history(), None, None, None
|
120 |
+
|
121 |
+
|
122 |
+
async def process_text_input(
|
123 |
+
sys_audio_input, sys_text_input, state: ChatState, text_input: str
|
124 |
+
):
|
125 |
+
async for event in process_audio_input(
|
126 |
+
sys_audio_input, sys_text_input, None, state, text_input
|
127 |
+
):
|
128 |
+
yield event
|
129 |
+
|
130 |
+
|
131 |
+
def create_demo():
|
132 |
+
with gr.Blocks() as demo:
|
133 |
+
state = gr.State(ChatState())
|
134 |
+
|
135 |
+
with gr.Row():
|
136 |
+
# Left column (70%) for chatbot and notes
|
137 |
+
with gr.Column(scale=7):
|
138 |
+
chatbot = gr.Chatbot(
|
139 |
+
[],
|
140 |
+
elem_id="chatbot",
|
141 |
+
bubble_full_width=False,
|
142 |
+
height=600,
|
143 |
+
type="messages",
|
144 |
+
)
|
145 |
+
|
146 |
+
# notes = gr.Markdown(
|
147 |
+
# """
|
148 |
+
# # Fish Agent
|
149 |
+
# 1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
|
150 |
+
# 2. 你可以在我们的官方仓��找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
|
151 |
+
# 3. Demo为早期灰度测试版本,推理速度尚待优化.
|
152 |
+
# # 特色
|
153 |
+
# 1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
|
154 |
+
# 2. 模型可以使用reference audio控制说话音色.
|
155 |
+
# 3. 可以生成具有较强情感与韵律的音频.
|
156 |
+
# """
|
157 |
+
# )
|
158 |
+
notes = gr.Markdown(
|
159 |
+
"""
|
160 |
+
# Fish Agent
|
161 |
+
1. This demo is Fish Audio's self-researh end-to-end language model, Fish Agent version 3B.
|
162 |
+
2. You can find the code and weights in our official repo in [gitub](https://github.com/fishaudio/fish-speech) and [hugging face](https://huggingface.co/fishaudio/fish-agent-v0.1-3b), but the content is released under a CC BY-NC-SA 4.0 licence.
|
163 |
+
3. The demo is an early alpha test version, the inference speed needs to be optimised.
|
164 |
+
# Features
|
165 |
+
1. The model automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
|
166 |
+
2. The model can use reference audio to control the speech timbre.
|
167 |
+
3. The model can generate speech with strong emotion.
|
168 |
+
"""
|
169 |
+
)
|
170 |
+
|
171 |
+
# Right column (30%) for controls
|
172 |
+
with gr.Column(scale=3):
|
173 |
+
sys_audio_input = gr.Audio(
|
174 |
+
sources=["upload"],
|
175 |
+
type="numpy",
|
176 |
+
label="Give a timbre for your assistant",
|
177 |
+
)
|
178 |
+
sys_text_input = gr.Textbox(
|
179 |
+
label="What is your assistant's role?",
|
180 |
+
value="You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user's speech, then answer it in the following format: 'Question: [USER_SPEECH]\n\nAnswer: [YOUR_RESPONSE]\n'. You are required to use the following voice in this conversation.",
|
181 |
+
type="text",
|
182 |
+
)
|
183 |
+
audio_input = gr.Audio(
|
184 |
+
sources=["microphone"], type="numpy", label="Speak your message"
|
185 |
+
)
|
186 |
+
|
187 |
+
text_input = gr.Textbox(label="Or type your message", type="text")
|
188 |
+
|
189 |
+
output_audio = gr.Audio(
|
190 |
+
label="Assistant's Voice",
|
191 |
+
streaming=True,
|
192 |
+
autoplay=True,
|
193 |
+
interactive=False,
|
194 |
+
)
|
195 |
+
|
196 |
+
send_button = gr.Button("Send", variant="primary")
|
197 |
+
clear_button = gr.Button("Clear")
|
198 |
+
|
199 |
+
# Event handlers
|
200 |
+
audio_input.stop_recording(
|
201 |
+
process_audio_input,
|
202 |
+
inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
|
203 |
+
outputs=[chatbot, output_audio, audio_input, text_input],
|
204 |
+
show_progress=True,
|
205 |
+
)
|
206 |
+
|
207 |
+
send_button.click(
|
208 |
+
process_text_input,
|
209 |
+
inputs=[sys_audio_input, sys_text_input, state, text_input],
|
210 |
+
outputs=[chatbot, output_audio, audio_input, text_input],
|
211 |
+
show_progress=True,
|
212 |
+
)
|
213 |
+
|
214 |
+
text_input.submit(
|
215 |
+
process_text_input,
|
216 |
+
inputs=[sys_audio_input, sys_text_input, state, text_input],
|
217 |
+
outputs=[chatbot, output_audio, audio_input, text_input],
|
218 |
+
show_progress=True,
|
219 |
+
)
|
220 |
+
|
221 |
+
clear_button.click(
|
222 |
+
clear_fn,
|
223 |
+
inputs=[],
|
224 |
+
outputs=[chatbot, state, audio_input, output_audio, text_input],
|
225 |
+
)
|
226 |
+
|
227 |
+
return demo
|
228 |
+
|
229 |
+
|
230 |
+
if __name__ == "__main__":
|
231 |
+
demo = create_demo()
|
232 |
+
demo.launch(server_name="127.0.0.1", server_port=7860, share=True)
|
tools/extract_model.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import click
|
2 |
+
import torch
|
3 |
+
from loguru import logger
|
4 |
+
|
5 |
+
|
6 |
+
@click.command()
|
7 |
+
@click.argument("model_path")
|
8 |
+
@click.argument("output_path")
|
9 |
+
def main(model_path, output_path):
|
10 |
+
if model_path == output_path:
|
11 |
+
logger.error("Model path and output path are the same")
|
12 |
+
return
|
13 |
+
|
14 |
+
logger.info(f"Loading model from {model_path}")
|
15 |
+
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
|
16 |
+
torch.save(state_dict, output_path)
|
17 |
+
logger.info(f"Model saved to {output_path}")
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
main()
|
tools/file.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from loguru import logger
|
6 |
+
from natsort import natsorted
|
7 |
+
|
8 |
+
AUDIO_EXTENSIONS = {
|
9 |
+
".mp3",
|
10 |
+
".wav",
|
11 |
+
".flac",
|
12 |
+
".ogg",
|
13 |
+
".m4a",
|
14 |
+
".wma",
|
15 |
+
".aac",
|
16 |
+
".aiff",
|
17 |
+
".aif",
|
18 |
+
".aifc",
|
19 |
+
}
|
20 |
+
|
21 |
+
VIDEO_EXTENSIONS = {
|
22 |
+
".mp4",
|
23 |
+
".avi",
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def audio_to_bytes(file_path):
|
28 |
+
if not file_path or not Path(file_path).exists():
|
29 |
+
return None
|
30 |
+
with open(file_path, "rb") as wav_file:
|
31 |
+
wav = wav_file.read()
|
32 |
+
return wav
|
33 |
+
|
34 |
+
|
35 |
+
def read_ref_text(ref_text):
|
36 |
+
path = Path(ref_text)
|
37 |
+
if path.exists() and path.is_file():
|
38 |
+
with path.open("r", encoding="utf-8") as file:
|
39 |
+
return file.read()
|
40 |
+
return ref_text
|
41 |
+
|
42 |
+
|
43 |
+
def list_files(
|
44 |
+
path: Union[Path, str],
|
45 |
+
extensions: set[str] = None,
|
46 |
+
recursive: bool = False,
|
47 |
+
sort: bool = True,
|
48 |
+
) -> list[Path]:
|
49 |
+
"""List files in a directory.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
path (Path): Path to the directory.
|
53 |
+
extensions (set, optional): Extensions to filter. Defaults to None.
|
54 |
+
recursive (bool, optional): Whether to search recursively. Defaults to False.
|
55 |
+
sort (bool, optional): Whether to sort the files. Defaults to True.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
list: List of files.
|
59 |
+
"""
|
60 |
+
|
61 |
+
if isinstance(path, str):
|
62 |
+
path = Path(path)
|
63 |
+
|
64 |
+
if not path.exists():
|
65 |
+
raise FileNotFoundError(f"Directory {path} does not exist.")
|
66 |
+
|
67 |
+
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
|
68 |
+
|
69 |
+
if sort:
|
70 |
+
files = natsorted(files)
|
71 |
+
|
72 |
+
return files
|
73 |
+
|
74 |
+
|
75 |
+
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
|
76 |
+
"""
|
77 |
+
Load a Bert-VITS2 style filelist.
|
78 |
+
"""
|
79 |
+
|
80 |
+
files = set()
|
81 |
+
results = []
|
82 |
+
count_duplicated, count_not_found = 0, 0
|
83 |
+
|
84 |
+
LANGUAGE_TO_LANGUAGES = {
|
85 |
+
"zh": ["zh", "en"],
|
86 |
+
"jp": ["jp", "en"],
|
87 |
+
"en": ["en"],
|
88 |
+
}
|
89 |
+
|
90 |
+
with open(path, "r", encoding="utf-8") as f:
|
91 |
+
for line in f.readlines():
|
92 |
+
splits = line.strip().split("|", maxsplit=3)
|
93 |
+
if len(splits) != 4:
|
94 |
+
logger.warning(f"Invalid line: {line}")
|
95 |
+
continue
|
96 |
+
|
97 |
+
filename, speaker, language, text = splits
|
98 |
+
file = Path(filename)
|
99 |
+
language = language.strip().lower()
|
100 |
+
|
101 |
+
if language == "ja":
|
102 |
+
language = "jp"
|
103 |
+
|
104 |
+
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
|
105 |
+
languages = LANGUAGE_TO_LANGUAGES[language]
|
106 |
+
|
107 |
+
if file in files:
|
108 |
+
logger.warning(f"Duplicated file: {file}")
|
109 |
+
count_duplicated += 1
|
110 |
+
continue
|
111 |
+
|
112 |
+
if not file.exists():
|
113 |
+
logger.warning(f"File not found: {file}")
|
114 |
+
count_not_found += 1
|
115 |
+
continue
|
116 |
+
|
117 |
+
results.append((file, speaker, languages, text))
|
118 |
+
|
119 |
+
if count_duplicated > 0:
|
120 |
+
logger.warning(f"Total duplicated files: {count_duplicated}")
|
121 |
+
|
122 |
+
if count_not_found > 0:
|
123 |
+
logger.warning(f"Total files not found: {count_not_found}")
|
124 |
+
|
125 |
+
return results
|
tools/fish_e2e.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import ctypes
|
3 |
+
import io
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import struct
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from enum import Enum
|
9 |
+
from typing import AsyncGenerator, Union
|
10 |
+
|
11 |
+
import httpx
|
12 |
+
import numpy as np
|
13 |
+
import ormsgpack
|
14 |
+
import soundfile as sf
|
15 |
+
|
16 |
+
from .schema import (
|
17 |
+
ServeMessage,
|
18 |
+
ServeRequest,
|
19 |
+
ServeTextPart,
|
20 |
+
ServeVQGANDecodeRequest,
|
21 |
+
ServeVQGANEncodeRequest,
|
22 |
+
ServeVQPart,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class CustomAudioFrame:
|
27 |
+
def __init__(self, data, sample_rate, num_channels, samples_per_channel):
|
28 |
+
if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
|
29 |
+
ctypes.c_int16
|
30 |
+
):
|
31 |
+
raise ValueError(
|
32 |
+
"data length must be >= num_channels * samples_per_channel * sizeof(int16)"
|
33 |
+
)
|
34 |
+
|
35 |
+
self._data = bytearray(data)
|
36 |
+
self._sample_rate = sample_rate
|
37 |
+
self._num_channels = num_channels
|
38 |
+
self._samples_per_channel = samples_per_channel
|
39 |
+
|
40 |
+
@property
|
41 |
+
def data(self):
|
42 |
+
return memoryview(self._data).cast("h")
|
43 |
+
|
44 |
+
@property
|
45 |
+
def sample_rate(self):
|
46 |
+
return self._sample_rate
|
47 |
+
|
48 |
+
@property
|
49 |
+
def num_channels(self):
|
50 |
+
return self._num_channels
|
51 |
+
|
52 |
+
@property
|
53 |
+
def samples_per_channel(self):
|
54 |
+
return self._samples_per_channel
|
55 |
+
|
56 |
+
@property
|
57 |
+
def duration(self):
|
58 |
+
return self.samples_per_channel / self.sample_rate
|
59 |
+
|
60 |
+
def __repr__(self):
|
61 |
+
return (
|
62 |
+
f"CustomAudioFrame(sample_rate={self.sample_rate}, "
|
63 |
+
f"num_channels={self.num_channels}, "
|
64 |
+
f"samples_per_channel={self.samples_per_channel}, "
|
65 |
+
f"duration={self.duration:.3f})"
|
66 |
+
)
|
67 |
+
|
68 |
+
|
69 |
+
class FishE2EEventType(Enum):
|
70 |
+
SPEECH_SEGMENT = 1
|
71 |
+
TEXT_SEGMENT = 2
|
72 |
+
END_OF_TEXT = 3
|
73 |
+
END_OF_SPEECH = 4
|
74 |
+
ASR_RESULT = 5
|
75 |
+
USER_CODES = 6
|
76 |
+
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class FishE2EEvent:
|
80 |
+
type: FishE2EEventType
|
81 |
+
frame: np.ndarray = None
|
82 |
+
text: str = None
|
83 |
+
vq_codes: list[list[int]] = None
|
84 |
+
|
85 |
+
|
86 |
+
client = httpx.AsyncClient(
|
87 |
+
timeout=None,
|
88 |
+
limits=httpx.Limits(
|
89 |
+
max_connections=None,
|
90 |
+
max_keepalive_connections=None,
|
91 |
+
keepalive_expiry=None,
|
92 |
+
),
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
class FishE2EAgent:
|
97 |
+
def __init__(self):
|
98 |
+
self.llm_url = "http://localhost:8080/v1/chat"
|
99 |
+
self.vqgan_url = "http://localhost:8080"
|
100 |
+
self.client = httpx.AsyncClient(timeout=None)
|
101 |
+
|
102 |
+
async def get_codes(self, audio_data, sample_rate):
|
103 |
+
audio_buffer = io.BytesIO()
|
104 |
+
sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
|
105 |
+
audio_buffer.seek(0)
|
106 |
+
# Step 1: Encode audio using VQGAN
|
107 |
+
encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
|
108 |
+
encode_request_bytes = ormsgpack.packb(
|
109 |
+
encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
110 |
+
)
|
111 |
+
encode_response = await self.client.post(
|
112 |
+
f"{self.vqgan_url}/v1/vqgan/encode",
|
113 |
+
data=encode_request_bytes,
|
114 |
+
headers={"Content-Type": "application/msgpack"},
|
115 |
+
)
|
116 |
+
encode_response_data = ormsgpack.unpackb(encode_response.content)
|
117 |
+
codes = encode_response_data["tokens"][0]
|
118 |
+
return codes
|
119 |
+
|
120 |
+
async def stream(
|
121 |
+
self,
|
122 |
+
system_audio_data: np.ndarray | None,
|
123 |
+
user_audio_data: np.ndarray | None,
|
124 |
+
sample_rate: int,
|
125 |
+
num_channels: int,
|
126 |
+
chat_ctx: dict | None = None,
|
127 |
+
) -> AsyncGenerator[bytes, None]:
|
128 |
+
|
129 |
+
if system_audio_data is not None:
|
130 |
+
sys_codes = await self.get_codes(system_audio_data, sample_rate)
|
131 |
+
else:
|
132 |
+
sys_codes = None
|
133 |
+
if user_audio_data is not None:
|
134 |
+
user_codes = await self.get_codes(user_audio_data, sample_rate)
|
135 |
+
# Step 2: Prepare LLM request
|
136 |
+
if chat_ctx is None:
|
137 |
+
sys_parts = [
|
138 |
+
ServeTextPart(
|
139 |
+
text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
|
140 |
+
),
|
141 |
+
]
|
142 |
+
if system_audio_data is not None:
|
143 |
+
sys_parts.append(ServeVQPart(codes=sys_codes))
|
144 |
+
chat_ctx = {
|
145 |
+
"messages": [
|
146 |
+
ServeMessage(
|
147 |
+
role="system",
|
148 |
+
parts=sys_parts,
|
149 |
+
),
|
150 |
+
],
|
151 |
+
}
|
152 |
+
else:
|
153 |
+
if chat_ctx["added_sysaudio"] is False and sys_codes:
|
154 |
+
chat_ctx["added_sysaudio"] = True
|
155 |
+
chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
|
156 |
+
|
157 |
+
prev_messages = chat_ctx["messages"].copy()
|
158 |
+
if user_audio_data is not None:
|
159 |
+
yield FishE2EEvent(
|
160 |
+
type=FishE2EEventType.USER_CODES,
|
161 |
+
vq_codes=user_codes,
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
user_codes = None
|
165 |
+
|
166 |
+
request = ServeRequest(
|
167 |
+
messages=prev_messages
|
168 |
+
+ (
|
169 |
+
[
|
170 |
+
ServeMessage(
|
171 |
+
role="user",
|
172 |
+
parts=[ServeVQPart(codes=user_codes)],
|
173 |
+
)
|
174 |
+
]
|
175 |
+
if user_codes
|
176 |
+
else []
|
177 |
+
),
|
178 |
+
streaming=True,
|
179 |
+
num_samples=1,
|
180 |
+
)
|
181 |
+
|
182 |
+
# Step 3: Stream LLM response and decode audio
|
183 |
+
buffer = b""
|
184 |
+
vq_codes = []
|
185 |
+
current_vq = False
|
186 |
+
|
187 |
+
async def decode_send():
|
188 |
+
nonlocal current_vq
|
189 |
+
nonlocal vq_codes
|
190 |
+
|
191 |
+
data = np.concatenate(vq_codes, axis=1).tolist()
|
192 |
+
# Decode VQ codes to audio
|
193 |
+
decode_request = ServeVQGANDecodeRequest(tokens=[data])
|
194 |
+
decode_response = await self.client.post(
|
195 |
+
f"{self.vqgan_url}/v1/vqgan/decode",
|
196 |
+
data=ormsgpack.packb(
|
197 |
+
decode_request,
|
198 |
+
option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
199 |
+
),
|
200 |
+
headers={"Content-Type": "application/msgpack"},
|
201 |
+
)
|
202 |
+
decode_data = ormsgpack.unpackb(decode_response.content)
|
203 |
+
|
204 |
+
# Convert float16 audio data to int16
|
205 |
+
audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
|
206 |
+
audio_data = (audio_data * 32768).astype(np.int16).tobytes()
|
207 |
+
|
208 |
+
audio_frame = CustomAudioFrame(
|
209 |
+
data=audio_data,
|
210 |
+
samples_per_channel=len(audio_data) // 2,
|
211 |
+
sample_rate=44100,
|
212 |
+
num_channels=1,
|
213 |
+
)
|
214 |
+
yield FishE2EEvent(
|
215 |
+
type=FishE2EEventType.SPEECH_SEGMENT,
|
216 |
+
frame=audio_frame,
|
217 |
+
vq_codes=data,
|
218 |
+
)
|
219 |
+
|
220 |
+
current_vq = False
|
221 |
+
vq_codes = []
|
222 |
+
|
223 |
+
async with self.client.stream(
|
224 |
+
"POST",
|
225 |
+
self.llm_url,
|
226 |
+
data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
227 |
+
headers={"Content-Type": "application/msgpack"},
|
228 |
+
) as response:
|
229 |
+
|
230 |
+
async for chunk in response.aiter_bytes():
|
231 |
+
buffer += chunk
|
232 |
+
|
233 |
+
while len(buffer) >= 4:
|
234 |
+
read_length = struct.unpack("I", buffer[:4])[0]
|
235 |
+
if len(buffer) < 4 + read_length:
|
236 |
+
break
|
237 |
+
|
238 |
+
body = buffer[4 : 4 + read_length]
|
239 |
+
buffer = buffer[4 + read_length :]
|
240 |
+
data = ormsgpack.unpackb(body)
|
241 |
+
|
242 |
+
if data["delta"] and data["delta"]["part"]:
|
243 |
+
if current_vq and data["delta"]["part"]["type"] == "text":
|
244 |
+
async for event in decode_send():
|
245 |
+
yield event
|
246 |
+
if data["delta"]["part"]["type"] == "text":
|
247 |
+
yield FishE2EEvent(
|
248 |
+
type=FishE2EEventType.TEXT_SEGMENT,
|
249 |
+
text=data["delta"]["part"]["text"],
|
250 |
+
)
|
251 |
+
elif data["delta"]["part"]["type"] == "vq":
|
252 |
+
vq_codes.append(np.array(data["delta"]["part"]["codes"]))
|
253 |
+
current_vq = True
|
254 |
+
|
255 |
+
if current_vq and vq_codes:
|
256 |
+
async for event in decode_send():
|
257 |
+
yield event
|
258 |
+
|
259 |
+
yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
|
260 |
+
yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
|
261 |
+
|
262 |
+
|
263 |
+
# Example usage:
|
264 |
+
async def main():
|
265 |
+
import torchaudio
|
266 |
+
|
267 |
+
agent = FishE2EAgent()
|
268 |
+
|
269 |
+
# Replace this with actual audio data loading
|
270 |
+
with open("uz_story_en.m4a", "rb") as f:
|
271 |
+
audio_data = f.read()
|
272 |
+
|
273 |
+
audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
|
274 |
+
audio_data = (audio_data.numpy() * 32768).astype(np.int16)
|
275 |
+
|
276 |
+
stream = agent.stream(audio_data, sample_rate, 1)
|
277 |
+
if os.path.exists("audio_segment.wav"):
|
278 |
+
os.remove("audio_segment.wav")
|
279 |
+
|
280 |
+
async for event in stream:
|
281 |
+
if event.type == FishE2EEventType.SPEECH_SEGMENT:
|
282 |
+
# Handle speech segment (e.g., play audio or save to file)
|
283 |
+
with open("audio_segment.wav", "ab+") as f:
|
284 |
+
f.write(event.frame.data)
|
285 |
+
elif event.type == FishE2EEventType.ASR_RESULT:
|
286 |
+
print(event.text, flush=True)
|
287 |
+
elif event.type == FishE2EEventType.TEXT_SEGMENT:
|
288 |
+
print(event.text, flush=True, end="")
|
289 |
+
elif event.type == FishE2EEventType.END_OF_TEXT:
|
290 |
+
print("\nEnd of text reached.")
|
291 |
+
elif event.type == FishE2EEventType.END_OF_SPEECH:
|
292 |
+
print("End of speech reached.")
|
293 |
+
|
294 |
+
|
295 |
+
if __name__ == "__main__":
|
296 |
+
import asyncio
|
297 |
+
|
298 |
+
asyncio.run(main())
|
tools/llama/__pycache__/generate.cpython-310.pyc
ADDED
Binary file (21.1 kB). View file
|
|
tools/llama/build_dataset.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from collections import defaultdict
|
5 |
+
from functools import partial
|
6 |
+
from multiprocessing import Pool
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import click
|
10 |
+
import numpy as np
|
11 |
+
from loguru import logger
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
15 |
+
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
16 |
+
from tools.file import load_filelist
|
17 |
+
|
18 |
+
# To avoid CPU overload
|
19 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
20 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
21 |
+
|
22 |
+
|
23 |
+
def task_generator_folder(root: Path, text_extension: str):
|
24 |
+
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
|
25 |
+
files = sorted(files)
|
26 |
+
|
27 |
+
grouped_files = defaultdict(list)
|
28 |
+
for file in tqdm(files, desc=f"Grouping {root}"):
|
29 |
+
p = str(file.parent)
|
30 |
+
speaker = file.parent.name
|
31 |
+
|
32 |
+
try:
|
33 |
+
if isinstance(text_extension, str):
|
34 |
+
texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
|
35 |
+
else:
|
36 |
+
texts = [
|
37 |
+
file.with_suffix(ext).read_text(encoding="utf-8")
|
38 |
+
for ext in text_extension
|
39 |
+
]
|
40 |
+
except Exception as e:
|
41 |
+
logger.error(f"Failed to read text {file}: {e}")
|
42 |
+
continue
|
43 |
+
|
44 |
+
grouped_files[p].append((speaker, file, texts))
|
45 |
+
|
46 |
+
logger.info(
|
47 |
+
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
|
48 |
+
)
|
49 |
+
|
50 |
+
for i in grouped_files.values():
|
51 |
+
subset = [(f, t) for _, f, t in i]
|
52 |
+
yield i[0][0], subset, "folder"
|
53 |
+
|
54 |
+
|
55 |
+
def task_generator_filelist(filelist):
|
56 |
+
grouped_files = defaultdict(list)
|
57 |
+
for filename, speaker, _, text in load_filelist(filelist):
|
58 |
+
grouped_files[speaker].append((Path(filename), [text]))
|
59 |
+
|
60 |
+
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
|
61 |
+
for speaker, values in grouped_files.items():
|
62 |
+
yield speaker, values, "filelist"
|
63 |
+
|
64 |
+
|
65 |
+
def run_task(task):
|
66 |
+
name, subset, source = task
|
67 |
+
|
68 |
+
# Parse the files
|
69 |
+
sentences = []
|
70 |
+
for file, texts in subset:
|
71 |
+
np_file = file.with_suffix(".npy")
|
72 |
+
if np_file.exists() is False:
|
73 |
+
logger.warning(f"Can't find {np_file}")
|
74 |
+
continue
|
75 |
+
|
76 |
+
new_texts = []
|
77 |
+
|
78 |
+
for text in texts:
|
79 |
+
# Simple cleaning: replace { xxx } and < xxx > with space
|
80 |
+
text = re.sub(r"\{.*?\}", " ", text)
|
81 |
+
text = re.sub(r"<.*?>", " ", text)
|
82 |
+
text = re.sub(r"\s+", " ", text)
|
83 |
+
new_texts.append(text)
|
84 |
+
|
85 |
+
try:
|
86 |
+
semantics = np.load(np_file)
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(f"Failed to parse {file}: {e}")
|
89 |
+
continue
|
90 |
+
|
91 |
+
if isinstance(semantics, np.ndarray):
|
92 |
+
semantics = semantics.tolist()
|
93 |
+
|
94 |
+
sentences.append(
|
95 |
+
Sentence(
|
96 |
+
texts=new_texts,
|
97 |
+
semantics=[Semantics(values=s) for s in semantics],
|
98 |
+
)
|
99 |
+
)
|
100 |
+
|
101 |
+
# Pack the sentences
|
102 |
+
return pack_pb_stream(
|
103 |
+
TextData(
|
104 |
+
source=source,
|
105 |
+
name=name,
|
106 |
+
sentences=sentences,
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
@click.command()
|
112 |
+
@click.option(
|
113 |
+
"--input",
|
114 |
+
type=click.Path(path_type=Path),
|
115 |
+
required=True,
|
116 |
+
help="A folder containing the dataset or a filelist",
|
117 |
+
multiple=True,
|
118 |
+
)
|
119 |
+
@click.option(
|
120 |
+
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
|
121 |
+
)
|
122 |
+
@click.option("--num-workers", type=int, default=16)
|
123 |
+
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
|
124 |
+
@click.option(
|
125 |
+
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
|
126 |
+
)
|
127 |
+
def main(input, output, num_workers, text_extension, shard_size):
|
128 |
+
generator_fns = []
|
129 |
+
|
130 |
+
for f in input:
|
131 |
+
assert f.exists(), f"{f} not found"
|
132 |
+
|
133 |
+
if f.is_dir():
|
134 |
+
generator_fn = task_generator_folder(f, text_extension)
|
135 |
+
else:
|
136 |
+
generator_fn = task_generator_filelist(f)
|
137 |
+
|
138 |
+
generator_fns.append(generator_fn)
|
139 |
+
|
140 |
+
generator_fn = itertools.chain(*generator_fns)
|
141 |
+
output.mkdir(parents=True, exist_ok=True)
|
142 |
+
|
143 |
+
dataset_fp = None
|
144 |
+
tar_idx = 0
|
145 |
+
written_size = 0
|
146 |
+
|
147 |
+
with Pool(num_workers) as p:
|
148 |
+
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
|
149 |
+
if dataset_fp is None:
|
150 |
+
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
|
151 |
+
|
152 |
+
dataset_fp.write(result)
|
153 |
+
written_size += len(result)
|
154 |
+
|
155 |
+
if written_size > shard_size * 1024 * 1024:
|
156 |
+
logger.info(f"Finished writing {tar_idx} shards to {output}")
|
157 |
+
dataset_fp.close()
|
158 |
+
dataset_fp = None
|
159 |
+
written_size = 0
|
160 |
+
tar_idx += 1
|
161 |
+
|
162 |
+
if dataset_fp is not None:
|
163 |
+
dataset_fp.close()
|
164 |
+
|
165 |
+
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
main()
|
tools/llama/eval_in_context.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
|
7 |
+
# register eval resolver and root
|
8 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
9 |
+
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
|
12 |
+
from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
|
13 |
+
from tools.llama.generate import load_model
|
14 |
+
|
15 |
+
|
16 |
+
def smooth(
|
17 |
+
scalars: list[float], weight: float
|
18 |
+
) -> list[float]: # Weight between 0 and 1
|
19 |
+
last = scalars[0] # First value in the plot (first timestep)
|
20 |
+
smoothed = list()
|
21 |
+
for point in scalars:
|
22 |
+
smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
|
23 |
+
smoothed.append(smoothed_val) # Save it
|
24 |
+
last = smoothed_val # Anchor the last smoothed value
|
25 |
+
|
26 |
+
return smoothed
|
27 |
+
|
28 |
+
|
29 |
+
@torch.inference_mode()
|
30 |
+
def analyze_one_model(loader, config, weight, max_length):
|
31 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
+
model = load_model(
|
33 |
+
config,
|
34 |
+
weight,
|
35 |
+
device,
|
36 |
+
torch.bfloat16,
|
37 |
+
max_length,
|
38 |
+
compile=False,
|
39 |
+
)[0]
|
40 |
+
|
41 |
+
current_step = 0
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
semantic_loss_sum = torch.zeros(
|
45 |
+
max_length,
|
46 |
+
dtype=torch.float32,
|
47 |
+
device=device,
|
48 |
+
)
|
49 |
+
counter = torch.zeros(
|
50 |
+
max_length,
|
51 |
+
dtype=torch.long,
|
52 |
+
device=device,
|
53 |
+
)
|
54 |
+
|
55 |
+
for batch in loader:
|
56 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
57 |
+
|
58 |
+
labels = batch["labels"]
|
59 |
+
outputs = model(
|
60 |
+
inp=batch["inputs"],
|
61 |
+
key_padding_mask=batch["attention_masks"],
|
62 |
+
)
|
63 |
+
|
64 |
+
token_logits = outputs.token_logits
|
65 |
+
codebook_logits = outputs.codebook_logits
|
66 |
+
|
67 |
+
# Generate labels
|
68 |
+
base_loss = F.cross_entropy(
|
69 |
+
token_logits.reshape(-1, token_logits.size(-1)),
|
70 |
+
labels[:, 0].reshape(-1),
|
71 |
+
ignore_index=-100,
|
72 |
+
reduction="none",
|
73 |
+
)
|
74 |
+
|
75 |
+
codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
|
76 |
+
semantic_loss = F.cross_entropy(
|
77 |
+
codebook_logits.reshape(-1, codebook_logits.size(-1)),
|
78 |
+
codebook_labels.reshape(-1),
|
79 |
+
ignore_index=-100,
|
80 |
+
reduction="none",
|
81 |
+
)
|
82 |
+
|
83 |
+
base_loss = base_loss.reshape(labels[:, 0].shape)
|
84 |
+
semantic_loss = semantic_loss.reshape(codebook_labels.shape)
|
85 |
+
|
86 |
+
semantic_loss_frame = semantic_loss.mean(-1)
|
87 |
+
pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
|
88 |
+
|
89 |
+
for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
|
90 |
+
semantic_loss_sum[~pad] += loss_sample[~pad]
|
91 |
+
counter[~pad] += 1
|
92 |
+
|
93 |
+
current_step += 1
|
94 |
+
if current_step == 10:
|
95 |
+
break
|
96 |
+
|
97 |
+
semantic_loss = semantic_loss.cpu()
|
98 |
+
counter = counter.cpu()
|
99 |
+
xs, ys = [], []
|
100 |
+
|
101 |
+
for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
|
102 |
+
if count > 0:
|
103 |
+
xs.append(i)
|
104 |
+
ys.append((loss / count).item()) # for better loss visualization
|
105 |
+
|
106 |
+
smoothed_ys = smooth(ys, 0.95)
|
107 |
+
|
108 |
+
# Unload model
|
109 |
+
del model
|
110 |
+
torch.cuda.empty_cache()
|
111 |
+
|
112 |
+
return xs, ys, smoothed_ys
|
113 |
+
|
114 |
+
|
115 |
+
def main():
|
116 |
+
tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
|
117 |
+
max_length = 4096
|
118 |
+
|
119 |
+
ds = AutoAugTextDataset(
|
120 |
+
["data/protos/sft/云天河"],
|
121 |
+
tokenizer=tokenizer,
|
122 |
+
use_speaker=False,
|
123 |
+
interactive_prob=1.0,
|
124 |
+
max_length=max_length,
|
125 |
+
)
|
126 |
+
|
127 |
+
loader = DataLoader(
|
128 |
+
ds,
|
129 |
+
batch_size=8,
|
130 |
+
collate_fn=TextDataCollator(tokenizer, max_length=max_length),
|
131 |
+
num_workers=0,
|
132 |
+
shuffle=False,
|
133 |
+
)
|
134 |
+
|
135 |
+
plt.figure(figsize=(10, 5), dpi=200)
|
136 |
+
|
137 |
+
plt.xlabel("Frame")
|
138 |
+
plt.ylabel("Loss")
|
139 |
+
plt.yscale("log")
|
140 |
+
plt.title("Semantic Loss")
|
141 |
+
plt.grid(which="both", axis="both")
|
142 |
+
plt.xlim(0, max_length)
|
143 |
+
|
144 |
+
tests = [
|
145 |
+
(
|
146 |
+
"pertrain-medium",
|
147 |
+
"dual_ar_2_codebook_medium",
|
148 |
+
"checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
|
149 |
+
),
|
150 |
+
(
|
151 |
+
"sft-medium",
|
152 |
+
"dual_ar_2_codebook_medium",
|
153 |
+
"checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
|
154 |
+
),
|
155 |
+
(
|
156 |
+
"sft-large",
|
157 |
+
"dual_ar_2_codebook_large",
|
158 |
+
"checkpoints/text2semantic-sft-large-v1.1-4k.pth",
|
159 |
+
),
|
160 |
+
]
|
161 |
+
|
162 |
+
for name, config, weight in tests:
|
163 |
+
xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
|
164 |
+
plt.plot(xs, smoothed_ys, label=name)
|
165 |
+
|
166 |
+
plt.legend()
|
167 |
+
plt.savefig("semantic_loss.png")
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
main()
|
tools/llama/generate.py
ADDED
@@ -0,0 +1,1087 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import queue
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
from contextlib import nullcontext
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Literal, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import click
|
11 |
+
import hydra
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch._dynamo.config
|
15 |
+
import torch._inductor.config
|
16 |
+
from loguru import logger
|
17 |
+
from tqdm import tqdm
|
18 |
+
from transformers import AutoTokenizer
|
19 |
+
|
20 |
+
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
21 |
+
from fish_speech.models.text2semantic.llama import BaseModelArgs
|
22 |
+
from fish_speech.text import clean_text, split_text
|
23 |
+
|
24 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
25 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
26 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
27 |
+
|
28 |
+
if hasattr(torch._inductor.config, "fx_graph_cache"):
|
29 |
+
# Experimental feature to reduce compilation times, will be on by default in future
|
30 |
+
torch._inductor.config.fx_graph_cache = True
|
31 |
+
|
32 |
+
|
33 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
34 |
+
|
35 |
+
from fish_speech.models.text2semantic.llama import (
|
36 |
+
BaseTransformer,
|
37 |
+
DualARTransformer,
|
38 |
+
NaiveTransformer,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def multinomial_sample_one_no_sync(
|
43 |
+
probs_sort,
|
44 |
+
): # Does multinomial sampling without a cuda synchronization
|
45 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
46 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
47 |
+
|
48 |
+
|
49 |
+
def logits_to_probs(
|
50 |
+
logits,
|
51 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
52 |
+
temperature: torch.Tensor = 1.0,
|
53 |
+
top_p: torch.Tensor = 1.0,
|
54 |
+
repetition_penalty: torch.Tensor = 1.0,
|
55 |
+
) -> torch.Tensor:
|
56 |
+
# Apply repetition penalty
|
57 |
+
if previous_tokens is not None:
|
58 |
+
previous_tokens = previous_tokens.long()
|
59 |
+
score = torch.gather(logits, dim=0, index=previous_tokens)
|
60 |
+
score = torch.where(
|
61 |
+
score < 0, score * repetition_penalty, score / repetition_penalty
|
62 |
+
)
|
63 |
+
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
64 |
+
|
65 |
+
# Apply top-p sampling
|
66 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
67 |
+
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
68 |
+
sorted_indices_to_remove = cum_probs > top_p
|
69 |
+
sorted_indices_to_remove[0] = False # keep at least one option
|
70 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
71 |
+
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
72 |
+
)
|
73 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
74 |
+
|
75 |
+
logits = logits / max(temperature, 1e-5)
|
76 |
+
|
77 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
78 |
+
return probs
|
79 |
+
|
80 |
+
|
81 |
+
def multinomial_sample_one_no_sync_agent(
|
82 |
+
probs_sort,
|
83 |
+
): # Does multinomial sampling without a cuda synchronization
|
84 |
+
q = torch.empty_like(probs_sort).exponential_(1)
|
85 |
+
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
86 |
+
|
87 |
+
|
88 |
+
def logits_to_probs_agent(
|
89 |
+
logits,
|
90 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
91 |
+
temperature: torch.Tensor = 1.0,
|
92 |
+
top_p: torch.Tensor = 1.0,
|
93 |
+
repetition_penalty: torch.Tensor = 1.0,
|
94 |
+
) -> torch.Tensor:
|
95 |
+
# Apply repetition penalty
|
96 |
+
if previous_tokens is not None:
|
97 |
+
previous_tokens = previous_tokens.long()
|
98 |
+
score = torch.gather(logits, dim=-1, index=previous_tokens)
|
99 |
+
score = torch.where(
|
100 |
+
score < 0, score * repetition_penalty, score / repetition_penalty
|
101 |
+
)
|
102 |
+
logits.scatter_(dim=-1, index=previous_tokens, src=score)
|
103 |
+
|
104 |
+
# Apply top-p sampling
|
105 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
106 |
+
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
107 |
+
sorted_indices_to_remove = cum_probs > top_p
|
108 |
+
sorted_indices_to_remove[..., 0] = False # keep at least one option
|
109 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
110 |
+
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
111 |
+
)
|
112 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
113 |
+
|
114 |
+
logits = logits / max(temperature, 1e-5)
|
115 |
+
|
116 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
117 |
+
return probs
|
118 |
+
|
119 |
+
|
120 |
+
def sample(
|
121 |
+
logits,
|
122 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
123 |
+
**sampling_kwargs,
|
124 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
125 |
+
probs = logits_to_probs(
|
126 |
+
logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
|
127 |
+
)
|
128 |
+
idx_next = multinomial_sample_one_no_sync(probs)
|
129 |
+
return idx_next, probs
|
130 |
+
|
131 |
+
|
132 |
+
def sample_agent(
|
133 |
+
logits,
|
134 |
+
previous_tokens: Optional[torch.Tensor] = None,
|
135 |
+
**sampling_kwargs,
|
136 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
137 |
+
probs = logits_to_probs_agent(
|
138 |
+
logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
|
139 |
+
)
|
140 |
+
idx_next = multinomial_sample_one_no_sync_agent(probs)
|
141 |
+
return idx_next, probs
|
142 |
+
|
143 |
+
|
144 |
+
def decode_one_token_ar_agent(
|
145 |
+
model: DualARTransformer,
|
146 |
+
x: torch.Tensor,
|
147 |
+
input_pos: torch.Tensor,
|
148 |
+
previous_tokens: torch.Tensor = None,
|
149 |
+
semantic_id: int = 32003,
|
150 |
+
**sampling_kwargs,
|
151 |
+
) -> torch.Tensor:
|
152 |
+
# print(x, input_pos)
|
153 |
+
x = model.forward_generate(x, input_pos)
|
154 |
+
logits = x.logits # [:, -1:]
|
155 |
+
hidden_states = x.hidden_states # [:, -1:]
|
156 |
+
|
157 |
+
sampling_kwargs_main = sampling_kwargs.copy()
|
158 |
+
sampling_kwargs_main["temperature"] = 0.1
|
159 |
+
sampling_kwargs_main["top_p"] = 0.1
|
160 |
+
sampling_kwargs_main["repetition_penalty"] = 1.0
|
161 |
+
|
162 |
+
codebooks = [
|
163 |
+
sample_agent(
|
164 |
+
logits,
|
165 |
+
previous_tokens=None, # Disable repetition penalty for the token codebook
|
166 |
+
**sampling_kwargs_main,
|
167 |
+
)[0]
|
168 |
+
]
|
169 |
+
|
170 |
+
# Cleanup the cache
|
171 |
+
for layer in model.fast_layers:
|
172 |
+
layer.attention.kv_cache.k_cache.fill_(0)
|
173 |
+
layer.attention.kv_cache.v_cache.fill_(0)
|
174 |
+
|
175 |
+
for codebook_idx in range(model.config.num_codebooks):
|
176 |
+
input_pos = torch.tensor(
|
177 |
+
[codebook_idx], device=hidden_states.device, dtype=torch.long
|
178 |
+
)
|
179 |
+
logits = model.forward_generate_fast(hidden_states, input_pos)
|
180 |
+
a = sample_agent(
|
181 |
+
logits,
|
182 |
+
previous_tokens=(
|
183 |
+
previous_tokens[:, codebook_idx + 1]
|
184 |
+
if previous_tokens is not None
|
185 |
+
else None
|
186 |
+
),
|
187 |
+
**sampling_kwargs,
|
188 |
+
)[0]
|
189 |
+
hidden_states = model.fast_embeddings(a)
|
190 |
+
codebooks.append(a)
|
191 |
+
|
192 |
+
codebooks = torch.stack(codebooks, dim=1)
|
193 |
+
codebooks[:, 1:, :] = torch.masked_fill(
|
194 |
+
codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
195 |
+
)
|
196 |
+
|
197 |
+
# for i in range(codebooks.size(1) - 1):
|
198 |
+
# codebooks[:, i + 1, :] = torch.masked_fill(
|
199 |
+
# codebooks[:, i + 1, :],
|
200 |
+
# codebooks[:, :1, :] != semantic_id,
|
201 |
+
# CODEBOOK_PAD_TOKEN_ID + i * 1024,
|
202 |
+
# )
|
203 |
+
|
204 |
+
# print(codebooks)
|
205 |
+
|
206 |
+
return codebooks
|
207 |
+
|
208 |
+
|
209 |
+
def decode_one_token_naive_agent(
|
210 |
+
model: NaiveTransformer,
|
211 |
+
x: torch.Tensor,
|
212 |
+
input_pos: torch.Tensor,
|
213 |
+
previous_tokens: torch.Tensor = None,
|
214 |
+
semantic_id: int = 32003,
|
215 |
+
**sampling_kwargs,
|
216 |
+
) -> torch.Tensor:
|
217 |
+
x = model.forward_generate(x, input_pos)
|
218 |
+
|
219 |
+
codebooks = [
|
220 |
+
sample(
|
221 |
+
x.token_logits,
|
222 |
+
previous_tokens=None, # Disable repetition penalty for the token codebook
|
223 |
+
**sampling_kwargs,
|
224 |
+
)[0]
|
225 |
+
]
|
226 |
+
|
227 |
+
for i in range(model.config.num_codebooks):
|
228 |
+
codebooks.append(
|
229 |
+
sample_agent(
|
230 |
+
x.codebook_logits[:, :, i],
|
231 |
+
previous_tokens=(
|
232 |
+
previous_tokens[:, i + 1] if previous_tokens is not None else None
|
233 |
+
),
|
234 |
+
**sampling_kwargs,
|
235 |
+
)[0]
|
236 |
+
)
|
237 |
+
|
238 |
+
codebooks = torch.stack(codebooks, dim=1)
|
239 |
+
codebooks[:, 1:, :] = torch.masked_fill(
|
240 |
+
codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
241 |
+
)
|
242 |
+
|
243 |
+
return codebooks
|
244 |
+
|
245 |
+
|
246 |
+
def decode_one_token_ar(
|
247 |
+
model: DualARTransformer,
|
248 |
+
x: torch.Tensor,
|
249 |
+
input_pos: torch.Tensor,
|
250 |
+
previous_tokens: torch.Tensor = None,
|
251 |
+
semantic_id: int = 0,
|
252 |
+
**sampling_kwargs,
|
253 |
+
) -> torch.Tensor:
|
254 |
+
x = model.forward_generate(x, input_pos)
|
255 |
+
|
256 |
+
sampling_kwargs_main = sampling_kwargs.copy()
|
257 |
+
# sampling_kwargs_main["temperature"] = 0.1
|
258 |
+
# sampling_kwargs_main["top_p"] = 0.1
|
259 |
+
# sampling_kwargs_main["repetition_penalty"] = 1.0
|
260 |
+
|
261 |
+
codebooks = [
|
262 |
+
sample(
|
263 |
+
x.logits,
|
264 |
+
previous_tokens=None, # Disable repetition penalty for the token codebook
|
265 |
+
**sampling_kwargs_main,
|
266 |
+
)[0]
|
267 |
+
]
|
268 |
+
|
269 |
+
x = x.hidden_states
|
270 |
+
|
271 |
+
# Cleanup the cache
|
272 |
+
for layer in model.fast_layers:
|
273 |
+
layer.attention.kv_cache.k_cache.fill_(0)
|
274 |
+
layer.attention.kv_cache.v_cache.fill_(0)
|
275 |
+
|
276 |
+
for codebook_idx in range(model.config.num_codebooks):
|
277 |
+
input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
|
278 |
+
logits = model.forward_generate_fast(x, input_pos)
|
279 |
+
a = sample(
|
280 |
+
logits,
|
281 |
+
previous_tokens=(
|
282 |
+
previous_tokens[codebook_idx + 1]
|
283 |
+
if previous_tokens is not None
|
284 |
+
else None
|
285 |
+
),
|
286 |
+
**sampling_kwargs,
|
287 |
+
)[0]
|
288 |
+
x = model.fast_embeddings(a)
|
289 |
+
codebooks.append(a)
|
290 |
+
|
291 |
+
codebooks = torch.stack(codebooks, dim=0)
|
292 |
+
codebooks[1:, :] = torch.masked_fill(
|
293 |
+
codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
|
294 |
+
)
|
295 |
+
|
296 |
+
return codebooks
|
297 |
+
|
298 |
+
|
299 |
+
def decode_one_token_naive(
|
300 |
+
model: NaiveTransformer,
|
301 |
+
x: torch.Tensor,
|
302 |
+
input_pos: torch.Tensor,
|
303 |
+
previous_tokens: torch.Tensor = None,
|
304 |
+
**sampling_kwargs,
|
305 |
+
) -> torch.Tensor:
|
306 |
+
x = model.forward_generate(x, input_pos)
|
307 |
+
|
308 |
+
sampling_kwargs_main = sampling_kwargs.copy()
|
309 |
+
sampling_kwargs_main["temperature"] = 0.1
|
310 |
+
sampling_kwargs_main["top_p"] = 0.1
|
311 |
+
sampling_kwargs_main["repetition_penalty"] = 1.0
|
312 |
+
|
313 |
+
codebooks = [
|
314 |
+
sample(
|
315 |
+
x.logits,
|
316 |
+
previous_tokens=None, # Disable repetition penalty for the token codebook
|
317 |
+
**sampling_kwargs_main,
|
318 |
+
)[0]
|
319 |
+
]
|
320 |
+
|
321 |
+
for i in range(model.config.num_codebooks):
|
322 |
+
codebooks.append(
|
323 |
+
sample(
|
324 |
+
x.codebook_logits[:, :, i],
|
325 |
+
previous_tokens=(
|
326 |
+
previous_tokens[i + 1] if previous_tokens is not None else None
|
327 |
+
),
|
328 |
+
**sampling_kwargs,
|
329 |
+
)[0]
|
330 |
+
)
|
331 |
+
|
332 |
+
return torch.stack(codebooks, dim=0)
|
333 |
+
|
334 |
+
|
335 |
+
def decode_n_tokens(
|
336 |
+
model: NaiveTransformer,
|
337 |
+
cur_token: torch.Tensor,
|
338 |
+
input_pos: torch.Tensor,
|
339 |
+
num_new_tokens: int,
|
340 |
+
im_end_id: int = 4,
|
341 |
+
decode_one_token=decode_one_token_naive,
|
342 |
+
semantic_id: int = 0,
|
343 |
+
**sampling_kwargs,
|
344 |
+
):
|
345 |
+
previous_tokens = torch.zeros(
|
346 |
+
(model.config.num_codebooks + 1, model.config.max_seq_len),
|
347 |
+
dtype=torch.int,
|
348 |
+
device=cur_token.device,
|
349 |
+
)
|
350 |
+
|
351 |
+
for i in tqdm(range(num_new_tokens)):
|
352 |
+
# We need to get windowed repeat penalty
|
353 |
+
win_size = 16
|
354 |
+
if i < win_size:
|
355 |
+
window = previous_tokens[:, :win_size]
|
356 |
+
else:
|
357 |
+
window = previous_tokens[:, i - win_size : i]
|
358 |
+
|
359 |
+
with (
|
360 |
+
torch.backends.cuda.sdp_kernel(
|
361 |
+
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
362 |
+
)
|
363 |
+
if torch.cuda.is_available()
|
364 |
+
else nullcontext()
|
365 |
+
): # Actually better for Inductor to codegen attention here
|
366 |
+
next_token = decode_one_token(
|
367 |
+
model=model,
|
368 |
+
x=cur_token,
|
369 |
+
input_pos=input_pos,
|
370 |
+
previous_tokens=window,
|
371 |
+
semantic_id=semantic_id,
|
372 |
+
**sampling_kwargs,
|
373 |
+
)
|
374 |
+
|
375 |
+
input_pos += 1
|
376 |
+
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
|
377 |
+
previous_tokens[:, i : i + 1] = next_token.view(
|
378 |
+
model.config.num_codebooks + 1, -1
|
379 |
+
)
|
380 |
+
|
381 |
+
if cur_token[0, 0, -1] == im_end_id:
|
382 |
+
break
|
383 |
+
|
384 |
+
return previous_tokens[:, : i + 1]
|
385 |
+
|
386 |
+
|
387 |
+
@torch.no_grad()
|
388 |
+
@torch.inference_mode()
|
389 |
+
def generate(
|
390 |
+
*,
|
391 |
+
model: NaiveTransformer,
|
392 |
+
prompt: torch.Tensor,
|
393 |
+
max_new_tokens: int =600,
|
394 |
+
im_end_id: int = 4,
|
395 |
+
decode_one_token=decode_one_token_naive,
|
396 |
+
**sampling_kwargs,
|
397 |
+
) -> torch.Tensor:
|
398 |
+
"""
|
399 |
+
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
400 |
+
"""
|
401 |
+
|
402 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
403 |
+
T = prompt.size(1)
|
404 |
+
semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
405 |
+
|
406 |
+
if max_new_tokens:
|
407 |
+
if T + max_new_tokens > model.config.max_seq_len:
|
408 |
+
max_new_tokens = model.config.max_seq_len - T
|
409 |
+
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
410 |
+
|
411 |
+
T_new = T + max_new_tokens
|
412 |
+
else:
|
413 |
+
T_new = model.config.max_seq_len
|
414 |
+
max_new_tokens = T_new - T
|
415 |
+
|
416 |
+
device, dtype = prompt.device, prompt.dtype
|
417 |
+
|
418 |
+
codebook_dim = 1 + model.config.num_codebooks
|
419 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
420 |
+
empty = torch.empty(
|
421 |
+
(codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
|
422 |
+
)
|
423 |
+
empty[:, :T] = prompt
|
424 |
+
seq = empty
|
425 |
+
input_pos = torch.arange(0, T, device=device)
|
426 |
+
|
427 |
+
# Use non-accelerated version for now, to avoid compilation overhead
|
428 |
+
prefill_decode = (
|
429 |
+
decode_one_token_naive
|
430 |
+
if isinstance(model, NaiveTransformer)
|
431 |
+
else decode_one_token_ar
|
432 |
+
)
|
433 |
+
|
434 |
+
next_token = prefill_decode(
|
435 |
+
model,
|
436 |
+
prompt.view(1, codebook_dim, -1),
|
437 |
+
input_pos,
|
438 |
+
semantic_id=semantic_id,
|
439 |
+
**sampling_kwargs,
|
440 |
+
)
|
441 |
+
seq[:, T : T + 1] = next_token
|
442 |
+
|
443 |
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
444 |
+
x = decode_n_tokens(
|
445 |
+
model,
|
446 |
+
next_token.view(1, codebook_dim, -1),
|
447 |
+
input_pos,
|
448 |
+
max_new_tokens - 1,
|
449 |
+
im_end_id=im_end_id,
|
450 |
+
decode_one_token=decode_one_token,
|
451 |
+
semantic_id=semantic_id,
|
452 |
+
**sampling_kwargs,
|
453 |
+
)
|
454 |
+
# x = torch.cat(generated_tokens, dim=1)
|
455 |
+
seq = seq[:, : T + 1 + x.size(1)]
|
456 |
+
seq[:, T + 1 :] = x
|
457 |
+
|
458 |
+
return seq
|
459 |
+
|
460 |
+
|
461 |
+
def decode_n_tokens_agent(
|
462 |
+
model: NaiveTransformer,
|
463 |
+
cur_token: torch.Tensor,
|
464 |
+
input_pos: torch.Tensor,
|
465 |
+
num_new_tokens: int,
|
466 |
+
im_end_id: int = 4,
|
467 |
+
semantic_id: int = 32003,
|
468 |
+
decode_one_token=decode_one_token_naive_agent,
|
469 |
+
early_stop_threshold: float = 0.6,
|
470 |
+
**sampling_kwargs,
|
471 |
+
):
|
472 |
+
batch_size = cur_token.size(0)
|
473 |
+
previous_tokens = torch.zeros(
|
474 |
+
(batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
|
475 |
+
dtype=torch.int,
|
476 |
+
device=cur_token.device,
|
477 |
+
)
|
478 |
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
|
479 |
+
finished = finished | (cur_token[:, 0, -1] == im_end_id)
|
480 |
+
start_time = time.time()
|
481 |
+
|
482 |
+
for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
|
483 |
+
# We need to get windowed repeat penalty
|
484 |
+
win_size = 16
|
485 |
+
if i < win_size:
|
486 |
+
window = previous_tokens[:, :, :win_size]
|
487 |
+
else:
|
488 |
+
window = previous_tokens[:, :, i - win_size : i]
|
489 |
+
|
490 |
+
with sdpa_kernel(
|
491 |
+
SDPBackend.MATH
|
492 |
+
): # Actually better for Inductor to codegen attention here
|
493 |
+
next_token = decode_one_token(
|
494 |
+
model=model,
|
495 |
+
x=cur_token,
|
496 |
+
input_pos=input_pos,
|
497 |
+
previous_tokens=window,
|
498 |
+
semantic_id=semantic_id,
|
499 |
+
**sampling_kwargs,
|
500 |
+
)
|
501 |
+
|
502 |
+
input_pos += 1
|
503 |
+
cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
|
504 |
+
previous_tokens[:, :, i : i + 1] = next_token.view(
|
505 |
+
batch_size, model.config.num_codebooks + 1, -1
|
506 |
+
)
|
507 |
+
|
508 |
+
yield cur_token.cpu()
|
509 |
+
|
510 |
+
finished = finished | (cur_token[:, 0, -1] == im_end_id)
|
511 |
+
if finished.all() or (
|
512 |
+
0 < early_stop_threshold < 1
|
513 |
+
and finished.sum() >= round(batch_size * early_stop_threshold)
|
514 |
+
):
|
515 |
+
break
|
516 |
+
|
517 |
+
total_time = time.time() - start_time
|
518 |
+
generated_tokens = i + 1
|
519 |
+
tokens_per_second = (generated_tokens / total_time) * batch_size
|
520 |
+
logger.info(
|
521 |
+
f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
|
522 |
+
)
|
523 |
+
|
524 |
+
|
525 |
+
@torch.no_grad()
|
526 |
+
@torch.inference_mode()
|
527 |
+
def generate_agent(
|
528 |
+
*,
|
529 |
+
model: BaseTransformer,
|
530 |
+
prompt: torch.Tensor,
|
531 |
+
max_new_tokens: int =500,
|
532 |
+
im_end_id: int = 4,
|
533 |
+
semantic_id: int = 32003,
|
534 |
+
decode_one_token=decode_one_token_naive_agent,
|
535 |
+
num_samples: int = 1,
|
536 |
+
early_stop_threshold: float = 0.6,
|
537 |
+
**sampling_kwargs,
|
538 |
+
):
|
539 |
+
"""
|
540 |
+
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
541 |
+
"""
|
542 |
+
|
543 |
+
# create an empty tensor of the expected final shape and fill in the current tokens
|
544 |
+
T = prompt.size(1)
|
545 |
+
prompt = prompt[None].repeat(num_samples, 1, 1)
|
546 |
+
|
547 |
+
if T >= model.config.max_seq_len:
|
548 |
+
raise ValueError(
|
549 |
+
f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
|
550 |
+
)
|
551 |
+
|
552 |
+
if max_new_tokens:
|
553 |
+
if T + max_new_tokens > model.config.max_seq_len:
|
554 |
+
max_new_tokens = model.config.max_seq_len - T
|
555 |
+
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
|
556 |
+
|
557 |
+
T_new = T + max_new_tokens
|
558 |
+
else:
|
559 |
+
T_new = model.config.max_seq_len
|
560 |
+
max_new_tokens = T_new - T
|
561 |
+
|
562 |
+
device, dtype = prompt.device, prompt.dtype
|
563 |
+
|
564 |
+
codebook_dim = 1 + model.config.num_codebooks
|
565 |
+
input_pos = torch.arange(0, T, device=device)
|
566 |
+
|
567 |
+
# Use non-accelerated version for now, to avoid compilation overhead
|
568 |
+
prefill_decode = (
|
569 |
+
decode_one_token_naive_agent
|
570 |
+
if isinstance(model, NaiveTransformer)
|
571 |
+
else decode_one_token_ar_agent
|
572 |
+
)
|
573 |
+
next_token = prefill_decode(
|
574 |
+
model,
|
575 |
+
prompt,
|
576 |
+
input_pos,
|
577 |
+
semantic_id=semantic_id,
|
578 |
+
**sampling_kwargs,
|
579 |
+
).view(num_samples, codebook_dim, -1)
|
580 |
+
yield next_token.cpu()
|
581 |
+
|
582 |
+
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
583 |
+
|
584 |
+
yield from decode_n_tokens_agent(
|
585 |
+
model,
|
586 |
+
next_token,
|
587 |
+
input_pos,
|
588 |
+
max_new_tokens - 1,
|
589 |
+
im_end_id=im_end_id,
|
590 |
+
semantic_id=semantic_id,
|
591 |
+
decode_one_token=decode_one_token,
|
592 |
+
early_stop_threshold=early_stop_threshold,
|
593 |
+
**sampling_kwargs,
|
594 |
+
)
|
595 |
+
|
596 |
+
|
597 |
+
def encode_tokens(
|
598 |
+
tokenizer,
|
599 |
+
string,
|
600 |
+
device="cuda",
|
601 |
+
prompt_tokens=None,
|
602 |
+
num_codebooks=4,
|
603 |
+
):
|
604 |
+
string = clean_text(string)
|
605 |
+
string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
|
606 |
+
|
607 |
+
new_tokens = tokenizer.encode(
|
608 |
+
string,
|
609 |
+
add_special_tokens=False,
|
610 |
+
max_length=10**6,
|
611 |
+
truncation=False,
|
612 |
+
)
|
613 |
+
tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
614 |
+
|
615 |
+
# Codebooks
|
616 |
+
zeros = (
|
617 |
+
torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
|
618 |
+
* CODEBOOK_PAD_TOKEN_ID
|
619 |
+
)
|
620 |
+
prompt = torch.cat((tokens, zeros), dim=0)
|
621 |
+
|
622 |
+
if prompt_tokens is None:
|
623 |
+
return prompt
|
624 |
+
|
625 |
+
# Get prompt tokens
|
626 |
+
if prompt_tokens.ndim == 3:
|
627 |
+
assert (
|
628 |
+
prompt_tokens.shape[0] == 1
|
629 |
+
), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
|
630 |
+
prompt_tokens = prompt_tokens[0]
|
631 |
+
|
632 |
+
assert prompt_tokens.ndim == 2
|
633 |
+
data = prompt_tokens + 1
|
634 |
+
|
635 |
+
if prompt_tokens.shape[0] > num_codebooks:
|
636 |
+
logger.warning(
|
637 |
+
f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
|
638 |
+
)
|
639 |
+
data = data[:num_codebooks]
|
640 |
+
|
641 |
+
# Add pad token for each codebook
|
642 |
+
data = torch.cat(
|
643 |
+
(data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
|
644 |
+
dim=1,
|
645 |
+
)
|
646 |
+
|
647 |
+
# Since 1.0, we use <|semantic|>
|
648 |
+
s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
|
649 |
+
end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
650 |
+
main_token_ids = (
|
651 |
+
torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
|
652 |
+
)
|
653 |
+
main_token_ids[0, -1] = end_token_id
|
654 |
+
|
655 |
+
data = torch.cat((main_token_ids, data), dim=0)
|
656 |
+
prompt = torch.cat((prompt, data), dim=1)
|
657 |
+
|
658 |
+
return prompt
|
659 |
+
|
660 |
+
|
661 |
+
def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
|
662 |
+
model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
|
663 |
+
checkpoint_path, load_weights=True
|
664 |
+
)
|
665 |
+
|
666 |
+
model = model.to(device=device, dtype=precision)
|
667 |
+
logger.info(f"Restored model from checkpoint")
|
668 |
+
|
669 |
+
if isinstance(model, DualARTransformer):
|
670 |
+
decode_one_token = (
|
671 |
+
decode_one_token_ar_agent if is_agent else decode_one_token_ar
|
672 |
+
)
|
673 |
+
logger.info("Using DualARTransformer")
|
674 |
+
else:
|
675 |
+
decode_one_token = (
|
676 |
+
decode_one_token_naive_agent if is_agent else decode_one_token_naive
|
677 |
+
)
|
678 |
+
logger.info("Using NaiveTransformer")
|
679 |
+
|
680 |
+
if compile:
|
681 |
+
logger.info("Compiling function...")
|
682 |
+
decode_one_token = torch.compile(
|
683 |
+
decode_one_token,
|
684 |
+
fullgraph=True,
|
685 |
+
backend="inductor" if torch.cuda.is_available() else "aot_eager",
|
686 |
+
mode="reduce-overhead" if torch.cuda.is_available() else None,
|
687 |
+
)
|
688 |
+
|
689 |
+
return model.eval(), decode_one_token
|
690 |
+
|
691 |
+
|
692 |
+
@dataclass
|
693 |
+
class GenerateResponse:
|
694 |
+
action: Literal["sample", "next"]
|
695 |
+
codes: Optional[torch.Tensor] = None
|
696 |
+
text: Optional[str] = None
|
697 |
+
|
698 |
+
|
699 |
+
def generate_long(
|
700 |
+
*,
|
701 |
+
model,
|
702 |
+
device: str | torch.device,
|
703 |
+
decode_one_token: callable,
|
704 |
+
text: str,
|
705 |
+
num_samples: int = 1,
|
706 |
+
max_new_tokens: int = 600,
|
707 |
+
top_p: int = 0.7,
|
708 |
+
repetition_penalty: float = 1.5,
|
709 |
+
temperature: float = 0.7,
|
710 |
+
compile: bool = False,
|
711 |
+
iterative_prompt: bool = True,
|
712 |
+
max_length: int = 2048,
|
713 |
+
chunk_length: int = 150,
|
714 |
+
prompt_text: Optional[str | list[str]] = None,
|
715 |
+
prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
|
716 |
+
):
|
717 |
+
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
718 |
+
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
719 |
+
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
720 |
+
|
721 |
+
use_prompt = prompt_text is not None and prompt_tokens is not None
|
722 |
+
if use_prompt and isinstance(prompt_text, str):
|
723 |
+
prompt_text = [prompt_text]
|
724 |
+
prompt_tokens = [prompt_tokens]
|
725 |
+
|
726 |
+
assert use_prompt is False or len(prompt_text) == len(
|
727 |
+
prompt_tokens
|
728 |
+
), "Prompt text and tokens must have the same length"
|
729 |
+
|
730 |
+
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
731 |
+
tokenizer = model.tokenizer
|
732 |
+
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
733 |
+
|
734 |
+
encoded = []
|
735 |
+
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
736 |
+
encoded_prompts = []
|
737 |
+
|
738 |
+
if use_prompt:
|
739 |
+
for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
|
740 |
+
encoded_prompts.append(
|
741 |
+
encode_tokens(
|
742 |
+
tokenizer,
|
743 |
+
string=t,
|
744 |
+
device=device,
|
745 |
+
prompt_tokens=c,
|
746 |
+
num_codebooks=model.config.num_codebooks,
|
747 |
+
)
|
748 |
+
)
|
749 |
+
|
750 |
+
for idx, text in enumerate(texts):
|
751 |
+
encoded.append(
|
752 |
+
encode_tokens(
|
753 |
+
tokenizer,
|
754 |
+
string=text,
|
755 |
+
device=device,
|
756 |
+
num_codebooks=model.config.num_codebooks,
|
757 |
+
)
|
758 |
+
)
|
759 |
+
logger.info(f"Encoded text: {text}")
|
760 |
+
|
761 |
+
# Move temperature, top_p, repetition_penalty to device
|
762 |
+
# This is important so that changing params doesn't trigger recompile
|
763 |
+
temperature = torch.tensor(temperature, device=device, dtype=torch.float)
|
764 |
+
top_p = torch.tensor(top_p, device=device, dtype=torch.float)
|
765 |
+
repetition_penalty = torch.tensor(
|
766 |
+
repetition_penalty, device=device, dtype=torch.float
|
767 |
+
)
|
768 |
+
|
769 |
+
for sample_idx in range(num_samples):
|
770 |
+
if torch.cuda.is_available():
|
771 |
+
torch.cuda.synchronize()
|
772 |
+
|
773 |
+
global_encoded = []
|
774 |
+
seg_idx = 0
|
775 |
+
|
776 |
+
while seg_idx < len(encoded):
|
777 |
+
logger.info(
|
778 |
+
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
|
779 |
+
)
|
780 |
+
|
781 |
+
seg = encoded[seg_idx]
|
782 |
+
global_encoded.append(seg)
|
783 |
+
|
784 |
+
lengths = reversed([seg.size(1) for seg in global_encoded])
|
785 |
+
|
786 |
+
# Pick last 2000 tokens
|
787 |
+
count = 0
|
788 |
+
for i, length in enumerate(lengths):
|
789 |
+
count += length
|
790 |
+
if count + length > max_length - 1024 - sum(
|
791 |
+
t.shape[1] for t in encoded_prompts
|
792 |
+
):
|
793 |
+
break
|
794 |
+
|
795 |
+
if i != 0 and i % 2 == 0:
|
796 |
+
i -= 1
|
797 |
+
|
798 |
+
# Rotate the list, always make sure first segment is included to avoid drift
|
799 |
+
if i < len(global_encoded) - 2:
|
800 |
+
partial_encoded = global_encoded[:2] + global_encoded[-i:]
|
801 |
+
else:
|
802 |
+
partial_encoded = global_encoded
|
803 |
+
|
804 |
+
if use_prompt:
|
805 |
+
partial_encoded = encoded_prompts + partial_encoded
|
806 |
+
|
807 |
+
cat_encoded = torch.cat(partial_encoded, dim=1)
|
808 |
+
prompt_length = cat_encoded.size(1)
|
809 |
+
|
810 |
+
t0 = time.perf_counter()
|
811 |
+
y = generate(
|
812 |
+
model=model,
|
813 |
+
prompt=cat_encoded,
|
814 |
+
max_new_tokens=max_new_tokens,
|
815 |
+
im_end_id=im_end_id,
|
816 |
+
decode_one_token=decode_one_token,
|
817 |
+
temperature=temperature,
|
818 |
+
top_p=top_p,
|
819 |
+
repetition_penalty=repetition_penalty,
|
820 |
+
)
|
821 |
+
|
822 |
+
if sample_idx == 0 and seg_idx == 0 and compile:
|
823 |
+
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
|
824 |
+
|
825 |
+
if torch.cuda.is_available():
|
826 |
+
torch.cuda.synchronize()
|
827 |
+
|
828 |
+
t = time.perf_counter() - t0
|
829 |
+
|
830 |
+
tokens_generated = y.size(1) - prompt_length
|
831 |
+
tokens_sec = tokens_generated / t
|
832 |
+
logger.info(
|
833 |
+
f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
|
834 |
+
)
|
835 |
+
logger.info(
|
836 |
+
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
|
837 |
+
)
|
838 |
+
|
839 |
+
if torch.cuda.is_available():
|
840 |
+
logger.info(
|
841 |
+
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
|
842 |
+
)
|
843 |
+
|
844 |
+
# Put the generated tokens
|
845 |
+
# since there is <im_end> and <eos> tokens, we remove last 2 tokens
|
846 |
+
codes = y[1:, prompt_length:-1].clone()
|
847 |
+
codes = codes - 1
|
848 |
+
assert (codes >= 0).all(), f"Negative code found"
|
849 |
+
|
850 |
+
decoded = y[:, prompt_length:-1].clone()
|
851 |
+
# But for global encoding, we should keep the <im_end> token
|
852 |
+
|
853 |
+
global_encoded.append(decoded)
|
854 |
+
assert (codes >= 0).all(), f"Negative code found: {codes}"
|
855 |
+
yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
|
856 |
+
seg_idx += 1
|
857 |
+
|
858 |
+
# This indicates the end of the current sample
|
859 |
+
yield GenerateResponse(action="next")
|
860 |
+
|
861 |
+
|
862 |
+
@dataclass
|
863 |
+
class WrappedGenerateResponse:
|
864 |
+
status: Literal["success", "error"]
|
865 |
+
response: Optional[GenerateResponse | Exception] = None
|
866 |
+
|
867 |
+
|
868 |
+
@dataclass
|
869 |
+
class GenerateRequest:
|
870 |
+
request: dict
|
871 |
+
response_queue: queue.Queue
|
872 |
+
|
873 |
+
|
874 |
+
def launch_thread_safe_queue(
|
875 |
+
checkpoint_path,
|
876 |
+
device,
|
877 |
+
precision,
|
878 |
+
compile: bool = False,
|
879 |
+
):
|
880 |
+
input_queue = queue.Queue()
|
881 |
+
init_event = threading.Event()
|
882 |
+
|
883 |
+
def worker():
|
884 |
+
model, decode_one_token = load_model(
|
885 |
+
checkpoint_path, device, precision, compile=compile
|
886 |
+
)
|
887 |
+
with torch.device(device):
|
888 |
+
model.setup_caches(
|
889 |
+
max_batch_size=1,
|
890 |
+
max_seq_len=model.config.max_seq_len,
|
891 |
+
dtype=next(model.parameters()).dtype,
|
892 |
+
)
|
893 |
+
init_event.set()
|
894 |
+
|
895 |
+
while True:
|
896 |
+
item: GenerateRequest | None = input_queue.get()
|
897 |
+
if item is None:
|
898 |
+
break
|
899 |
+
|
900 |
+
kwargs = item.request
|
901 |
+
response_queue = item.response_queue
|
902 |
+
|
903 |
+
try:
|
904 |
+
for chunk in generate_long(
|
905 |
+
model=model, decode_one_token=decode_one_token, **kwargs
|
906 |
+
):
|
907 |
+
response_queue.put(
|
908 |
+
WrappedGenerateResponse(status="success", response=chunk)
|
909 |
+
)
|
910 |
+
except Exception as e:
|
911 |
+
response_queue.put(WrappedGenerateResponse(status="error", response=e))
|
912 |
+
|
913 |
+
threading.Thread(target=worker, daemon=True).start()
|
914 |
+
init_event.wait()
|
915 |
+
|
916 |
+
return input_queue
|
917 |
+
|
918 |
+
|
919 |
+
def launch_thread_safe_queue_agent(
|
920 |
+
checkpoint_path,
|
921 |
+
device,
|
922 |
+
precision,
|
923 |
+
compile: bool = False,
|
924 |
+
):
|
925 |
+
input_queue = queue.Queue()
|
926 |
+
init_event = threading.Event()
|
927 |
+
|
928 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
929 |
+
config = BaseModelArgs.from_pretrained(checkpoint_path)
|
930 |
+
|
931 |
+
def worker():
|
932 |
+
model, decode_one_token = load_model(
|
933 |
+
checkpoint_path, device, precision, compile=compile, is_agent=True
|
934 |
+
)
|
935 |
+
|
936 |
+
with torch.device(device):
|
937 |
+
model.setup_caches(
|
938 |
+
max_batch_size=1,
|
939 |
+
max_seq_len=model.config.max_seq_len,
|
940 |
+
dtype=next(model.parameters()).dtype,
|
941 |
+
)
|
942 |
+
init_event.set()
|
943 |
+
|
944 |
+
while True:
|
945 |
+
item: GenerateRequest | None = input_queue.get()
|
946 |
+
if item is None:
|
947 |
+
break
|
948 |
+
|
949 |
+
kwargs = item.request
|
950 |
+
response_queue = item.response_queue
|
951 |
+
|
952 |
+
try:
|
953 |
+
for token in generate_agent(
|
954 |
+
model=model,
|
955 |
+
decode_one_token=decode_one_token,
|
956 |
+
**kwargs,
|
957 |
+
):
|
958 |
+
response_queue.put(token)
|
959 |
+
|
960 |
+
response_queue.put("stop")
|
961 |
+
except Exception as e:
|
962 |
+
import traceback
|
963 |
+
|
964 |
+
logger.exception(f"Error in worker: {traceback.format_exc()}")
|
965 |
+
response_queue.put("error")
|
966 |
+
|
967 |
+
threading.Thread(target=worker, daemon=True).start()
|
968 |
+
init_event.wait()
|
969 |
+
|
970 |
+
return input_queue, tokenizer, config
|
971 |
+
|
972 |
+
|
973 |
+
@click.command()
|
974 |
+
@click.option(
|
975 |
+
"--text",
|
976 |
+
type=str,
|
977 |
+
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
978 |
+
)
|
979 |
+
@click.option("--prompt-text", type=str, default=None, multiple=True)
|
980 |
+
@click.option(
|
981 |
+
"--prompt-tokens",
|
982 |
+
type=click.Path(path_type=Path, exists=True),
|
983 |
+
default=None,
|
984 |
+
multiple=True,
|
985 |
+
)
|
986 |
+
@click.option("--num-samples", type=int, default=1)
|
987 |
+
@click.option("--max-new-tokens", type=int, default=0)
|
988 |
+
@click.option("--top-p", type=float, default=0.7)
|
989 |
+
@click.option("--repetition-penalty", type=float, default=1.2)
|
990 |
+
@click.option("--temperature", type=float, default=0.7)
|
991 |
+
@click.option(
|
992 |
+
"--checkpoint-path",
|
993 |
+
type=click.Path(path_type=Path, exists=True),
|
994 |
+
default="checkpoints/fish-speech-1.4",
|
995 |
+
)
|
996 |
+
@click.option("--device", type=str, default="cuda")
|
997 |
+
@click.option("--compile/--no-compile", default=False)
|
998 |
+
@click.option("--seed", type=int, default=42)
|
999 |
+
@click.option("--half/--no-half", default=False)
|
1000 |
+
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
1001 |
+
@click.option("--chunk-length", type=int, default=100)
|
1002 |
+
def main(
|
1003 |
+
text: str,
|
1004 |
+
prompt_text: Optional[list[str]],
|
1005 |
+
prompt_tokens: Optional[list[Path]],
|
1006 |
+
num_samples: int,
|
1007 |
+
max_new_tokens: int,
|
1008 |
+
top_p: int,
|
1009 |
+
repetition_penalty: float,
|
1010 |
+
temperature: float,
|
1011 |
+
checkpoint_path: Path,
|
1012 |
+
device: str,
|
1013 |
+
compile: bool,
|
1014 |
+
seed: int,
|
1015 |
+
half: bool,
|
1016 |
+
iterative_prompt: bool,
|
1017 |
+
chunk_length: int,
|
1018 |
+
) -> None:
|
1019 |
+
|
1020 |
+
precision = torch.half if half else torch.bfloat16
|
1021 |
+
|
1022 |
+
if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
|
1023 |
+
raise ValueError(
|
1024 |
+
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
logger.info("Loading model ...")
|
1028 |
+
t0 = time.time()
|
1029 |
+
model, decode_one_token = load_model(
|
1030 |
+
checkpoint_path, device, precision, compile=compile
|
1031 |
+
)
|
1032 |
+
with torch.device(device):
|
1033 |
+
model.setup_caches(
|
1034 |
+
max_batch_size=1,
|
1035 |
+
max_seq_len=model.config.max_seq_len,
|
1036 |
+
dtype=next(model.parameters()).dtype,
|
1037 |
+
)
|
1038 |
+
if torch.cuda.is_available():
|
1039 |
+
torch.cuda.synchronize()
|
1040 |
+
|
1041 |
+
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
1042 |
+
|
1043 |
+
if prompt_tokens is not None:
|
1044 |
+
prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
|
1045 |
+
|
1046 |
+
torch.manual_seed(seed)
|
1047 |
+
|
1048 |
+
if torch.cuda.is_available():
|
1049 |
+
torch.cuda.manual_seed(seed)
|
1050 |
+
|
1051 |
+
generator = generate_long(
|
1052 |
+
model=model,
|
1053 |
+
device=device,
|
1054 |
+
decode_one_token=decode_one_token,
|
1055 |
+
text=text,
|
1056 |
+
num_samples=num_samples,
|
1057 |
+
max_new_tokens=max_new_tokens,
|
1058 |
+
top_p=top_p,
|
1059 |
+
repetition_penalty=repetition_penalty,
|
1060 |
+
temperature=temperature,
|
1061 |
+
compile=compile,
|
1062 |
+
iterative_prompt=iterative_prompt,
|
1063 |
+
chunk_length=chunk_length,
|
1064 |
+
prompt_text=prompt_text,
|
1065 |
+
prompt_tokens=prompt_tokens,
|
1066 |
+
)
|
1067 |
+
|
1068 |
+
idx = 0
|
1069 |
+
codes = []
|
1070 |
+
|
1071 |
+
for response in generator:
|
1072 |
+
if response.action == "sample":
|
1073 |
+
codes.append(response.codes)
|
1074 |
+
logger.info(f"Sampled text: {response.text}")
|
1075 |
+
elif response.action == "next":
|
1076 |
+
if codes:
|
1077 |
+
np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
|
1078 |
+
logger.info(f"Saved codes to codes_{idx}.npy")
|
1079 |
+
logger.info(f"Next sample")
|
1080 |
+
codes = []
|
1081 |
+
idx += 1
|
1082 |
+
else:
|
1083 |
+
logger.error(f"Error: {response}")
|
1084 |
+
|
1085 |
+
|
1086 |
+
if __name__ == "__main__":
|
1087 |
+
main()
|
tools/llama/merge_lora.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
from copy import deepcopy
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import click
|
6 |
+
import hydra
|
7 |
+
import torch
|
8 |
+
from hydra import compose, initialize
|
9 |
+
from hydra.utils import instantiate
|
10 |
+
from loguru import logger
|
11 |
+
|
12 |
+
from fish_speech.models.text2semantic.llama import BaseTransformer
|
13 |
+
from fish_speech.models.text2semantic.lora import get_merged_state_dict
|
14 |
+
|
15 |
+
|
16 |
+
@click.command()
|
17 |
+
@click.option("--lora-config", type=str, default="r_8_alpha_16")
|
18 |
+
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
|
19 |
+
@click.option("--lora-weight", type=str, required=True)
|
20 |
+
@click.option("--output", type=str, required=True)
|
21 |
+
def merge(lora_config, base_weight, lora_weight, output):
|
22 |
+
output = Path(output)
|
23 |
+
logger.info(
|
24 |
+
f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
|
25 |
+
)
|
26 |
+
|
27 |
+
with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
|
28 |
+
cfg = compose(config_name=lora_config)
|
29 |
+
|
30 |
+
lora_config = instantiate(cfg)
|
31 |
+
logger.info(f"Loaded lora model with config {lora_config}")
|
32 |
+
|
33 |
+
llama_model = BaseTransformer.from_pretrained(
|
34 |
+
path=base_weight,
|
35 |
+
load_weights=True,
|
36 |
+
lora_config=lora_config,
|
37 |
+
)
|
38 |
+
logger.info(f"Loaded llama model")
|
39 |
+
|
40 |
+
llama_state_dict = llama_model.state_dict()
|
41 |
+
llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
|
42 |
+
llama_state_dict_copy = deepcopy(llama_state_dict)
|
43 |
+
lora_state_dict = torch.load(lora_weight, map_location="cpu")
|
44 |
+
|
45 |
+
if "state_dict" in llama_state_dict:
|
46 |
+
llama_state_dict = llama_state_dict["state_dict"]
|
47 |
+
|
48 |
+
if "state_dict" in lora_state_dict:
|
49 |
+
lora_state_dict = lora_state_dict["state_dict"]
|
50 |
+
|
51 |
+
# remove prefix model.
|
52 |
+
if any(k.startswith("model.") for k in llama_state_dict.keys()):
|
53 |
+
llama_state_dict = {
|
54 |
+
k.replace("model.", ""): v
|
55 |
+
for k, v in llama_state_dict.items()
|
56 |
+
if k.startswith("model.")
|
57 |
+
}
|
58 |
+
if any(k.startswith("model.") for k in lora_state_dict.keys()):
|
59 |
+
lora_state_dict = {
|
60 |
+
k.replace("model.", ""): v
|
61 |
+
for k, v in lora_state_dict.items()
|
62 |
+
if k.startswith("model.")
|
63 |
+
}
|
64 |
+
|
65 |
+
logger.info(f"Found {len(llama_state_dict)} keys in llama model")
|
66 |
+
logger.info(f"Found {len(lora_state_dict)} keys in lora model")
|
67 |
+
|
68 |
+
merged_state_dict = llama_state_dict | lora_state_dict
|
69 |
+
llama_model.load_state_dict(merged_state_dict, strict=True)
|
70 |
+
logger.info(f"Merged model loaded")
|
71 |
+
|
72 |
+
# Trigger eval mode to merge lora
|
73 |
+
llama_model.eval()
|
74 |
+
llama_model.save_pretrained(output, drop_lora=True)
|
75 |
+
logger.info(f"Saved merged model to {output}, validating")
|
76 |
+
|
77 |
+
new_state_dict = torch.load(output / "model.pth", map_location="cpu")
|
78 |
+
original_keys = set(llama_state_dict_copy.keys())
|
79 |
+
merged_keys = set(new_state_dict.keys())
|
80 |
+
|
81 |
+
assert original_keys == merged_keys, "Keys should be same"
|
82 |
+
|
83 |
+
for key in original_keys:
|
84 |
+
diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
|
85 |
+
if diff_l1 != 0:
|
86 |
+
break
|
87 |
+
else:
|
88 |
+
logger.error("Merged model is same as the original model")
|
89 |
+
exit(1)
|
90 |
+
|
91 |
+
logger.info("Merged model is different from the original model, check passed")
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
merge()
|
tools/llama/quantize.py
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
import datetime
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
# This source code is licensed under the license found in the
|
7 |
+
# LICENSE file in the root directory of this source tree.
|
8 |
+
import time
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import click
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from fish_speech.models.text2semantic.llama import find_multiple
|
17 |
+
from tools.llama.generate import load_model
|
18 |
+
|
19 |
+
##### Quantization Primitives ######
|
20 |
+
|
21 |
+
|
22 |
+
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
23 |
+
# assumes symmetric quantization
|
24 |
+
# assumes axis == 0
|
25 |
+
# assumes dense memory format
|
26 |
+
# TODO(future): relax ^ as needed
|
27 |
+
|
28 |
+
# default setup for affine quantization of activations
|
29 |
+
eps = torch.finfo(torch.float32).eps
|
30 |
+
|
31 |
+
# get min and max
|
32 |
+
min_val, max_val = torch.aminmax(x, dim=1)
|
33 |
+
|
34 |
+
# calculate scales and zero_points based on min and max
|
35 |
+
# reference: https://fburl.com/code/srbiybme
|
36 |
+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
37 |
+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
38 |
+
device = min_val_neg.device
|
39 |
+
|
40 |
+
# reference: https://fburl.com/code/4wll53rk
|
41 |
+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
42 |
+
scales = max_val_pos / (float(quant_max - quant_min) / 2)
|
43 |
+
# ensure scales is the same dtype as the original tensor
|
44 |
+
scales = torch.clamp(scales, min=eps).to(x.dtype)
|
45 |
+
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
|
46 |
+
|
47 |
+
# quantize based on qmin/qmax/scales/zp
|
48 |
+
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
|
49 |
+
x_div = x / scales.unsqueeze(-1)
|
50 |
+
x_round = torch.round(x_div)
|
51 |
+
x_zp = x_round + zero_points.unsqueeze(-1)
|
52 |
+
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
|
53 |
+
|
54 |
+
return quant, scales, zero_points
|
55 |
+
|
56 |
+
|
57 |
+
def get_group_qparams(w, n_bit=4, groupsize=128):
|
58 |
+
# needed for GPTQ with padding
|
59 |
+
if groupsize > w.shape[-1]:
|
60 |
+
groupsize = w.shape[-1]
|
61 |
+
assert groupsize > 1
|
62 |
+
assert w.shape[-1] % groupsize == 0
|
63 |
+
assert w.dim() == 2
|
64 |
+
|
65 |
+
to_quant = w.reshape(-1, groupsize)
|
66 |
+
assert torch.isnan(to_quant).sum() == 0
|
67 |
+
|
68 |
+
max_val = to_quant.amax(dim=1, keepdim=True)
|
69 |
+
min_val = to_quant.amin(dim=1, keepdim=True)
|
70 |
+
max_int = 2**n_bit - 1
|
71 |
+
scales = (max_val - min_val).clamp(min=1e-6) / max_int
|
72 |
+
zeros = min_val + scales * (2 ** (n_bit - 1))
|
73 |
+
return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
|
74 |
+
torch.bfloat16
|
75 |
+
).reshape(w.shape[0], -1)
|
76 |
+
|
77 |
+
|
78 |
+
def pack_scales_and_zeros(scales, zeros):
|
79 |
+
assert scales.shape == zeros.shape
|
80 |
+
assert scales.dtype == torch.bfloat16
|
81 |
+
assert zeros.dtype == torch.bfloat16
|
82 |
+
return (
|
83 |
+
torch.cat(
|
84 |
+
[
|
85 |
+
scales.reshape(scales.size(0), scales.size(1), 1),
|
86 |
+
zeros.reshape(zeros.size(0), zeros.size(1), 1),
|
87 |
+
],
|
88 |
+
2,
|
89 |
+
)
|
90 |
+
.transpose(0, 1)
|
91 |
+
.contiguous()
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
def unpack_scales_and_zeros(scales_and_zeros):
|
96 |
+
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
|
97 |
+
assert scales_and_zeros.dtype == torch.float
|
98 |
+
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
|
99 |
+
|
100 |
+
|
101 |
+
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
|
102 |
+
assert groupsize > 1
|
103 |
+
# needed for GPTQ single column quantize
|
104 |
+
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
|
105 |
+
groupsize = w.shape[-1]
|
106 |
+
|
107 |
+
assert w.shape[-1] % groupsize == 0
|
108 |
+
assert w.dim() == 2
|
109 |
+
|
110 |
+
to_quant = w.reshape(-1, groupsize)
|
111 |
+
assert torch.isnan(to_quant).sum() == 0
|
112 |
+
|
113 |
+
scales = scales.reshape(-1, 1)
|
114 |
+
zeros = zeros.reshape(-1, 1)
|
115 |
+
min_val = zeros - scales * (2 ** (n_bit - 1))
|
116 |
+
max_int = 2**n_bit - 1
|
117 |
+
min_int = 0
|
118 |
+
w_int32 = (
|
119 |
+
to_quant.sub(min_val)
|
120 |
+
.div(scales)
|
121 |
+
.round()
|
122 |
+
.clamp_(min_int, max_int)
|
123 |
+
.to(torch.int32)
|
124 |
+
.reshape_as(w)
|
125 |
+
)
|
126 |
+
|
127 |
+
return w_int32
|
128 |
+
|
129 |
+
|
130 |
+
def group_quantize_tensor(w, n_bit=4, groupsize=128):
|
131 |
+
scales, zeros = get_group_qparams(w, n_bit, groupsize)
|
132 |
+
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
|
133 |
+
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
|
134 |
+
return w_int32, scales_and_zeros
|
135 |
+
|
136 |
+
|
137 |
+
def group_dequantize_tensor_from_qparams(
|
138 |
+
w_int32, scales, zeros, n_bit=4, groupsize=128
|
139 |
+
):
|
140 |
+
assert groupsize > 1
|
141 |
+
# needed for GPTQ single column dequantize
|
142 |
+
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
|
143 |
+
groupsize = w_int32.shape[-1]
|
144 |
+
assert w_int32.shape[-1] % groupsize == 0
|
145 |
+
assert w_int32.dim() == 2
|
146 |
+
|
147 |
+
w_int32_grouped = w_int32.reshape(-1, groupsize)
|
148 |
+
scales = scales.reshape(-1, 1)
|
149 |
+
zeros = zeros.reshape(-1, 1)
|
150 |
+
|
151 |
+
w_dq = (
|
152 |
+
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
|
153 |
+
)
|
154 |
+
return w_dq
|
155 |
+
|
156 |
+
|
157 |
+
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
|
158 |
+
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
|
159 |
+
return group_dequantize_tensor_from_qparams(
|
160 |
+
w_int32, scales, zeros, n_bit, groupsize
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
class QuantHandler:
|
165 |
+
def __init__(self, mod):
|
166 |
+
self.mod = mod
|
167 |
+
|
168 |
+
def create_quantized_state_dict(self) -> "StateDict":
|
169 |
+
pass
|
170 |
+
|
171 |
+
def convert_for_runtime(self) -> "nn.Module":
|
172 |
+
pass
|
173 |
+
|
174 |
+
|
175 |
+
##### Weight-only int8 per-channel quantized code ######
|
176 |
+
|
177 |
+
|
178 |
+
def replace_linear_weight_only_int8_per_channel(module):
|
179 |
+
for name, child in module.named_children():
|
180 |
+
if isinstance(child, nn.Linear):
|
181 |
+
setattr(
|
182 |
+
module,
|
183 |
+
name,
|
184 |
+
WeightOnlyInt8Linear(child.in_features, child.out_features),
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
replace_linear_weight_only_int8_per_channel(child)
|
188 |
+
|
189 |
+
|
190 |
+
class WeightOnlyInt8QuantHandler:
|
191 |
+
def __init__(self, mod):
|
192 |
+
self.mod = mod
|
193 |
+
|
194 |
+
@torch.no_grad()
|
195 |
+
def create_quantized_state_dict(self):
|
196 |
+
cur_state_dict = self.mod.state_dict()
|
197 |
+
for fqn, mod in self.mod.named_modules():
|
198 |
+
if isinstance(mod, torch.nn.Linear):
|
199 |
+
int8_weight, scales, _ = dynamically_quantize_per_channel(
|
200 |
+
mod.weight.float(), -128, 127, torch.int8
|
201 |
+
)
|
202 |
+
cur_state_dict[f"{fqn}.weight"] = int8_weight
|
203 |
+
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
|
204 |
+
|
205 |
+
return cur_state_dict
|
206 |
+
|
207 |
+
def convert_for_runtime(self):
|
208 |
+
replace_linear_weight_only_int8_per_channel(self.mod)
|
209 |
+
return self.mod
|
210 |
+
|
211 |
+
|
212 |
+
class WeightOnlyInt8Linear(torch.nn.Module):
|
213 |
+
__constants__ = ["in_features", "out_features"]
|
214 |
+
in_features: int
|
215 |
+
out_features: int
|
216 |
+
weight: torch.Tensor
|
217 |
+
|
218 |
+
def __init__(
|
219 |
+
self,
|
220 |
+
in_features: int,
|
221 |
+
out_features: int,
|
222 |
+
bias: bool = True,
|
223 |
+
device=None,
|
224 |
+
dtype=None,
|
225 |
+
) -> None:
|
226 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
227 |
+
super().__init__()
|
228 |
+
self.in_features = in_features
|
229 |
+
self.out_features = out_features
|
230 |
+
self.register_buffer(
|
231 |
+
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
|
232 |
+
)
|
233 |
+
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
|
234 |
+
|
235 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
236 |
+
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
|
237 |
+
|
238 |
+
|
239 |
+
##### weight only int4 per channel groupwise quantized code ######
|
240 |
+
|
241 |
+
|
242 |
+
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
|
243 |
+
weight_int32, scales_and_zeros = group_quantize_tensor(
|
244 |
+
weight_bf16, n_bit=4, groupsize=groupsize
|
245 |
+
)
|
246 |
+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
|
247 |
+
weight_int32, inner_k_tiles
|
248 |
+
)
|
249 |
+
return weight_int4pack, scales_and_zeros
|
250 |
+
|
251 |
+
|
252 |
+
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
|
253 |
+
origin_x_size = x.size()
|
254 |
+
x = x.reshape(-1, origin_x_size[-1])
|
255 |
+
c = torch.ops.aten._weight_int4pack_mm(
|
256 |
+
x, weight_int4pack, groupsize, scales_and_zeros
|
257 |
+
)
|
258 |
+
new_shape = origin_x_size[:-1] + (out_features,)
|
259 |
+
c = c.reshape(new_shape)
|
260 |
+
return c
|
261 |
+
|
262 |
+
|
263 |
+
def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
|
264 |
+
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
|
265 |
+
|
266 |
+
|
267 |
+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
|
268 |
+
for name, child in module.named_children():
|
269 |
+
if isinstance(child, nn.Linear):
|
270 |
+
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
|
271 |
+
setattr(
|
272 |
+
module,
|
273 |
+
name,
|
274 |
+
WeightOnlyInt4Linear(
|
275 |
+
child.in_features,
|
276 |
+
child.out_features,
|
277 |
+
bias=False,
|
278 |
+
groupsize=groupsize,
|
279 |
+
inner_k_tiles=inner_k_tiles,
|
280 |
+
padding=False,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
elif padding:
|
284 |
+
setattr(
|
285 |
+
module,
|
286 |
+
name,
|
287 |
+
WeightOnlyInt4Linear(
|
288 |
+
child.in_features,
|
289 |
+
child.out_features,
|
290 |
+
bias=False,
|
291 |
+
groupsize=groupsize,
|
292 |
+
inner_k_tiles=inner_k_tiles,
|
293 |
+
padding=True,
|
294 |
+
),
|
295 |
+
)
|
296 |
+
else:
|
297 |
+
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
|
298 |
+
|
299 |
+
|
300 |
+
class WeightOnlyInt4QuantHandler:
|
301 |
+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
|
302 |
+
self.mod = mod
|
303 |
+
self.groupsize = groupsize
|
304 |
+
self.inner_k_tiles = inner_k_tiles
|
305 |
+
self.padding = padding
|
306 |
+
assert groupsize in [32, 64, 128, 256]
|
307 |
+
assert inner_k_tiles in [2, 4, 8]
|
308 |
+
|
309 |
+
@torch.no_grad()
|
310 |
+
def create_quantized_state_dict(self):
|
311 |
+
cur_state_dict = self.mod.state_dict()
|
312 |
+
for fqn, mod in self.mod.named_modules():
|
313 |
+
if isinstance(mod, torch.nn.Linear):
|
314 |
+
assert not mod.bias
|
315 |
+
out_features = mod.out_features
|
316 |
+
in_features = mod.in_features
|
317 |
+
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
318 |
+
print(f"linear: {fqn}, in={in_features}, out={out_features}")
|
319 |
+
|
320 |
+
weight = mod.weight.data
|
321 |
+
if not _check_linear_int4_k(
|
322 |
+
in_features, self.groupsize, self.inner_k_tiles
|
323 |
+
):
|
324 |
+
if self.padding:
|
325 |
+
import torch.nn.functional as F
|
326 |
+
|
327 |
+
print(
|
328 |
+
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
|
329 |
+
)
|
330 |
+
padded_in_features = find_multiple(in_features, 1024)
|
331 |
+
weight = F.pad(
|
332 |
+
weight, pad=(0, padded_in_features - in_features)
|
333 |
+
)
|
334 |
+
else:
|
335 |
+
print(
|
336 |
+
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
|
337 |
+
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
|
338 |
+
)
|
339 |
+
continue
|
340 |
+
(
|
341 |
+
weight_int4pack,
|
342 |
+
scales_and_zeros,
|
343 |
+
) = prepare_int4_weight_and_scales_and_zeros(
|
344 |
+
weight.to(torch.bfloat16).to("cuda"),
|
345 |
+
self.groupsize,
|
346 |
+
self.inner_k_tiles,
|
347 |
+
)
|
348 |
+
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
|
349 |
+
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
|
350 |
+
|
351 |
+
return cur_state_dict
|
352 |
+
|
353 |
+
def convert_for_runtime(self):
|
354 |
+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
|
355 |
+
return self.mod
|
356 |
+
|
357 |
+
|
358 |
+
class WeightOnlyInt4Linear(torch.nn.Module):
|
359 |
+
__constants__ = ["in_features", "out_features"]
|
360 |
+
in_features: int
|
361 |
+
out_features: int
|
362 |
+
weight: torch.Tensor
|
363 |
+
|
364 |
+
def __init__(
|
365 |
+
self,
|
366 |
+
in_features: int,
|
367 |
+
out_features: int,
|
368 |
+
bias=True,
|
369 |
+
device=None,
|
370 |
+
dtype=None,
|
371 |
+
groupsize: int = 128,
|
372 |
+
inner_k_tiles: int = 8,
|
373 |
+
padding: bool = True,
|
374 |
+
) -> None:
|
375 |
+
super().__init__()
|
376 |
+
self.padding = padding
|
377 |
+
if padding:
|
378 |
+
self.origin_in_features = in_features
|
379 |
+
in_features = find_multiple(in_features, 1024)
|
380 |
+
|
381 |
+
self.in_features = in_features
|
382 |
+
self.out_features = out_features
|
383 |
+
assert not bias, "require bias=False"
|
384 |
+
self.groupsize = groupsize
|
385 |
+
self.inner_k_tiles = inner_k_tiles
|
386 |
+
|
387 |
+
assert out_features % 8 == 0, "require out_features % 8 == 0"
|
388 |
+
assert (
|
389 |
+
in_features % (inner_k_tiles * 16) == 0
|
390 |
+
), "require in_features % (innerKTiles * 16) == 0"
|
391 |
+
self.register_buffer(
|
392 |
+
"weight",
|
393 |
+
torch.empty(
|
394 |
+
(
|
395 |
+
out_features // 8,
|
396 |
+
in_features // (inner_k_tiles * 16),
|
397 |
+
32,
|
398 |
+
inner_k_tiles // 2,
|
399 |
+
),
|
400 |
+
dtype=torch.int32,
|
401 |
+
),
|
402 |
+
)
|
403 |
+
self.register_buffer(
|
404 |
+
"scales_and_zeros",
|
405 |
+
torch.empty(
|
406 |
+
(in_features // groupsize, out_features, 2), dtype=torch.bfloat16
|
407 |
+
),
|
408 |
+
)
|
409 |
+
|
410 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
411 |
+
input = input.to(torch.bfloat16)
|
412 |
+
if self.padding:
|
413 |
+
import torch.nn.functional as F
|
414 |
+
|
415 |
+
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
|
416 |
+
return linear_forward_int4(
|
417 |
+
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
|
418 |
+
)
|
419 |
+
|
420 |
+
|
421 |
+
def generate_folder_name():
|
422 |
+
now = datetime.datetime.now()
|
423 |
+
folder_name = now.strftime("%Y%m%d_%H%M%S")
|
424 |
+
return folder_name
|
425 |
+
|
426 |
+
|
427 |
+
@click.command()
|
428 |
+
@click.option(
|
429 |
+
"--checkpoint-path",
|
430 |
+
type=click.Path(path_type=Path, exists=True),
|
431 |
+
default="checkpoints/fish-speech-1.4",
|
432 |
+
)
|
433 |
+
@click.option(
|
434 |
+
"--mode", type=str, default="int8", help="type of quantization to perform"
|
435 |
+
)
|
436 |
+
@click.option(
|
437 |
+
"--groupsize", type=int, default=128, help="Group size for int4 quantization."
|
438 |
+
)
|
439 |
+
@click.option("--timestamp", type=str, default="None", help="When to do quantization")
|
440 |
+
def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
|
441 |
+
|
442 |
+
device = "cpu"
|
443 |
+
precision = torch.bfloat16
|
444 |
+
|
445 |
+
print("Loading model ...")
|
446 |
+
t0 = time.time()
|
447 |
+
|
448 |
+
model, _ = load_model(
|
449 |
+
checkpoint_path=checkpoint_path,
|
450 |
+
device=device,
|
451 |
+
precision=precision,
|
452 |
+
compile=False,
|
453 |
+
)
|
454 |
+
vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
|
455 |
+
now = timestamp if timestamp != "None" else generate_folder_name()
|
456 |
+
|
457 |
+
if mode == "int8":
|
458 |
+
print(
|
459 |
+
"Quantizing model weights for int8 weight-only symmetric per-channel quantization"
|
460 |
+
)
|
461 |
+
quant_handler = WeightOnlyInt8QuantHandler(model)
|
462 |
+
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
463 |
+
|
464 |
+
dir_name = checkpoint_path
|
465 |
+
dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
|
466 |
+
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
|
467 |
+
if (dst_name / vq_model).exists():
|
468 |
+
(dst_name / vq_model).unlink()
|
469 |
+
quantize_path = dst_name / "model.pth"
|
470 |
+
|
471 |
+
elif mode == "int4":
|
472 |
+
print(
|
473 |
+
"Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
|
474 |
+
)
|
475 |
+
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
|
476 |
+
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
477 |
+
|
478 |
+
dir_name = checkpoint_path
|
479 |
+
dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
|
480 |
+
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
|
481 |
+
if (dst_name / vq_model).exists():
|
482 |
+
(dst_name / vq_model).unlink()
|
483 |
+
quantize_path = dst_name / "model.pth"
|
484 |
+
|
485 |
+
else:
|
486 |
+
raise ValueError(
|
487 |
+
f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
|
488 |
+
)
|
489 |
+
|
490 |
+
print(f"Writing quantized weights to {quantize_path}")
|
491 |
+
quantize_path.unlink(missing_ok=True) # remove existing file if one already there
|
492 |
+
torch.save(quantized_state_dict, quantize_path)
|
493 |
+
print(f"Quantization complete took {time.time() - t0:.02f} seconds")
|
494 |
+
|
495 |
+
|
496 |
+
if __name__ == "__main__":
|
497 |
+
quantize()
|
tools/llama/rebuild_tokenizer.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
|
2 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
3 |
+
|
4 |
+
# Initialize a tokenizer
|
5 |
+
tokenizer = Tokenizer(models.BPE())
|
6 |
+
|
7 |
+
# Customize pre-tokenization and decoding
|
8 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
9 |
+
tokenizer.decoder = decoders.ByteLevel()
|
10 |
+
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
11 |
+
|
12 |
+
# Don't train the tokenizer
|
13 |
+
trainer = trainers.BpeTrainer(
|
14 |
+
vocab_size=0,
|
15 |
+
min_frequency=2,
|
16 |
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
17 |
+
special_tokens=[
|
18 |
+
"<|begin_of_sequence|>",
|
19 |
+
"<|end_of_sequence|>",
|
20 |
+
"<|im_start|>",
|
21 |
+
"<|im_sep|>", # system, user, assistant, etc.
|
22 |
+
"<|im_end|>",
|
23 |
+
"<|semantic|>", # audio features
|
24 |
+
"<|pad|>",
|
25 |
+
],
|
26 |
+
)
|
27 |
+
|
28 |
+
# <|im_start|>user<|im_sep|>...<|im_end|>
|
29 |
+
# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
|
30 |
+
tokenizer.train_from_iterator([], trainer=trainer)
|
31 |
+
|
32 |
+
print(len(tokenizer.get_vocab()))
|
33 |
+
x = tokenizer.encode(
|
34 |
+
"Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
|
35 |
+
).ids
|
36 |
+
print(x, len(x))
|
37 |
+
print(tokenizer.decode(x, skip_special_tokens=True))
|
38 |
+
|
39 |
+
|
40 |
+
tokenizer = PreTrainedTokenizerFast(
|
41 |
+
tokenizer_object=tokenizer,
|
42 |
+
pad_token="<|pad|>",
|
43 |
+
bos_token="<|begin_of_sequence|>",
|
44 |
+
eos_token="<|end_of_sequence|>",
|
45 |
+
)
|
46 |
+
|
47 |
+
# Try tokenizing a new sequence
|
48 |
+
sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
|
49 |
+
encoded = tokenizer(sequence).input_ids
|
50 |
+
|
51 |
+
print("Test encoding....")
|
52 |
+
print(f"\tSentence: {sequence}")
|
53 |
+
print(f"\tEncoded: {encoded}")
|
54 |
+
print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
|
55 |
+
print(f"\tDecoded: {tokenizer.decode(encoded)}")
|
56 |
+
|
57 |
+
tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
|
tools/msgpack_api.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import httpx
|
6 |
+
import ormsgpack
|
7 |
+
|
8 |
+
from tools.schema import ServeReferenceAudio, ServeTTSRequest
|
9 |
+
|
10 |
+
api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
|
11 |
+
|
12 |
+
|
13 |
+
def audio_request():
|
14 |
+
# priority: ref_id > references
|
15 |
+
request = ServeTTSRequest(
|
16 |
+
text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
17 |
+
# reference_id="114514",
|
18 |
+
references=[
|
19 |
+
ServeReferenceAudio(
|
20 |
+
audio=open("lengyue.wav", "rb").read(),
|
21 |
+
text=open("lengyue.lab", "r", encoding="utf-8").read(),
|
22 |
+
)
|
23 |
+
],
|
24 |
+
streaming=True,
|
25 |
+
)
|
26 |
+
|
27 |
+
api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
|
28 |
+
|
29 |
+
with (
|
30 |
+
httpx.Client() as client,
|
31 |
+
open("hello.wav", "wb") as f,
|
32 |
+
):
|
33 |
+
with client.stream(
|
34 |
+
"POST",
|
35 |
+
"http://127.0.0.1:8080/v1/tts",
|
36 |
+
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
37 |
+
headers={
|
38 |
+
"authorization": f"Bearer {api_key}",
|
39 |
+
"content-type": "application/msgpack",
|
40 |
+
},
|
41 |
+
timeout=None,
|
42 |
+
) as response:
|
43 |
+
for chunk in response.iter_bytes():
|
44 |
+
f.write(chunk)
|
45 |
+
|
46 |
+
|
47 |
+
def asr_request(audio_path: Path):
|
48 |
+
|
49 |
+
# Read the audio file
|
50 |
+
with open(
|
51 |
+
str(audio_path),
|
52 |
+
"rb",
|
53 |
+
) as audio_file:
|
54 |
+
audio_data = audio_file.read()
|
55 |
+
|
56 |
+
# Prepare the request data
|
57 |
+
request_data = {
|
58 |
+
"audio": audio_data,
|
59 |
+
"language": "en", # Optional: specify the language
|
60 |
+
"ignore_timestamps": False, # Optional: set to True to ignore precise timestamps
|
61 |
+
}
|
62 |
+
|
63 |
+
# Send the request
|
64 |
+
with httpx.Client() as client:
|
65 |
+
response = client.post(
|
66 |
+
"https://api.fish.audio/v1/asr",
|
67 |
+
headers={
|
68 |
+
"Authorization": f"Bearer {api_key}",
|
69 |
+
"Content-Type": "application/msgpack",
|
70 |
+
},
|
71 |
+
content=ormsgpack.packb(request_data),
|
72 |
+
)
|
73 |
+
|
74 |
+
# Parse the response
|
75 |
+
result = response.json()
|
76 |
+
|
77 |
+
print(f"Transcribed text: {result['text']}")
|
78 |
+
print(f"Audio duration: {result['duration']} seconds")
|
79 |
+
|
80 |
+
for segment in result["segments"]:
|
81 |
+
print(f"Segment: {segment['text']}")
|
82 |
+
print(f"Start time: {segment['start']}, End time: {segment['end']}")
|
83 |
+
|
84 |
+
|
85 |
+
def parse_args():
|
86 |
+
parser = ArgumentParser()
|
87 |
+
parser.add_argument("--audio_path", type=Path, default="audio/ref/trump.mp3")
|
88 |
+
|
89 |
+
return parser.parse_args()
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
args = parse_args()
|
94 |
+
|
95 |
+
asr_request(args.audio_path)
|
tools/post_api.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import base64
|
3 |
+
import wave
|
4 |
+
|
5 |
+
import ormsgpack
|
6 |
+
import pyaudio
|
7 |
+
import requests
|
8 |
+
from pydub import AudioSegment
|
9 |
+
from pydub.playback import play
|
10 |
+
|
11 |
+
from tools.file import audio_to_bytes, read_ref_text
|
12 |
+
from tools.schema import ServeReferenceAudio, ServeTTSRequest
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser(
|
18 |
+
description="Send a WAV file and text to a server and receive synthesized audio.",
|
19 |
+
formatter_class=argparse.RawTextHelpFormatter,
|
20 |
+
)
|
21 |
+
|
22 |
+
parser.add_argument(
|
23 |
+
"--url",
|
24 |
+
"-u",
|
25 |
+
type=str,
|
26 |
+
default="http://127.0.0.1:8080/v1/tts",
|
27 |
+
help="URL of the server",
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--text", "-t", type=str, required=True, help="Text to be synthesized"
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--reference_id",
|
34 |
+
"-id",
|
35 |
+
type=str,
|
36 |
+
default=None,
|
37 |
+
help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)",
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--reference_audio",
|
41 |
+
"-ra",
|
42 |
+
type=str,
|
43 |
+
nargs="+",
|
44 |
+
default=None,
|
45 |
+
help="Path to the audio file",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--reference_text",
|
49 |
+
"-rt",
|
50 |
+
type=str,
|
51 |
+
nargs="+",
|
52 |
+
default=None,
|
53 |
+
help="Reference text for voice synthesis",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--output",
|
57 |
+
"-o",
|
58 |
+
type=str,
|
59 |
+
default="generated_audio",
|
60 |
+
help="Output audio file name",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--play",
|
64 |
+
type=bool,
|
65 |
+
default=True,
|
66 |
+
help="Whether to play audio after receiving data",
|
67 |
+
)
|
68 |
+
parser.add_argument("--normalize", type=bool, default=True)
|
69 |
+
parser.add_argument(
|
70 |
+
"--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
|
74 |
+
)
|
75 |
+
parser.add_argument("--opus_bitrate", type=int, default=-1000)
|
76 |
+
parser.add_argument(
|
77 |
+
"--latency",
|
78 |
+
type=str,
|
79 |
+
default="normal",
|
80 |
+
choices=["normal", "balanced"],
|
81 |
+
help="Used in api.fish.audio/v1/tts",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--max_new_tokens",
|
85 |
+
type=int,
|
86 |
+
default=0,
|
87 |
+
help="Maximum new tokens to generate. \n0 means no limit.",
|
88 |
+
)
|
89 |
+
parser.add_argument(
|
90 |
+
"--chunk_length", type=int, default=200, help="Chunk length for synthesis"
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
|
94 |
+
)
|
95 |
+
parser.add_argument(
|
96 |
+
"--repetition_penalty",
|
97 |
+
type=float,
|
98 |
+
default=1.2,
|
99 |
+
help="Repetition penalty for synthesis",
|
100 |
+
)
|
101 |
+
parser.add_argument(
|
102 |
+
"--temperature", type=float, default=0.7, help="Temperature for sampling"
|
103 |
+
)
|
104 |
+
|
105 |
+
parser.add_argument(
|
106 |
+
"--streaming", type=bool, default=False, help="Enable streaming response"
|
107 |
+
)
|
108 |
+
parser.add_argument(
|
109 |
+
"--channels", type=int, default=1, help="Number of audio channels"
|
110 |
+
)
|
111 |
+
parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
|
112 |
+
parser.add_argument(
|
113 |
+
"--use_memory_cache",
|
114 |
+
type=str,
|
115 |
+
default="never",
|
116 |
+
choices=["on-demand", "never"],
|
117 |
+
help="Cache encoded references codes in memory.\n"
|
118 |
+
"If `on-demand`, the server will use cached encodings\n "
|
119 |
+
"instead of encoding reference audio again.",
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--seed",
|
123 |
+
type=int,
|
124 |
+
default=None,
|
125 |
+
help="`None` means randomized inference, otherwise deterministic.\n"
|
126 |
+
"It can't be used for fixing a timbre.",
|
127 |
+
)
|
128 |
+
|
129 |
+
return parser.parse_args()
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == "__main__":
|
133 |
+
|
134 |
+
args = parse_args()
|
135 |
+
|
136 |
+
idstr: str | None = args.reference_id
|
137 |
+
# priority: ref_id > [{text, audio},...]
|
138 |
+
if idstr is None:
|
139 |
+
ref_audios = args.reference_audio
|
140 |
+
ref_texts = args.reference_text
|
141 |
+
if ref_audios is None:
|
142 |
+
byte_audios = []
|
143 |
+
else:
|
144 |
+
byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
|
145 |
+
if ref_texts is None:
|
146 |
+
ref_texts = []
|
147 |
+
else:
|
148 |
+
ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
|
149 |
+
else:
|
150 |
+
byte_audios = []
|
151 |
+
ref_texts = []
|
152 |
+
pass # in api.py
|
153 |
+
|
154 |
+
data = {
|
155 |
+
"text": args.text,
|
156 |
+
"references": [
|
157 |
+
ServeReferenceAudio(audio=ref_audio, text=ref_text)
|
158 |
+
for ref_text, ref_audio in zip(ref_texts, byte_audios)
|
159 |
+
],
|
160 |
+
"reference_id": idstr,
|
161 |
+
"normalize": args.normalize,
|
162 |
+
"format": args.format,
|
163 |
+
"mp3_bitrate": args.mp3_bitrate,
|
164 |
+
"opus_bitrate": args.opus_bitrate,
|
165 |
+
"max_new_tokens": args.max_new_tokens,
|
166 |
+
"chunk_length": args.chunk_length,
|
167 |
+
"top_p": args.top_p,
|
168 |
+
"repetition_penalty": args.repetition_penalty,
|
169 |
+
"temperature": args.temperature,
|
170 |
+
"streaming": args.streaming,
|
171 |
+
"use_memory_cache": args.use_memory_cache,
|
172 |
+
"seed": args.seed,
|
173 |
+
}
|
174 |
+
|
175 |
+
pydantic_data = ServeTTSRequest(**data)
|
176 |
+
|
177 |
+
response = requests.post(
|
178 |
+
args.url,
|
179 |
+
data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
180 |
+
stream=args.streaming,
|
181 |
+
headers={
|
182 |
+
"authorization": "Bearer YOUR_API_KEY",
|
183 |
+
"content-type": "application/msgpack",
|
184 |
+
},
|
185 |
+
)
|
186 |
+
|
187 |
+
if response.status_code == 200:
|
188 |
+
if args.streaming:
|
189 |
+
p = pyaudio.PyAudio()
|
190 |
+
audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
|
191 |
+
stream = p.open(
|
192 |
+
format=audio_format, channels=args.channels, rate=args.rate, output=True
|
193 |
+
)
|
194 |
+
|
195 |
+
wf = wave.open(f"{args.output}.wav", "wb")
|
196 |
+
wf.setnchannels(args.channels)
|
197 |
+
wf.setsampwidth(p.get_sample_size(audio_format))
|
198 |
+
wf.setframerate(args.rate)
|
199 |
+
|
200 |
+
stream_stopped_flag = False
|
201 |
+
|
202 |
+
try:
|
203 |
+
for chunk in response.iter_content(chunk_size=1024):
|
204 |
+
if chunk:
|
205 |
+
stream.write(chunk)
|
206 |
+
wf.writeframesraw(chunk)
|
207 |
+
else:
|
208 |
+
if not stream_stopped_flag:
|
209 |
+
stream.stop_stream()
|
210 |
+
stream_stopped_flag = True
|
211 |
+
finally:
|
212 |
+
stream.close()
|
213 |
+
p.terminate()
|
214 |
+
wf.close()
|
215 |
+
else:
|
216 |
+
audio_content = response.content
|
217 |
+
audio_path = f"{args.output}.{args.format}"
|
218 |
+
with open(audio_path, "wb") as audio_file:
|
219 |
+
audio_file.write(audio_content)
|
220 |
+
|
221 |
+
audio = AudioSegment.from_file(audio_path, format=args.format)
|
222 |
+
if args.play:
|
223 |
+
play(audio)
|
224 |
+
print(f"Audio has been saved to '{audio_path}'.")
|
225 |
+
else:
|
226 |
+
print(f"Request failed with status code {response.status_code}")
|
227 |
+
print(response.json())
|
tools/schema.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import queue
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Annotated, Literal, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
|
8 |
+
from pydantic.functional_validators import SkipValidation
|
9 |
+
|
10 |
+
from fish_speech.conversation import Message, TextPart, VQPart
|
11 |
+
|
12 |
+
GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
|
13 |
+
|
14 |
+
|
15 |
+
class ServeVQPart(BaseModel):
|
16 |
+
type: Literal["vq"] = "vq"
|
17 |
+
codes: SkipValidation[list[list[int]]]
|
18 |
+
|
19 |
+
|
20 |
+
class ServeTextPart(BaseModel):
|
21 |
+
type: Literal["text"] = "text"
|
22 |
+
text: str
|
23 |
+
|
24 |
+
|
25 |
+
class ServeAudioPart(BaseModel):
|
26 |
+
type: Literal["audio"] = "audio"
|
27 |
+
audio: bytes
|
28 |
+
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class ASRPackRequest:
|
32 |
+
audio: torch.Tensor
|
33 |
+
result_queue: queue.Queue
|
34 |
+
language: str
|
35 |
+
|
36 |
+
|
37 |
+
class ServeASRRequest(BaseModel):
|
38 |
+
# The audio should be an uncompressed PCM float16 audio
|
39 |
+
audios: list[bytes]
|
40 |
+
sample_rate: int = 44100
|
41 |
+
language: Literal["zh", "en", "ja", "auto"] = "auto"
|
42 |
+
|
43 |
+
|
44 |
+
class ServeASRTranscription(BaseModel):
|
45 |
+
text: str
|
46 |
+
duration: float
|
47 |
+
huge_gap: bool
|
48 |
+
|
49 |
+
|
50 |
+
class ServeASRSegment(BaseModel):
|
51 |
+
text: str
|
52 |
+
start: float
|
53 |
+
end: float
|
54 |
+
|
55 |
+
|
56 |
+
class ServeTimedASRResponse(BaseModel):
|
57 |
+
text: str
|
58 |
+
segments: list[ServeASRSegment]
|
59 |
+
duration: float
|
60 |
+
|
61 |
+
|
62 |
+
class ServeASRResponse(BaseModel):
|
63 |
+
transcriptions: list[ServeASRTranscription]
|
64 |
+
|
65 |
+
|
66 |
+
class ServeMessage(BaseModel):
|
67 |
+
role: Literal["system", "assistant", "user"]
|
68 |
+
parts: list[ServeVQPart | ServeTextPart]
|
69 |
+
|
70 |
+
def to_conversation_message(self):
|
71 |
+
new_message = Message(role=self.role, parts=[])
|
72 |
+
for part in self.parts:
|
73 |
+
if isinstance(part, ServeTextPart):
|
74 |
+
new_message.parts.append(TextPart(text=part.text))
|
75 |
+
elif isinstance(part, ServeVQPart):
|
76 |
+
new_message.parts.append(
|
77 |
+
VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Unsupported part type: {part}")
|
81 |
+
|
82 |
+
return new_message
|
83 |
+
|
84 |
+
|
85 |
+
class ServeRequest(BaseModel):
|
86 |
+
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
|
87 |
+
max_new_tokens: int = 1024
|
88 |
+
top_p: float = 0.7
|
89 |
+
repetition_penalty: float = 1.2
|
90 |
+
temperature: float = 0.7
|
91 |
+
streaming: bool = False
|
92 |
+
num_samples: int = 1
|
93 |
+
early_stop_threshold: float = 1.0
|
94 |
+
|
95 |
+
|
96 |
+
class ServeVQGANEncodeRequest(BaseModel):
|
97 |
+
# The audio here should be in wav, mp3, etc
|
98 |
+
audios: list[bytes]
|
99 |
+
|
100 |
+
|
101 |
+
class ServeVQGANEncodeResponse(BaseModel):
|
102 |
+
tokens: SkipValidation[list[list[list[int]]]]
|
103 |
+
|
104 |
+
|
105 |
+
class ServeVQGANDecodeRequest(BaseModel):
|
106 |
+
tokens: SkipValidation[list[list[list[int]]]]
|
107 |
+
|
108 |
+
|
109 |
+
class ServeVQGANDecodeResponse(BaseModel):
|
110 |
+
# The audio here should be in PCM float16 format
|
111 |
+
audios: list[bytes]
|
112 |
+
|
113 |
+
|
114 |
+
class ServeReferenceAudio(BaseModel):
|
115 |
+
audio: bytes
|
116 |
+
text: str
|
117 |
+
|
118 |
+
|
119 |
+
class ServeForwardMessage(BaseModel):
|
120 |
+
role: str
|
121 |
+
content: str
|
122 |
+
|
123 |
+
|
124 |
+
class ServeResponse(BaseModel):
|
125 |
+
messages: list[ServeMessage]
|
126 |
+
finish_reason: Literal["stop", "error"] | None = None
|
127 |
+
stats: dict[str, int | float | str] = {}
|
128 |
+
|
129 |
+
|
130 |
+
class ServeStreamDelta(BaseModel):
|
131 |
+
role: Literal["system", "assistant", "user"] | None = None
|
132 |
+
part: ServeVQPart | ServeTextPart | None = None
|
133 |
+
|
134 |
+
|
135 |
+
class ServeStreamResponse(BaseModel):
|
136 |
+
sample_id: int = 0
|
137 |
+
delta: ServeStreamDelta | None = None
|
138 |
+
finish_reason: Literal["stop", "error"] | None = None
|
139 |
+
stats: dict[str, int | float | str] | None = None
|
140 |
+
|
141 |
+
|
142 |
+
class ServeReferenceAudio(BaseModel):
|
143 |
+
audio: bytes
|
144 |
+
text: str
|
145 |
+
|
146 |
+
def __repr__(self) -> str:
|
147 |
+
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
|
148 |
+
|
149 |
+
|
150 |
+
class ServeChatRequestV1(BaseModel):
|
151 |
+
model: str = "llama3-8b"
|
152 |
+
messages: list[ServeForwardMessage] = []
|
153 |
+
audio: bytes | None = None
|
154 |
+
temperature: float = 1.0
|
155 |
+
top_p: float = 1.0
|
156 |
+
max_tokens: int = 256
|
157 |
+
voice: str = "jessica"
|
158 |
+
tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
|
159 |
+
tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
|
160 |
+
|
161 |
+
|
162 |
+
class ServeTTSRequest(BaseModel):
|
163 |
+
text: str
|
164 |
+
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
165 |
+
# Audio format
|
166 |
+
format: Literal["wav", "pcm", "mp3"] = "wav"
|
167 |
+
mp3_bitrate: Literal[64, 128, 192] = 128
|
168 |
+
# References audios for in-context learning
|
169 |
+
references: list[ServeReferenceAudio] = []
|
170 |
+
# Reference id
|
171 |
+
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
172 |
+
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
173 |
+
reference_id: str | None = None
|
174 |
+
seed: int | None = None
|
175 |
+
use_memory_cache: Literal["on-demand", "never"] = "never"
|
176 |
+
# Normalize text for en & zh, this increase stability for numbers
|
177 |
+
normalize: bool = True
|
178 |
+
mp3_bitrate: Optional[int] = 64
|
179 |
+
opus_bitrate: Optional[int] = -1000
|
180 |
+
# Balance mode will reduce latency to 300ms, but may decrease stability
|
181 |
+
latency: Literal["normal", "balanced"] = "normal"
|
182 |
+
# not usually used below
|
183 |
+
streaming: bool = False
|
184 |
+
max_new_tokens: int = 1024
|
185 |
+
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
186 |
+
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
187 |
+
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
tools/sensevoice/README.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FunASR Command Line Interface
|
2 |
+
|
3 |
+
This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
|
4 |
+
|
5 |
+
## Requirements
|
6 |
+
|
7 |
+
- Python >= 3.10
|
8 |
+
- PyTorch <= 2.3.1
|
9 |
+
- ffmpeg, pydub, audio-separator[gpu].
|
10 |
+
|
11 |
+
## Installation
|
12 |
+
|
13 |
+
Install the required packages:
|
14 |
+
|
15 |
+
```bash
|
16 |
+
pip install -e .[stable]
|
17 |
+
```
|
18 |
+
|
19 |
+
Make sure you have `ffmpeg` installed and available in your `PATH`.
|
20 |
+
|
21 |
+
## Usage
|
22 |
+
|
23 |
+
### Basic Usage
|
24 |
+
|
25 |
+
To run the tool with default settings:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
python tools/sensevoice/fun_asr.py --audio-dir <audio_directory> --save-dir <output_directory>
|
29 |
+
```
|
30 |
+
|
31 |
+
## Options
|
32 |
+
|
33 |
+
| Option | Description |
|
34 |
+
| :-----------------------: | :---------------------------------------------------------------------------: |
|
35 |
+
| --audio-dir | Directory containing audio or video files. |
|
36 |
+
| --save-dir | Directory to save processed audio files. |
|
37 |
+
| --device | Device to use for processing. Options: cuda (default) or cpu. |
|
38 |
+
| --language | Language of the transcription. Default is auto. |
|
39 |
+
| --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
|
40 |
+
| --punc | Enable punctuation prediction. |
|
41 |
+
| --denoise | Enable noise reduction (vocal separation). |
|
42 |
+
|
43 |
+
## Example
|
44 |
+
|
45 |
+
To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
|
46 |
+
|
47 |
+
```bash
|
48 |
+
python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
|
49 |
+
```
|
50 |
+
|
51 |
+
## Additional Notes
|
52 |
+
|
53 |
+
- The tool supports `both audio and video files`. Videos will be converted to audio automatically.
|
54 |
+
- If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
|
55 |
+
- The script will automatically create necessary directories in the `--save-dir`.
|
56 |
+
|
57 |
+
## Troubleshooting
|
58 |
+
|
59 |
+
If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
|
tools/sensevoice/__init__.py
ADDED
File without changes
|
tools/sensevoice/auto_model.py
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import os.path
|
10 |
+
import random
|
11 |
+
import re
|
12 |
+
import string
|
13 |
+
import time
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from funasr.download.download_model_from_hub import download_model
|
18 |
+
from funasr.download.file import download_from_url
|
19 |
+
from funasr.register import tables
|
20 |
+
from funasr.train_utils.load_pretrained_model import load_pretrained_model
|
21 |
+
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
22 |
+
from funasr.utils import export_utils, misc
|
23 |
+
from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
|
24 |
+
from funasr.utils.misc import deep_update
|
25 |
+
from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
from .vad_utils import merge_vad, slice_padding_audio_samples
|
29 |
+
|
30 |
+
try:
|
31 |
+
from funasr.models.campplus.cluster_backend import ClusterBackend
|
32 |
+
from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
|
33 |
+
except:
|
34 |
+
pass
|
35 |
+
|
36 |
+
|
37 |
+
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
|
38 |
+
""" """
|
39 |
+
data_list = []
|
40 |
+
key_list = []
|
41 |
+
filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
|
42 |
+
|
43 |
+
chars = string.ascii_letters + string.digits
|
44 |
+
if isinstance(data_in, str):
|
45 |
+
if data_in.startswith("http://") or data_in.startswith("https://"): # url
|
46 |
+
data_in = download_from_url(data_in)
|
47 |
+
|
48 |
+
if isinstance(data_in, str) and os.path.exists(
|
49 |
+
data_in
|
50 |
+
): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
|
51 |
+
_, file_extension = os.path.splitext(data_in)
|
52 |
+
file_extension = file_extension.lower()
|
53 |
+
if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
|
54 |
+
with open(data_in, encoding="utf-8") as fin:
|
55 |
+
for line in fin:
|
56 |
+
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
57 |
+
if data_in.endswith(
|
58 |
+
".jsonl"
|
59 |
+
): # file.jsonl: json.dumps({"source": data})
|
60 |
+
lines = json.loads(line.strip())
|
61 |
+
data = lines["source"]
|
62 |
+
key = data["key"] if "key" in data else key
|
63 |
+
else: # filelist, wav.scp, text.txt: id \t data or data
|
64 |
+
lines = line.strip().split(maxsplit=1)
|
65 |
+
data = lines[1] if len(lines) > 1 else lines[0]
|
66 |
+
key = lines[0] if len(lines) > 1 else key
|
67 |
+
|
68 |
+
data_list.append(data)
|
69 |
+
key_list.append(key)
|
70 |
+
else:
|
71 |
+
if key is None:
|
72 |
+
# key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
73 |
+
key = misc.extract_filename_without_extension(data_in)
|
74 |
+
data_list = [data_in]
|
75 |
+
key_list = [key]
|
76 |
+
elif isinstance(data_in, (list, tuple)):
|
77 |
+
if data_type is not None and isinstance(
|
78 |
+
data_type, (list, tuple)
|
79 |
+
): # mutiple inputs
|
80 |
+
data_list_tmp = []
|
81 |
+
for data_in_i, data_type_i in zip(data_in, data_type):
|
82 |
+
key_list, data_list_i = prepare_data_iterator(
|
83 |
+
data_in=data_in_i, data_type=data_type_i
|
84 |
+
)
|
85 |
+
data_list_tmp.append(data_list_i)
|
86 |
+
data_list = []
|
87 |
+
for item in zip(*data_list_tmp):
|
88 |
+
data_list.append(item)
|
89 |
+
else:
|
90 |
+
# [audio sample point, fbank, text]
|
91 |
+
data_list = data_in
|
92 |
+
key_list = []
|
93 |
+
for data_i in data_in:
|
94 |
+
if isinstance(data_i, str) and os.path.exists(data_i):
|
95 |
+
key = misc.extract_filename_without_extension(data_i)
|
96 |
+
else:
|
97 |
+
if key is None:
|
98 |
+
key = "rand_key_" + "".join(
|
99 |
+
random.choice(chars) for _ in range(13)
|
100 |
+
)
|
101 |
+
key_list.append(key)
|
102 |
+
|
103 |
+
else: # raw text; audio sample point, fbank; bytes
|
104 |
+
if isinstance(data_in, bytes): # audio bytes
|
105 |
+
data_in = load_bytes(data_in)
|
106 |
+
if key is None:
|
107 |
+
key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
|
108 |
+
data_list = [data_in]
|
109 |
+
key_list = [key]
|
110 |
+
|
111 |
+
return key_list, data_list
|
112 |
+
|
113 |
+
|
114 |
+
class AutoModel:
|
115 |
+
|
116 |
+
def __init__(self, **kwargs):
|
117 |
+
|
118 |
+
try:
|
119 |
+
from funasr.utils.version_checker import check_for_update
|
120 |
+
|
121 |
+
print(
|
122 |
+
"Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
|
123 |
+
)
|
124 |
+
check_for_update(disable=kwargs.get("disable_update", False))
|
125 |
+
except:
|
126 |
+
pass
|
127 |
+
|
128 |
+
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
|
129 |
+
logging.basicConfig(level=log_level)
|
130 |
+
|
131 |
+
model, kwargs = self.build_model(**kwargs)
|
132 |
+
|
133 |
+
# if vad_model is not None, build vad model else None
|
134 |
+
vad_model = kwargs.get("vad_model", None)
|
135 |
+
vad_kwargs = (
|
136 |
+
{} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
|
137 |
+
)
|
138 |
+
if vad_model is not None:
|
139 |
+
logging.info("Building VAD model.")
|
140 |
+
vad_kwargs["model"] = vad_model
|
141 |
+
vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
|
142 |
+
vad_kwargs["device"] = kwargs["device"]
|
143 |
+
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
|
144 |
+
|
145 |
+
# if punc_model is not None, build punc model else None
|
146 |
+
punc_model = kwargs.get("punc_model", None)
|
147 |
+
punc_kwargs = (
|
148 |
+
{}
|
149 |
+
if kwargs.get("punc_kwargs", {}) is None
|
150 |
+
else kwargs.get("punc_kwargs", {})
|
151 |
+
)
|
152 |
+
if punc_model is not None:
|
153 |
+
logging.info("Building punc model.")
|
154 |
+
punc_kwargs["model"] = punc_model
|
155 |
+
punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
|
156 |
+
punc_kwargs["device"] = kwargs["device"]
|
157 |
+
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
|
158 |
+
|
159 |
+
# if spk_model is not None, build spk model else None
|
160 |
+
spk_model = kwargs.get("spk_model", None)
|
161 |
+
spk_kwargs = (
|
162 |
+
{} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
|
163 |
+
)
|
164 |
+
if spk_model is not None:
|
165 |
+
logging.info("Building SPK model.")
|
166 |
+
spk_kwargs["model"] = spk_model
|
167 |
+
spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
|
168 |
+
spk_kwargs["device"] = kwargs["device"]
|
169 |
+
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
|
170 |
+
self.cb_model = ClusterBackend().to(kwargs["device"])
|
171 |
+
spk_mode = kwargs.get("spk_mode", "punc_segment")
|
172 |
+
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
|
173 |
+
logging.error(
|
174 |
+
"spk_mode should be one of default, vad_segment and punc_segment."
|
175 |
+
)
|
176 |
+
self.spk_mode = spk_mode
|
177 |
+
|
178 |
+
self.kwargs = kwargs
|
179 |
+
self.model = model
|
180 |
+
self.vad_model = vad_model
|
181 |
+
self.vad_kwargs = vad_kwargs
|
182 |
+
self.punc_model = punc_model
|
183 |
+
self.punc_kwargs = punc_kwargs
|
184 |
+
self.spk_model = spk_model
|
185 |
+
self.spk_kwargs = spk_kwargs
|
186 |
+
self.model_path = kwargs.get("model_path")
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def build_model(**kwargs):
|
190 |
+
assert "model" in kwargs
|
191 |
+
if "model_conf" not in kwargs:
|
192 |
+
logging.info(
|
193 |
+
"download models from model hub: {}".format(kwargs.get("hub", "ms"))
|
194 |
+
)
|
195 |
+
kwargs = download_model(**kwargs)
|
196 |
+
|
197 |
+
set_all_random_seed(kwargs.get("seed", 0))
|
198 |
+
|
199 |
+
device = kwargs.get("device", "cuda")
|
200 |
+
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
|
201 |
+
device = "cpu"
|
202 |
+
kwargs["batch_size"] = 1
|
203 |
+
kwargs["device"] = device
|
204 |
+
|
205 |
+
torch.set_num_threads(kwargs.get("ncpu", 4))
|
206 |
+
|
207 |
+
# build tokenizer
|
208 |
+
tokenizer = kwargs.get("tokenizer", None)
|
209 |
+
if tokenizer is not None:
|
210 |
+
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
|
211 |
+
tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
|
212 |
+
kwargs["token_list"] = (
|
213 |
+
tokenizer.token_list if hasattr(tokenizer, "token_list") else None
|
214 |
+
)
|
215 |
+
kwargs["token_list"] = (
|
216 |
+
tokenizer.get_vocab()
|
217 |
+
if hasattr(tokenizer, "get_vocab")
|
218 |
+
else kwargs["token_list"]
|
219 |
+
)
|
220 |
+
vocab_size = (
|
221 |
+
len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
|
222 |
+
)
|
223 |
+
if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
|
224 |
+
vocab_size = tokenizer.get_vocab_size()
|
225 |
+
else:
|
226 |
+
vocab_size = -1
|
227 |
+
kwargs["tokenizer"] = tokenizer
|
228 |
+
|
229 |
+
# build frontend
|
230 |
+
frontend = kwargs.get("frontend", None)
|
231 |
+
kwargs["input_size"] = None
|
232 |
+
if frontend is not None:
|
233 |
+
frontend_class = tables.frontend_classes.get(frontend)
|
234 |
+
frontend = frontend_class(**kwargs.get("frontend_conf", {}))
|
235 |
+
kwargs["input_size"] = (
|
236 |
+
frontend.output_size() if hasattr(frontend, "output_size") else None
|
237 |
+
)
|
238 |
+
kwargs["frontend"] = frontend
|
239 |
+
# build model
|
240 |
+
model_class = tables.model_classes.get(kwargs["model"])
|
241 |
+
assert model_class is not None, f'{kwargs["model"]} is not registered'
|
242 |
+
model_conf = {}
|
243 |
+
deep_update(model_conf, kwargs.get("model_conf", {}))
|
244 |
+
deep_update(model_conf, kwargs)
|
245 |
+
model = model_class(**model_conf, vocab_size=vocab_size)
|
246 |
+
|
247 |
+
# init_param
|
248 |
+
init_param = kwargs.get("init_param", None)
|
249 |
+
if init_param is not None:
|
250 |
+
if os.path.exists(init_param):
|
251 |
+
logging.info(f"Loading pretrained params from {init_param}")
|
252 |
+
load_pretrained_model(
|
253 |
+
model=model,
|
254 |
+
path=init_param,
|
255 |
+
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
256 |
+
oss_bucket=kwargs.get("oss_bucket", None),
|
257 |
+
scope_map=kwargs.get("scope_map", []),
|
258 |
+
excludes=kwargs.get("excludes", None),
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
print(f"error, init_param does not exist!: {init_param}")
|
262 |
+
|
263 |
+
# fp16
|
264 |
+
if kwargs.get("fp16", False):
|
265 |
+
model.to(torch.float16)
|
266 |
+
elif kwargs.get("bf16", False):
|
267 |
+
model.to(torch.bfloat16)
|
268 |
+
model.to(device)
|
269 |
+
|
270 |
+
if not kwargs.get("disable_log", True):
|
271 |
+
tables.print()
|
272 |
+
|
273 |
+
return model, kwargs
|
274 |
+
|
275 |
+
def __call__(self, *args, **cfg):
|
276 |
+
kwargs = self.kwargs
|
277 |
+
deep_update(kwargs, cfg)
|
278 |
+
res = self.model(*args, kwargs)
|
279 |
+
return res
|
280 |
+
|
281 |
+
def generate(self, input, input_len=None, **cfg):
|
282 |
+
if self.vad_model is None:
|
283 |
+
return self.inference(input, input_len=input_len, **cfg)
|
284 |
+
|
285 |
+
else:
|
286 |
+
return self.inference_with_vad(input, input_len=input_len, **cfg)
|
287 |
+
|
288 |
+
def inference(
|
289 |
+
self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
|
290 |
+
):
|
291 |
+
kwargs = self.kwargs if kwargs is None else kwargs
|
292 |
+
if "cache" in kwargs:
|
293 |
+
kwargs.pop("cache")
|
294 |
+
deep_update(kwargs, cfg)
|
295 |
+
model = self.model if model is None else model
|
296 |
+
model.eval()
|
297 |
+
|
298 |
+
batch_size = kwargs.get("batch_size", 1)
|
299 |
+
# if kwargs.get("device", "cpu") == "cpu":
|
300 |
+
# batch_size = 1
|
301 |
+
|
302 |
+
key_list, data_list = prepare_data_iterator(
|
303 |
+
input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
|
304 |
+
)
|
305 |
+
|
306 |
+
speed_stats = {}
|
307 |
+
asr_result_list = []
|
308 |
+
num_samples = len(data_list)
|
309 |
+
disable_pbar = self.kwargs.get("disable_pbar", False)
|
310 |
+
pbar = (
|
311 |
+
tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
|
312 |
+
if not disable_pbar
|
313 |
+
else None
|
314 |
+
)
|
315 |
+
time_speech_total = 0.0
|
316 |
+
time_escape_total = 0.0
|
317 |
+
for beg_idx in range(0, num_samples, batch_size):
|
318 |
+
end_idx = min(num_samples, beg_idx + batch_size)
|
319 |
+
data_batch = data_list[beg_idx:end_idx]
|
320 |
+
key_batch = key_list[beg_idx:end_idx]
|
321 |
+
batch = {"data_in": data_batch, "key": key_batch}
|
322 |
+
|
323 |
+
if (end_idx - beg_idx) == 1 and kwargs.get(
|
324 |
+
"data_type", None
|
325 |
+
) == "fbank": # fbank
|
326 |
+
batch["data_in"] = data_batch[0]
|
327 |
+
batch["data_lengths"] = input_len
|
328 |
+
|
329 |
+
time1 = time.perf_counter()
|
330 |
+
with torch.no_grad():
|
331 |
+
res = model.inference(**batch, **kwargs)
|
332 |
+
if isinstance(res, (list, tuple)):
|
333 |
+
results = res[0] if len(res) > 0 else [{"text": ""}]
|
334 |
+
meta_data = res[1] if len(res) > 1 else {}
|
335 |
+
time2 = time.perf_counter()
|
336 |
+
|
337 |
+
asr_result_list.extend(results)
|
338 |
+
|
339 |
+
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
|
340 |
+
batch_data_time = meta_data.get("batch_data_time", -1)
|
341 |
+
time_escape = time2 - time1
|
342 |
+
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
|
343 |
+
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
|
344 |
+
speed_stats["forward"] = f"{time_escape:0.3f}"
|
345 |
+
speed_stats["batch_size"] = f"{len(results)}"
|
346 |
+
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
|
347 |
+
description = f"{speed_stats}, "
|
348 |
+
if pbar:
|
349 |
+
pbar.update(end_idx - beg_idx)
|
350 |
+
pbar.set_description(description)
|
351 |
+
time_speech_total += batch_data_time
|
352 |
+
time_escape_total += time_escape
|
353 |
+
|
354 |
+
if pbar:
|
355 |
+
# pbar.update(1)
|
356 |
+
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
357 |
+
torch.cuda.empty_cache()
|
358 |
+
return asr_result_list
|
359 |
+
|
360 |
+
def vad(self, input, input_len=None, **cfg):
|
361 |
+
kwargs = self.kwargs
|
362 |
+
# step.1: compute the vad model
|
363 |
+
deep_update(self.vad_kwargs, cfg)
|
364 |
+
beg_vad = time.time()
|
365 |
+
res = self.inference(
|
366 |
+
input,
|
367 |
+
input_len=input_len,
|
368 |
+
model=self.vad_model,
|
369 |
+
kwargs=self.vad_kwargs,
|
370 |
+
**cfg,
|
371 |
+
)
|
372 |
+
end_vad = time.time()
|
373 |
+
# FIX(gcf): concat the vad clips for sense vocie model for better aed
|
374 |
+
if cfg.get("merge_vad", False):
|
375 |
+
for i in range(len(res)):
|
376 |
+
res[i]["value"] = merge_vad(
|
377 |
+
res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
|
378 |
+
)
|
379 |
+
elapsed = end_vad - beg_vad
|
380 |
+
return elapsed, res
|
381 |
+
|
382 |
+
def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
|
383 |
+
|
384 |
+
kwargs = self.kwargs
|
385 |
+
|
386 |
+
# step.2 compute asr model
|
387 |
+
model = self.model
|
388 |
+
deep_update(kwargs, cfg)
|
389 |
+
batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
|
390 |
+
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
|
391 |
+
kwargs["batch_size"] = batch_size
|
392 |
+
|
393 |
+
key_list, data_list = prepare_data_iterator(
|
394 |
+
input, input_len=input_len, data_type=kwargs.get("data_type", None)
|
395 |
+
)
|
396 |
+
results_ret_list = []
|
397 |
+
time_speech_total_all_samples = 1e-6
|
398 |
+
|
399 |
+
beg_total = time.time()
|
400 |
+
pbar_total = (
|
401 |
+
tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
|
402 |
+
if not kwargs.get("disable_pbar", False)
|
403 |
+
else None
|
404 |
+
)
|
405 |
+
|
406 |
+
for i in range(len(vad_res)):
|
407 |
+
key = vad_res[i]["key"]
|
408 |
+
vadsegments = vad_res[i]["value"]
|
409 |
+
input_i = data_list[i]
|
410 |
+
fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
|
411 |
+
speech = load_audio_text_image_video(
|
412 |
+
input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
|
413 |
+
)
|
414 |
+
speech_lengths = len(speech)
|
415 |
+
n = len(vadsegments)
|
416 |
+
data_with_index = [(vadsegments[i], i) for i in range(n)]
|
417 |
+
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
|
418 |
+
results_sorted = []
|
419 |
+
|
420 |
+
if not len(sorted_data):
|
421 |
+
results_ret_list.append({"key": key, "text": "", "timestamp": []})
|
422 |
+
logging.info("decoding, utt: {}, empty speech".format(key))
|
423 |
+
continue
|
424 |
+
|
425 |
+
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
426 |
+
batch_size = max(
|
427 |
+
batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
|
428 |
+
)
|
429 |
+
|
430 |
+
if kwargs["device"] == "cpu":
|
431 |
+
batch_size = 0
|
432 |
+
|
433 |
+
beg_idx = 0
|
434 |
+
beg_asr_total = time.time()
|
435 |
+
time_speech_total_per_sample = speech_lengths / 16000
|
436 |
+
time_speech_total_all_samples += time_speech_total_per_sample
|
437 |
+
|
438 |
+
# pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
|
439 |
+
|
440 |
+
all_segments = []
|
441 |
+
max_len_in_batch = 0
|
442 |
+
end_idx = 1
|
443 |
+
|
444 |
+
for j, _ in enumerate(range(0, n)):
|
445 |
+
# pbar_sample.update(1)
|
446 |
+
sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
|
447 |
+
potential_batch_length = max(max_len_in_batch, sample_length) * (
|
448 |
+
j + 1 - beg_idx
|
449 |
+
)
|
450 |
+
# batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
|
451 |
+
if (
|
452 |
+
j < n - 1
|
453 |
+
and sample_length < batch_size_threshold_ms
|
454 |
+
and potential_batch_length < batch_size
|
455 |
+
):
|
456 |
+
max_len_in_batch = max(max_len_in_batch, sample_length)
|
457 |
+
end_idx += 1
|
458 |
+
continue
|
459 |
+
|
460 |
+
speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
|
461 |
+
speech, speech_lengths, sorted_data[beg_idx:end_idx]
|
462 |
+
)
|
463 |
+
results = self.inference(
|
464 |
+
speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
|
465 |
+
)
|
466 |
+
|
467 |
+
for _b in range(len(speech_j)):
|
468 |
+
results[_b]["interval"] = intervals[_b]
|
469 |
+
|
470 |
+
if self.spk_model is not None:
|
471 |
+
# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
|
472 |
+
for _b in range(len(speech_j)):
|
473 |
+
vad_segments = [
|
474 |
+
[
|
475 |
+
sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
|
476 |
+
sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
|
477 |
+
np.array(speech_j[_b]),
|
478 |
+
]
|
479 |
+
]
|
480 |
+
segments = sv_chunk(vad_segments)
|
481 |
+
all_segments.extend(segments)
|
482 |
+
speech_b = [i[2] for i in segments]
|
483 |
+
spk_res = self.inference(
|
484 |
+
speech_b,
|
485 |
+
input_len=None,
|
486 |
+
model=self.spk_model,
|
487 |
+
kwargs=kwargs,
|
488 |
+
**cfg,
|
489 |
+
)
|
490 |
+
results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
|
491 |
+
|
492 |
+
beg_idx = end_idx
|
493 |
+
end_idx += 1
|
494 |
+
max_len_in_batch = sample_length
|
495 |
+
if len(results) < 1:
|
496 |
+
continue
|
497 |
+
results_sorted.extend(results)
|
498 |
+
|
499 |
+
# end_asr_total = time.time()
|
500 |
+
# time_escape_total_per_sample = end_asr_total - beg_asr_total
|
501 |
+
# pbar_sample.update(1)
|
502 |
+
# pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
503 |
+
# f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
|
504 |
+
# f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
|
505 |
+
|
506 |
+
restored_data = [0] * n
|
507 |
+
for j in range(n):
|
508 |
+
index = sorted_data[j][1]
|
509 |
+
cur = results_sorted[j]
|
510 |
+
pattern = r"<\|([^|]+)\|>"
|
511 |
+
emotion_string = re.findall(pattern, cur["text"])
|
512 |
+
cur["text"] = re.sub(pattern, "", cur["text"])
|
513 |
+
cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
|
514 |
+
if self.punc_model is not None and len(cur["text"].strip()) > 0:
|
515 |
+
deep_update(self.punc_kwargs, cfg)
|
516 |
+
punc_res = self.inference(
|
517 |
+
cur["text"],
|
518 |
+
model=self.punc_model,
|
519 |
+
kwargs=self.punc_kwargs,
|
520 |
+
**cfg,
|
521 |
+
)
|
522 |
+
cur["text"] = punc_res[0]["text"]
|
523 |
+
|
524 |
+
restored_data[index] = cur
|
525 |
+
|
526 |
+
end_asr_total = time.time()
|
527 |
+
time_escape_total_per_sample = end_asr_total - beg_asr_total
|
528 |
+
if pbar_total:
|
529 |
+
pbar_total.update(1)
|
530 |
+
pbar_total.set_description(
|
531 |
+
f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
532 |
+
f"time_speech: {time_speech_total_per_sample: 0.3f}, "
|
533 |
+
f"time_escape: {time_escape_total_per_sample:0.3f}"
|
534 |
+
)
|
535 |
+
|
536 |
+
# end_total = time.time()
|
537 |
+
# time_escape_total_all_samples = end_total - beg_total
|
538 |
+
# print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
|
539 |
+
# f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
|
540 |
+
# f"time_escape_all: {time_escape_total_all_samples:0.3f}")
|
541 |
+
return restored_data
|
542 |
+
|
543 |
+
def export(self, input=None, **cfg):
|
544 |
+
"""
|
545 |
+
|
546 |
+
:param input:
|
547 |
+
:param type:
|
548 |
+
:param quantize:
|
549 |
+
:param fallback_num:
|
550 |
+
:param calib_num:
|
551 |
+
:param opset_version:
|
552 |
+
:param cfg:
|
553 |
+
:return:
|
554 |
+
"""
|
555 |
+
|
556 |
+
device = cfg.get("device", "cpu")
|
557 |
+
model = self.model.to(device=device)
|
558 |
+
kwargs = self.kwargs
|
559 |
+
deep_update(kwargs, cfg)
|
560 |
+
kwargs["device"] = device
|
561 |
+
del kwargs["model"]
|
562 |
+
model.eval()
|
563 |
+
|
564 |
+
type = kwargs.get("type", "onnx")
|
565 |
+
|
566 |
+
key_list, data_list = prepare_data_iterator(
|
567 |
+
input, input_len=None, data_type=kwargs.get("data_type", None), key=None
|
568 |
+
)
|
569 |
+
|
570 |
+
with torch.no_grad():
|
571 |
+
export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
|
572 |
+
|
573 |
+
return export_dir
|
tools/sensevoice/fun_asr.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
from audio_separator.separator import Separator
|
6 |
+
|
7 |
+
os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
|
8 |
+
os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
|
9 |
+
import json
|
10 |
+
import subprocess
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import click
|
14 |
+
import torch
|
15 |
+
from loguru import logger
|
16 |
+
from pydub import AudioSegment
|
17 |
+
from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
|
21 |
+
from tools.sensevoice.auto_model import AutoModel
|
22 |
+
|
23 |
+
|
24 |
+
def uvr5_cli(
|
25 |
+
audio_dir: Path,
|
26 |
+
output_folder: Path,
|
27 |
+
audio_files: list[Path] | None = None,
|
28 |
+
output_format: str = "flac",
|
29 |
+
model: str = "BS-Roformer-Viperx-1297.ckpt",
|
30 |
+
):
|
31 |
+
# ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
|
32 |
+
sepr = Separator(
|
33 |
+
model_file_dir=os.environ["UVR5_CACHE"],
|
34 |
+
output_dir=output_folder,
|
35 |
+
output_format=output_format,
|
36 |
+
)
|
37 |
+
dictmodel = {
|
38 |
+
"BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
|
39 |
+
"BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
|
40 |
+
"BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
|
41 |
+
"Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
|
42 |
+
}
|
43 |
+
roformer_model = dictmodel[model]
|
44 |
+
sepr.load_model(roformer_model)
|
45 |
+
if audio_files is None:
|
46 |
+
audio_files = list_files(
|
47 |
+
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
48 |
+
)
|
49 |
+
total_files = len(audio_files)
|
50 |
+
|
51 |
+
print(f"{total_files} audio files found")
|
52 |
+
|
53 |
+
res = []
|
54 |
+
for audio in tqdm(audio_files, desc="Denoising: "):
|
55 |
+
file_path = str(audio_dir / audio)
|
56 |
+
sep_out = sepr.separate(file_path)
|
57 |
+
if isinstance(sep_out, str):
|
58 |
+
res.append(sep_out)
|
59 |
+
elif isinstance(sep_out, list):
|
60 |
+
res.extend(sep_out)
|
61 |
+
del sepr
|
62 |
+
gc.collect()
|
63 |
+
if torch.cuda.is_available():
|
64 |
+
torch.cuda.empty_cache()
|
65 |
+
|
66 |
+
return res, roformer_model
|
67 |
+
|
68 |
+
|
69 |
+
def get_sample_rate(media_path: Path):
|
70 |
+
result = subprocess.run(
|
71 |
+
[
|
72 |
+
"ffprobe",
|
73 |
+
"-v",
|
74 |
+
"quiet",
|
75 |
+
"-print_format",
|
76 |
+
"json",
|
77 |
+
"-show_streams",
|
78 |
+
str(media_path),
|
79 |
+
],
|
80 |
+
capture_output=True,
|
81 |
+
text=True,
|
82 |
+
check=True,
|
83 |
+
)
|
84 |
+
media_info = json.loads(result.stdout)
|
85 |
+
for stream in media_info.get("streams", []):
|
86 |
+
if stream.get("codec_type") == "audio":
|
87 |
+
return stream.get("sample_rate")
|
88 |
+
return "44100" # Default sample rate if not found
|
89 |
+
|
90 |
+
|
91 |
+
def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
|
92 |
+
sr = get_sample_rate(src_path)
|
93 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
94 |
+
if src_path.resolve() == out_path.resolve():
|
95 |
+
output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
|
96 |
+
else:
|
97 |
+
output = str(out_path)
|
98 |
+
subprocess.run(
|
99 |
+
[
|
100 |
+
"ffmpeg",
|
101 |
+
"-loglevel",
|
102 |
+
"error",
|
103 |
+
"-i",
|
104 |
+
str(src_path),
|
105 |
+
"-acodec",
|
106 |
+
"pcm_s16le" if out_fmt == "wav" else "flac",
|
107 |
+
"-ar",
|
108 |
+
sr,
|
109 |
+
"-ac",
|
110 |
+
"1",
|
111 |
+
"-y",
|
112 |
+
output,
|
113 |
+
],
|
114 |
+
check=True,
|
115 |
+
)
|
116 |
+
return out_path
|
117 |
+
|
118 |
+
|
119 |
+
def convert_video_to_audio(video_path: Path, audio_dir: Path):
|
120 |
+
cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
|
121 |
+
vocals = [
|
122 |
+
p
|
123 |
+
for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
|
124 |
+
if p.suffix in AUDIO_EXTENSIONS
|
125 |
+
]
|
126 |
+
if len(vocals) > 0:
|
127 |
+
return vocals[0]
|
128 |
+
audio_path = cur_dir / f"{video_path.stem}.wav"
|
129 |
+
convert_to_mono(video_path, audio_path)
|
130 |
+
return audio_path
|
131 |
+
|
132 |
+
|
133 |
+
@click.command()
|
134 |
+
@click.option("--audio-dir", required=True, help="Directory containing audio files")
|
135 |
+
@click.option(
|
136 |
+
"--save-dir", required=True, help="Directory to save processed audio files"
|
137 |
+
)
|
138 |
+
@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
|
139 |
+
@click.option("--language", default="auto", help="Language of the transcription")
|
140 |
+
@click.option(
|
141 |
+
"--max_single_segment_time",
|
142 |
+
default=20000,
|
143 |
+
type=int,
|
144 |
+
help="Maximum of Output single audio duration(ms)",
|
145 |
+
)
|
146 |
+
@click.option("--fsmn-vad/--silero-vad", default=False)
|
147 |
+
@click.option("--punc/--no-punc", default=False)
|
148 |
+
@click.option("--denoise/--no-denoise", default=False)
|
149 |
+
@click.option("--save_emo/--no_save_emo", default=False)
|
150 |
+
def main(
|
151 |
+
audio_dir: str,
|
152 |
+
save_dir: str,
|
153 |
+
device: str,
|
154 |
+
language: str,
|
155 |
+
max_single_segment_time: int,
|
156 |
+
fsmn_vad: bool,
|
157 |
+
punc: bool,
|
158 |
+
denoise: bool,
|
159 |
+
save_emo: bool,
|
160 |
+
):
|
161 |
+
|
162 |
+
audios_path = Path(audio_dir)
|
163 |
+
save_path = Path(save_dir)
|
164 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
165 |
+
|
166 |
+
video_files = list_files(
|
167 |
+
path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
|
168 |
+
)
|
169 |
+
v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
|
170 |
+
|
171 |
+
if denoise:
|
172 |
+
VOCAL = "_(Vocals)"
|
173 |
+
original_files = [
|
174 |
+
p
|
175 |
+
for p in audios_path.glob("**/*")
|
176 |
+
if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
|
177 |
+
]
|
178 |
+
|
179 |
+
_, cur_model = uvr5_cli(
|
180 |
+
audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
|
181 |
+
)
|
182 |
+
need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
|
183 |
+
need_remove.extend(original_files)
|
184 |
+
for _ in need_remove:
|
185 |
+
_.unlink()
|
186 |
+
vocal_files = [
|
187 |
+
p
|
188 |
+
for p in audios_path.glob("**/*")
|
189 |
+
if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
|
190 |
+
]
|
191 |
+
for f in vocal_files:
|
192 |
+
fn, ext = f.stem, f.suffix
|
193 |
+
|
194 |
+
v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
|
195 |
+
if v_pos != -1:
|
196 |
+
new_fn = fn[: v_pos + len(VOCAL)]
|
197 |
+
new_f = f.with_name(new_fn + ext)
|
198 |
+
f = f.rename(new_f)
|
199 |
+
convert_to_mono(f, f, "flac")
|
200 |
+
f.unlink()
|
201 |
+
|
202 |
+
audio_files = list_files(
|
203 |
+
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
204 |
+
)
|
205 |
+
|
206 |
+
logger.info("Loading / Downloading Funasr model...")
|
207 |
+
|
208 |
+
model_dir = "iic/SenseVoiceSmall"
|
209 |
+
|
210 |
+
vad_model = "fsmn-vad" if fsmn_vad else None
|
211 |
+
vad_kwargs = {"max_single_segment_time": max_single_segment_time}
|
212 |
+
punc_model = "ct-punc" if punc else None
|
213 |
+
|
214 |
+
manager = AutoModel(
|
215 |
+
model=model_dir,
|
216 |
+
trust_remote_code=False,
|
217 |
+
vad_model=vad_model,
|
218 |
+
vad_kwargs=vad_kwargs,
|
219 |
+
punc_model=punc_model,
|
220 |
+
device=device,
|
221 |
+
)
|
222 |
+
|
223 |
+
if not fsmn_vad and vad_model is None:
|
224 |
+
vad_model = load_silero_vad()
|
225 |
+
|
226 |
+
logger.info("Model loaded.")
|
227 |
+
|
228 |
+
pattern = re.compile(r"_\d{3}\.")
|
229 |
+
|
230 |
+
for file_path in tqdm(audio_files, desc="Processing audio file"):
|
231 |
+
|
232 |
+
if pattern.search(file_path.name):
|
233 |
+
# logger.info(f"Skipping {file_path} as it has already been processed.")
|
234 |
+
continue
|
235 |
+
|
236 |
+
file_stem = file_path.stem
|
237 |
+
file_suffix = file_path.suffix
|
238 |
+
|
239 |
+
rel_path = Path(file_path).relative_to(audio_dir)
|
240 |
+
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
|
241 |
+
|
242 |
+
audio = AudioSegment.from_file(file_path)
|
243 |
+
|
244 |
+
cfg = dict(
|
245 |
+
cache={},
|
246 |
+
language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
|
247 |
+
use_itn=False,
|
248 |
+
batch_size_s=60,
|
249 |
+
)
|
250 |
+
|
251 |
+
if fsmn_vad:
|
252 |
+
elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
|
253 |
+
else:
|
254 |
+
wav = read_audio(
|
255 |
+
str(file_path)
|
256 |
+
) # backend (sox, soundfile, or ffmpeg) required!
|
257 |
+
audio_key = file_path.stem
|
258 |
+
audio_val = []
|
259 |
+
speech_timestamps = get_speech_timestamps(
|
260 |
+
wav,
|
261 |
+
vad_model,
|
262 |
+
max_speech_duration_s=max_single_segment_time // 1000,
|
263 |
+
return_seconds=True,
|
264 |
+
)
|
265 |
+
|
266 |
+
audio_val = [
|
267 |
+
[int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
|
268 |
+
for timestamp in speech_timestamps
|
269 |
+
]
|
270 |
+
vad_res = []
|
271 |
+
vad_res.append(dict(key=audio_key, value=audio_val))
|
272 |
+
|
273 |
+
res = manager.inference_with_vadres(
|
274 |
+
input=str(file_path), vad_res=vad_res, **cfg
|
275 |
+
)
|
276 |
+
|
277 |
+
for i, info in enumerate(res):
|
278 |
+
[start_ms, end_ms] = info["interval"]
|
279 |
+
text = info["text"]
|
280 |
+
emo = info["emo"]
|
281 |
+
sliced_audio = audio[start_ms:end_ms]
|
282 |
+
audio_save_path = (
|
283 |
+
save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
|
284 |
+
)
|
285 |
+
sliced_audio.export(audio_save_path, format=file_suffix[1:])
|
286 |
+
print(f"Exported {audio_save_path}: {text}")
|
287 |
+
|
288 |
+
transcript_save_path = (
|
289 |
+
save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
|
290 |
+
)
|
291 |
+
with open(
|
292 |
+
transcript_save_path,
|
293 |
+
"w",
|
294 |
+
encoding="utf-8",
|
295 |
+
) as f:
|
296 |
+
f.write(text)
|
297 |
+
|
298 |
+
if save_emo:
|
299 |
+
emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
|
300 |
+
with open(
|
301 |
+
emo_save_path,
|
302 |
+
"w",
|
303 |
+
encoding="utf-8",
|
304 |
+
) as f:
|
305 |
+
f.write(emo)
|
306 |
+
|
307 |
+
if audios_path.resolve() == save_path.resolve():
|
308 |
+
file_path.unlink()
|
309 |
+
|
310 |
+
|
311 |
+
if __name__ == "__main__":
|
312 |
+
main()
|
313 |
+
exit(0)
|
314 |
+
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
315 |
+
|
316 |
+
# Load the audio file
|
317 |
+
audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
|
318 |
+
model_dir = "iic/SenseVoiceSmall"
|
319 |
+
m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
|
320 |
+
m.eval()
|
321 |
+
|
322 |
+
res = m.inference(
|
323 |
+
data_in=f"{kwargs['model_path']}/example/zh.mp3",
|
324 |
+
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
|
325 |
+
use_itn=False,
|
326 |
+
ban_emo_unk=False,
|
327 |
+
**kwargs,
|
328 |
+
)
|
329 |
+
|
330 |
+
print(res)
|
331 |
+
text = rich_transcription_postprocess(res[0][0]["text"])
|
332 |
+
print(text)
|
tools/sensevoice/vad_utils.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.utils.rnn import pad_sequence
|
3 |
+
|
4 |
+
|
5 |
+
def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
6 |
+
speech_list = []
|
7 |
+
speech_lengths_list = []
|
8 |
+
for i, segment in enumerate(vad_segments):
|
9 |
+
|
10 |
+
bed_idx = int(segment[0][0] * 16)
|
11 |
+
end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
|
12 |
+
speech_i = speech[0, bed_idx:end_idx]
|
13 |
+
speech_lengths_i = end_idx - bed_idx
|
14 |
+
speech_list.append(speech_i)
|
15 |
+
speech_lengths_list.append(speech_lengths_i)
|
16 |
+
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
|
17 |
+
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
|
18 |
+
return feats_pad, speech_lengths_pad
|
19 |
+
|
20 |
+
|
21 |
+
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
|
22 |
+
speech_list = []
|
23 |
+
speech_lengths_list = []
|
24 |
+
intervals = []
|
25 |
+
for i, segment in enumerate(vad_segments):
|
26 |
+
bed_idx = int(segment[0][0] * 16)
|
27 |
+
end_idx = min(int(segment[0][1] * 16), speech_lengths)
|
28 |
+
speech_i = speech[bed_idx:end_idx]
|
29 |
+
speech_lengths_i = end_idx - bed_idx
|
30 |
+
speech_list.append(speech_i)
|
31 |
+
speech_lengths_list.append(speech_lengths_i)
|
32 |
+
intervals.append([bed_idx // 16, end_idx // 16])
|
33 |
+
|
34 |
+
return speech_list, speech_lengths_list, intervals
|
35 |
+
|
36 |
+
|
37 |
+
def merge_vad(vad_result, max_length=15000, min_length=0):
|
38 |
+
new_result = []
|
39 |
+
if len(vad_result) <= 1:
|
40 |
+
return vad_result
|
41 |
+
time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
|
42 |
+
time_step = sorted(list(set(time_step)))
|
43 |
+
if len(time_step) == 0:
|
44 |
+
return []
|
45 |
+
bg = 0
|
46 |
+
for i in range(len(time_step) - 1):
|
47 |
+
time = time_step[i]
|
48 |
+
if time_step[i + 1] - bg < max_length:
|
49 |
+
continue
|
50 |
+
if time - bg > min_length:
|
51 |
+
new_result.append([bg, time])
|
52 |
+
# if time - bg < max_length * 1.5:
|
53 |
+
# new_result.append([bg, time])
|
54 |
+
# else:
|
55 |
+
# split_num = int(time - bg) // max_length + 1
|
56 |
+
# spl_l = int(time - bg) // split_num
|
57 |
+
# for j in range(split_num):
|
58 |
+
# new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
|
59 |
+
bg = time
|
60 |
+
new_result.append([bg, time_step[-1]])
|
61 |
+
return new_result
|
tools/smart_pad.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from multiprocessing import Pool
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import click
|
6 |
+
import librosa
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from tools.file import AUDIO_EXTENSIONS, list_files
|
12 |
+
|
13 |
+
threshold = 10 ** (-50 / 20.0)
|
14 |
+
|
15 |
+
|
16 |
+
def process(file):
|
17 |
+
waveform, sample_rate = torchaudio.load(str(file), backend="sox")
|
18 |
+
if waveform.size(0) > 1:
|
19 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
20 |
+
|
21 |
+
loudness = librosa.feature.rms(
|
22 |
+
y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
|
23 |
+
)[0]
|
24 |
+
|
25 |
+
for i in range(len(loudness) - 1, 0, -1):
|
26 |
+
if loudness[i] > threshold:
|
27 |
+
break
|
28 |
+
|
29 |
+
end_silent_time = (len(loudness) - i) * 512 / sample_rate
|
30 |
+
|
31 |
+
if end_silent_time <= 0.3:
|
32 |
+
random_time = random.uniform(0.3, 0.7) - end_silent_time
|
33 |
+
waveform = F.pad(
|
34 |
+
waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
|
35 |
+
)
|
36 |
+
|
37 |
+
for i in range(len(loudness)):
|
38 |
+
if loudness[i] > threshold:
|
39 |
+
break
|
40 |
+
|
41 |
+
start_silent_time = i * 512 / sample_rate
|
42 |
+
|
43 |
+
if start_silent_time > 0.02:
|
44 |
+
waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
|
45 |
+
|
46 |
+
torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
|
47 |
+
|
48 |
+
|
49 |
+
@click.command()
|
50 |
+
@click.argument("source", type=Path)
|
51 |
+
@click.option("--num-workers", type=int, default=12)
|
52 |
+
def main(source, num_workers):
|
53 |
+
files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
|
54 |
+
|
55 |
+
with Pool(num_workers) as p:
|
56 |
+
list(tqdm(p.imap_unordered(process, files), total=len(files)))
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
main()
|
tools/vqgan/__pycache__/inference.cpython-310.pyc
ADDED
Binary file (3.53 kB). View file
|
|
tools/vqgan/create_train_split.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pathlib import Path
|
3 |
+
from random import Random
|
4 |
+
|
5 |
+
import click
|
6 |
+
from loguru import logger
|
7 |
+
from pydub import AudioSegment
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
11 |
+
|
12 |
+
|
13 |
+
@click.command()
|
14 |
+
@click.argument("root", type=click.Path(exists=True, path_type=Path))
|
15 |
+
@click.option("--val-ratio", type=float, default=None)
|
16 |
+
@click.option("--val-count", type=int, default=None)
|
17 |
+
@click.option("--filelist", default=None, type=Path)
|
18 |
+
@click.option("--min-duration", default=None, type=float)
|
19 |
+
@click.option("--max-duration", default=None, type=float)
|
20 |
+
def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
|
21 |
+
if filelist:
|
22 |
+
files = [i[0] for i in load_filelist(filelist)]
|
23 |
+
else:
|
24 |
+
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
|
25 |
+
|
26 |
+
if min_duration is None and max_duration is None:
|
27 |
+
filtered_files = list(map(str, [file.relative_to(root) for file in files]))
|
28 |
+
else:
|
29 |
+
filtered_files = []
|
30 |
+
for file in tqdm(files):
|
31 |
+
try:
|
32 |
+
audio = AudioSegment.from_file(str(file))
|
33 |
+
duration = len(audio) / 1000.0
|
34 |
+
|
35 |
+
if min_duration is not None and duration < min_duration:
|
36 |
+
logger.info(
|
37 |
+
f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
|
38 |
+
)
|
39 |
+
continue
|
40 |
+
|
41 |
+
if max_duration is not None and duration > max_duration:
|
42 |
+
logger.info(
|
43 |
+
f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
|
44 |
+
)
|
45 |
+
continue
|
46 |
+
|
47 |
+
filtered_files.append(str(file.relative_to(root)))
|
48 |
+
except Exception as e:
|
49 |
+
logger.info(f"Error processing {file}: {e}")
|
50 |
+
|
51 |
+
logger.info(
|
52 |
+
f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
|
53 |
+
)
|
54 |
+
|
55 |
+
Random(42).shuffle(filtered_files)
|
56 |
+
|
57 |
+
if val_count is None and val_ratio is None:
|
58 |
+
logger.info("Validation ratio and count not specified, using min(20%, 100)")
|
59 |
+
val_size = min(100, math.ceil(len(filtered_files) * 0.2))
|
60 |
+
elif val_count is not None and val_ratio is not None:
|
61 |
+
logger.error("Cannot specify both val_count and val_ratio")
|
62 |
+
return
|
63 |
+
elif val_count is not None:
|
64 |
+
if val_count < 1 or val_count > len(filtered_files):
|
65 |
+
logger.error("val_count must be between 1 and number of files")
|
66 |
+
return
|
67 |
+
val_size = val_count
|
68 |
+
else:
|
69 |
+
val_size = math.ceil(len(filtered_files) * val_ratio)
|
70 |
+
|
71 |
+
logger.info(f"Using {val_size} files for validation")
|
72 |
+
|
73 |
+
with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
|
74 |
+
f.write("\n".join(filtered_files[val_size:]))
|
75 |
+
|
76 |
+
with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
|
77 |
+
f.write("\n".join(filtered_files[:val_size]))
|
78 |
+
|
79 |
+
logger.info("Done")
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
main()
|
tools/vqgan/extract_vq.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess as sp
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
from datetime import timedelta
|
6 |
+
from functools import lru_cache
|
7 |
+
from pathlib import Path
|
8 |
+
from random import Random
|
9 |
+
|
10 |
+
import click
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torchaudio
|
14 |
+
from hydra import compose, initialize
|
15 |
+
from hydra.utils import instantiate
|
16 |
+
from lightning import LightningModule
|
17 |
+
from loguru import logger
|
18 |
+
from omegaconf import OmegaConf
|
19 |
+
|
20 |
+
from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
21 |
+
|
22 |
+
# register eval resolver
|
23 |
+
OmegaConf.register_new_resolver("eval", eval)
|
24 |
+
# This file is used to convert the audio files to text files using the Whisper model.
|
25 |
+
# It's mainly used to generate the training data for the VQ model.
|
26 |
+
|
27 |
+
backends = torchaudio.list_audio_backends()
|
28 |
+
|
29 |
+
if "ffmpeg" in backends:
|
30 |
+
backend = "ffmpeg"
|
31 |
+
else:
|
32 |
+
backend = "soundfile"
|
33 |
+
|
34 |
+
RANK = int(os.environ.get("SLURM_PROCID", 0))
|
35 |
+
WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
|
36 |
+
|
37 |
+
logger_format = (
|
38 |
+
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
39 |
+
"<level>{level: <8}</level> | "
|
40 |
+
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
|
41 |
+
"{extra[rank]} - <level>{message}</level>"
|
42 |
+
)
|
43 |
+
logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
|
44 |
+
logger.remove()
|
45 |
+
logger.add(sys.stderr, format=logger_format)
|
46 |
+
|
47 |
+
|
48 |
+
@lru_cache(maxsize=1)
|
49 |
+
def get_model(
|
50 |
+
config_name: str = "firefly_gan_vq",
|
51 |
+
checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
52 |
+
device: str | torch.device = "cuda",
|
53 |
+
):
|
54 |
+
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
55 |
+
cfg = compose(config_name=config_name)
|
56 |
+
|
57 |
+
model = instantiate(cfg)
|
58 |
+
state_dict = torch.load(
|
59 |
+
checkpoint_path,
|
60 |
+
map_location=device,
|
61 |
+
)
|
62 |
+
if "state_dict" in state_dict:
|
63 |
+
state_dict = state_dict["state_dict"]
|
64 |
+
|
65 |
+
if any("generator" in k for k in state_dict):
|
66 |
+
state_dict = {
|
67 |
+
k.replace("generator.", ""): v
|
68 |
+
for k, v in state_dict.items()
|
69 |
+
if "generator." in k
|
70 |
+
}
|
71 |
+
|
72 |
+
model.load_state_dict(state_dict, strict=False)
|
73 |
+
model.eval()
|
74 |
+
model.to(device)
|
75 |
+
|
76 |
+
logger.info(f"Loaded model")
|
77 |
+
return model
|
78 |
+
|
79 |
+
|
80 |
+
@torch.inference_mode()
|
81 |
+
def process_batch(files: list[Path], model) -> float:
|
82 |
+
wavs = []
|
83 |
+
audio_lengths = []
|
84 |
+
new_files = []
|
85 |
+
max_length = total_time = 0
|
86 |
+
|
87 |
+
for file in files:
|
88 |
+
try:
|
89 |
+
wav, sr = torchaudio.load(
|
90 |
+
str(file), backend=backend
|
91 |
+
) # Need to install libsox-dev
|
92 |
+
except Exception as e:
|
93 |
+
logger.error(f"Error reading {file}: {e}")
|
94 |
+
continue
|
95 |
+
|
96 |
+
if wav.shape[0] > 1:
|
97 |
+
wav = wav.mean(dim=0, keepdim=True)
|
98 |
+
|
99 |
+
wav = torchaudio.functional.resample(
|
100 |
+
wav.cuda(), sr, model.spec_transform.sample_rate
|
101 |
+
)[0]
|
102 |
+
total_time += len(wav) / model.spec_transform.sample_rate
|
103 |
+
max_length = max(max_length, len(wav))
|
104 |
+
|
105 |
+
wavs.append(wav)
|
106 |
+
audio_lengths.append(len(wav))
|
107 |
+
new_files.append(file)
|
108 |
+
|
109 |
+
files = new_files
|
110 |
+
|
111 |
+
# Pad to max length
|
112 |
+
for i, wav in enumerate(wavs):
|
113 |
+
wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
|
114 |
+
|
115 |
+
audios = torch.stack(wavs, dim=0)[:, None]
|
116 |
+
audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
|
117 |
+
|
118 |
+
# Calculate lengths
|
119 |
+
indices, feature_lengths = model.encode(audios, audio_lengths)
|
120 |
+
|
121 |
+
# Save to disk
|
122 |
+
outputs = indices.cpu().numpy()
|
123 |
+
|
124 |
+
for file, length, feature, audio_length in zip(
|
125 |
+
files, feature_lengths, outputs, audio_lengths
|
126 |
+
):
|
127 |
+
feature = feature[:, :length]
|
128 |
+
|
129 |
+
# (T,)
|
130 |
+
with open(file.with_suffix(".npy"), "wb") as f:
|
131 |
+
np.save(f, feature)
|
132 |
+
|
133 |
+
return total_time
|
134 |
+
|
135 |
+
|
136 |
+
@click.command()
|
137 |
+
@click.argument("folder")
|
138 |
+
@click.option("--num-workers", default=1)
|
139 |
+
@click.option("--config-name", default="firefly_gan_vq")
|
140 |
+
@click.option(
|
141 |
+
"--checkpoint-path",
|
142 |
+
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
143 |
+
)
|
144 |
+
@click.option("--batch-size", default=64)
|
145 |
+
@click.option("--filelist", default=None, type=Path)
|
146 |
+
def main(
|
147 |
+
folder: str,
|
148 |
+
num_workers: int,
|
149 |
+
config_name: str,
|
150 |
+
checkpoint_path: str,
|
151 |
+
batch_size: int,
|
152 |
+
filelist: Path,
|
153 |
+
):
|
154 |
+
if num_workers > 1 and WORLD_SIZE != num_workers:
|
155 |
+
assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
|
156 |
+
|
157 |
+
logger.info(f"Spawning {num_workers} workers")
|
158 |
+
|
159 |
+
if torch.cuda.is_available():
|
160 |
+
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
161 |
+
if visible_devices is None:
|
162 |
+
visible_devices = list(range(torch.cuda.device_count()))
|
163 |
+
else:
|
164 |
+
visible_devices = visible_devices.split(",")
|
165 |
+
else:
|
166 |
+
# Set to empty string to avoid using GPU
|
167 |
+
visible_devices = [""]
|
168 |
+
|
169 |
+
processes = []
|
170 |
+
for i in range(num_workers):
|
171 |
+
env = os.environ.copy()
|
172 |
+
env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
|
173 |
+
env["SLURM_PROCID"] = str(i)
|
174 |
+
env["SLURM_NTASKS"] = str(num_workers)
|
175 |
+
|
176 |
+
processes.append(
|
177 |
+
sp.Popen(
|
178 |
+
[sys.executable] + sys.argv.copy(),
|
179 |
+
env=env,
|
180 |
+
)
|
181 |
+
)
|
182 |
+
|
183 |
+
for p in processes:
|
184 |
+
p.wait()
|
185 |
+
|
186 |
+
logger.info(f"All workers finished")
|
187 |
+
return
|
188 |
+
|
189 |
+
# This is a worker
|
190 |
+
logger.info(f"Starting worker")
|
191 |
+
if filelist:
|
192 |
+
files = [i[0] for i in load_filelist(filelist)]
|
193 |
+
else:
|
194 |
+
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
|
195 |
+
|
196 |
+
print(f"Found {len(files)} files")
|
197 |
+
files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
|
198 |
+
|
199 |
+
total_files = len(files)
|
200 |
+
files = files[RANK::WORLD_SIZE]
|
201 |
+
logger.info(f"Processing {len(files)}/{total_files} files")
|
202 |
+
|
203 |
+
# Batch processing
|
204 |
+
total_time = 0
|
205 |
+
begin_time = time.time()
|
206 |
+
processed_files = 0
|
207 |
+
model = get_model(config_name, checkpoint_path)
|
208 |
+
|
209 |
+
for n_batch, idx in enumerate(range(0, len(files), batch_size)):
|
210 |
+
batch = files[idx : idx + batch_size]
|
211 |
+
batch_time = process_batch(batch, model)
|
212 |
+
|
213 |
+
total_time += batch_time
|
214 |
+
processed_files += len(batch)
|
215 |
+
|
216 |
+
if (n_batch + 1) % 10 == 0:
|
217 |
+
eta = (
|
218 |
+
(time.time() - begin_time)
|
219 |
+
/ processed_files
|
220 |
+
* (len(files) - processed_files)
|
221 |
+
)
|
222 |
+
logger.info(
|
223 |
+
f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
|
224 |
+
+ f"ETA: {timedelta(seconds=round(eta))}s"
|
225 |
+
)
|
226 |
+
|
227 |
+
logger.info(
|
228 |
+
f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
|
229 |
+
)
|
230 |
+
|
231 |
+
|
232 |
+
if __name__ == "__main__":
|
233 |
+
main()
|
tools/vqgan/inference.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import click
|
4 |
+
import hydra
|
5 |
+
import numpy as np
|
6 |
+
import soundfile as sf
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from hydra import compose, initialize
|
10 |
+
from hydra.utils import instantiate
|
11 |
+
from loguru import logger
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
|
14 |
+
from tools.file import AUDIO_EXTENSIONS
|
15 |
+
|
16 |
+
# register eval resolver
|
17 |
+
OmegaConf.register_new_resolver("eval", eval)
|
18 |
+
|
19 |
+
|
20 |
+
def load_model(config_name, checkpoint_path, device="cuda"):
|
21 |
+
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
22 |
+
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
23 |
+
cfg = compose(config_name=config_name)
|
24 |
+
|
25 |
+
model = instantiate(cfg)
|
26 |
+
state_dict = torch.load(
|
27 |
+
checkpoint_path, map_location=device, mmap=True, weights_only=True
|
28 |
+
)
|
29 |
+
if "state_dict" in state_dict:
|
30 |
+
state_dict = state_dict["state_dict"]
|
31 |
+
|
32 |
+
if any("generator" in k for k in state_dict):
|
33 |
+
state_dict = {
|
34 |
+
k.replace("generator.", ""): v
|
35 |
+
for k, v in state_dict.items()
|
36 |
+
if "generator." in k
|
37 |
+
}
|
38 |
+
|
39 |
+
result = model.load_state_dict(state_dict, strict=False, assign=True)
|
40 |
+
model.eval()
|
41 |
+
model.to(device)
|
42 |
+
|
43 |
+
logger.info(f"Loaded model: {result}")
|
44 |
+
return model
|
45 |
+
|
46 |
+
|
47 |
+
@torch.no_grad()
|
48 |
+
@click.command()
|
49 |
+
@click.option(
|
50 |
+
"--input-path",
|
51 |
+
"-i",
|
52 |
+
default="test.wav",
|
53 |
+
type=click.Path(exists=True, path_type=Path),
|
54 |
+
)
|
55 |
+
@click.option(
|
56 |
+
"--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
|
57 |
+
)
|
58 |
+
@click.option("--config-name", default="firefly_gan_vq")
|
59 |
+
@click.option(
|
60 |
+
"--checkpoint-path",
|
61 |
+
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
62 |
+
)
|
63 |
+
@click.option(
|
64 |
+
"--device",
|
65 |
+
"-d",
|
66 |
+
default="cuda",
|
67 |
+
)
|
68 |
+
def main(input_path, output_path, config_name, checkpoint_path, device):
|
69 |
+
model = load_model(config_name, checkpoint_path, device=device)
|
70 |
+
|
71 |
+
if input_path.suffix in AUDIO_EXTENSIONS:
|
72 |
+
logger.info(f"Processing in-place reconstruction of {input_path}")
|
73 |
+
|
74 |
+
# Load audio
|
75 |
+
audio, sr = torchaudio.load(str(input_path))
|
76 |
+
if audio.shape[0] > 1:
|
77 |
+
audio = audio.mean(0, keepdim=True)
|
78 |
+
audio = torchaudio.functional.resample(
|
79 |
+
audio, sr, model.spec_transform.sample_rate
|
80 |
+
)
|
81 |
+
|
82 |
+
audios = audio[None].to(device)
|
83 |
+
logger.info(
|
84 |
+
f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
|
85 |
+
)
|
86 |
+
|
87 |
+
# VQ Encoder
|
88 |
+
audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
|
89 |
+
indices = model.encode(audios, audio_lengths)[0][0]
|
90 |
+
|
91 |
+
logger.info(f"Generated indices of shape {indices.shape}")
|
92 |
+
|
93 |
+
# Save indices
|
94 |
+
np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
|
95 |
+
elif input_path.suffix == ".npy":
|
96 |
+
logger.info(f"Processing precomputed indices from {input_path}")
|
97 |
+
indices = np.load(input_path)
|
98 |
+
indices = torch.from_numpy(indices).to(device).long()
|
99 |
+
assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
|
100 |
+
else:
|
101 |
+
raise ValueError(f"Unknown input type: {input_path}")
|
102 |
+
|
103 |
+
# Restore
|
104 |
+
feature_lengths = torch.tensor([indices.shape[1]], device=device)
|
105 |
+
fake_audios, _ = model.decode(
|
106 |
+
indices=indices[None], feature_lengths=feature_lengths
|
107 |
+
)
|
108 |
+
audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
|
109 |
+
|
110 |
+
logger.info(
|
111 |
+
f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
|
112 |
+
)
|
113 |
+
|
114 |
+
# Save audio
|
115 |
+
fake_audio = fake_audios[0, 0].float().cpu().numpy()
|
116 |
+
sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
|
117 |
+
logger.info(f"Saved audio to {output_path}")
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
main()
|
tools/webui.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import html
|
3 |
+
import io
|
4 |
+
import os
|
5 |
+
import queue
|
6 |
+
import wave
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
from functools import partial
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import librosa
|
13 |
+
import numpy as np
|
14 |
+
import pyrootutils
|
15 |
+
import torch
|
16 |
+
from loguru import logger
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
|
19 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
20 |
+
|
21 |
+
|
22 |
+
from fish_speech.i18n import i18n
|
23 |
+
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
24 |
+
from fish_speech.utils import autocast_exclude_mps, set_seed
|
25 |
+
from tools.api import decode_vq_tokens, encode_reference
|
26 |
+
from tools.file import AUDIO_EXTENSIONS, list_files
|
27 |
+
from tools.llama.generate import (
|
28 |
+
GenerateRequest,
|
29 |
+
GenerateResponse,
|
30 |
+
WrappedGenerateResponse,
|
31 |
+
launch_thread_safe_queue,
|
32 |
+
)
|
33 |
+
from tools.vqgan.inference import load_model as load_decoder_model
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
# Make einx happy
|
38 |
+
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
39 |
+
|
40 |
+
|
41 |
+
HEADER_MD = f"""# 泰雅爾語TTS
|
42 |
+
|
43 |
+
#泰雅爾語測試範例
|
44 |
+
|
45 |
+
{i18n("Miyan qaniy qu binkgan bbinkesan na Yesu:Yesu Kristo ga kinbahan na Tabite, Tabite ga kinbahan na Aburaham.")}
|
46 |
+
|
47 |
+
{i18n("Aburaham ga yaba na Isak; Isak ga yaba na Yakob; Yakob ga yaba na Yuta ki mmtswe nya mlikuy.")}
|
48 |
+
|
49 |
+
{i18n("Babaw nqu kyapun rasun squ qalang Babilon lga, plqyun ni Yehoyacin qu Seltiyel; Seltiyel ga yaba na Zerubabel;")}
|
50 |
+
|
51 |
+
#若要使用自己的聲音合成請按以下步驟(Streaming Generate)
|
52 |
+
|
53 |
+
# <span style="color: red;">Streaming Generate 此功能維護中</span>
|
54 |
+
|
55 |
+
{i18n("1.在Reference Audio找到Enable Reference Audio打勾")}
|
56 |
+
|
57 |
+
{i18n("2.在左下方將錄音檔案上傳,並在Reference Text輸入上傳音檔的文字")}
|
58 |
+
|
59 |
+
{i18n("3.在Input Text輸入文字")}
|
60 |
+
|
61 |
+
{i18n("4.按下Streaming Generate即可")}
|
62 |
+
|
63 |
+
|
64 |
+
"""
|
65 |
+
|
66 |
+
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
|
67 |
+
SPACE_IMPORTED = False
|
68 |
+
|
69 |
+
|
70 |
+
def build_html_error_message(error):
|
71 |
+
return f"""
|
72 |
+
<div style="color: red;
|
73 |
+
font-weight: bold;">
|
74 |
+
{html.escape(str(error))}
|
75 |
+
</div>
|
76 |
+
"""
|
77 |
+
|
78 |
+
|
79 |
+
@torch.inference_mode()
|
80 |
+
def inference(
|
81 |
+
text,
|
82 |
+
enable_reference_audio,
|
83 |
+
reference_audio,
|
84 |
+
reference_text,
|
85 |
+
max_new_tokens,
|
86 |
+
chunk_length,
|
87 |
+
top_p,
|
88 |
+
repetition_penalty,
|
89 |
+
temperature,
|
90 |
+
seed="0",
|
91 |
+
streaming=False,
|
92 |
+
):
|
93 |
+
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
94 |
+
return (
|
95 |
+
None,
|
96 |
+
None,
|
97 |
+
i18n("Text is too long, please keep it under {} characters.").format(
|
98 |
+
args.max_gradio_length
|
99 |
+
),
|
100 |
+
)
|
101 |
+
|
102 |
+
seed = int(seed)
|
103 |
+
if seed != 0:
|
104 |
+
set_seed(seed)
|
105 |
+
logger.warning(f"set seed: {seed}")
|
106 |
+
|
107 |
+
# Parse reference audio aka prompt
|
108 |
+
prompt_tokens = encode_reference(
|
109 |
+
decoder_model=decoder_model,
|
110 |
+
reference_audio=reference_audio,
|
111 |
+
enable_reference_audio=enable_reference_audio,
|
112 |
+
)
|
113 |
+
|
114 |
+
# LLAMA Inference
|
115 |
+
request = dict(
|
116 |
+
device=decoder_model.device,
|
117 |
+
max_new_tokens=600,
|
118 |
+
text=text,
|
119 |
+
top_p=top_p,
|
120 |
+
repetition_penalty=repetition_penalty,
|
121 |
+
temperature=temperature,
|
122 |
+
compile=args.compile,
|
123 |
+
iterative_prompt=chunk_length > 0,
|
124 |
+
chunk_length=chunk_length,
|
125 |
+
max_length=2048,
|
126 |
+
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
127 |
+
prompt_text=reference_text if enable_reference_audio else None,
|
128 |
+
)
|
129 |
+
|
130 |
+
response_queue = queue.Queue()
|
131 |
+
llama_queue.put(
|
132 |
+
GenerateRequest(
|
133 |
+
request=request,
|
134 |
+
response_queue=response_queue,
|
135 |
+
)
|
136 |
+
)
|
137 |
+
|
138 |
+
if streaming:
|
139 |
+
yield wav_chunk_header(), None, None
|
140 |
+
|
141 |
+
segments = []
|
142 |
+
|
143 |
+
while True:
|
144 |
+
result: WrappedGenerateResponse = response_queue.get()
|
145 |
+
if result.status == "error":
|
146 |
+
yield None, None, build_html_error_message(result.response)
|
147 |
+
break
|
148 |
+
|
149 |
+
result: GenerateResponse = result.response
|
150 |
+
if result.action == "next":
|
151 |
+
break
|
152 |
+
|
153 |
+
with autocast_exclude_mps(
|
154 |
+
device_type=decoder_model.device.type, dtype=args.precision
|
155 |
+
):
|
156 |
+
fake_audios = decode_vq_tokens(
|
157 |
+
decoder_model=decoder_model,
|
158 |
+
codes=result.codes,
|
159 |
+
)
|
160 |
+
|
161 |
+
fake_audios = fake_audios.float().cpu().numpy()
|
162 |
+
segments.append(fake_audios)
|
163 |
+
|
164 |
+
if streaming:
|
165 |
+
wav_header = wav_chunk_header()
|
166 |
+
audio_data = (fake_audios * 32768).astype(np.int16).tobytes()
|
167 |
+
yield wav_header + audio_data, None, None
|
168 |
+
|
169 |
+
if len(segments) == 0:
|
170 |
+
return (
|
171 |
+
None,
|
172 |
+
None,
|
173 |
+
build_html_error_message(
|
174 |
+
i18n("No audio generated, please check the input text.")
|
175 |
+
),
|
176 |
+
)
|
177 |
+
|
178 |
+
# No matter streaming or not, we need to return the final audio
|
179 |
+
audio = np.concatenate(segments, axis=0)
|
180 |
+
yield None, (decoder_model.spec_transform.sample_rate, audio), None
|
181 |
+
|
182 |
+
if torch.cuda.is_available():
|
183 |
+
torch.cuda.empty_cache()
|
184 |
+
gc.collect()
|
185 |
+
|
186 |
+
|
187 |
+
inference_stream = partial(inference, streaming=True)
|
188 |
+
|
189 |
+
n_audios = 4
|
190 |
+
|
191 |
+
global_audio_list = []
|
192 |
+
global_error_list = []
|
193 |
+
|
194 |
+
|
195 |
+
def inference_wrapper(
|
196 |
+
text,
|
197 |
+
enable_reference_audio,
|
198 |
+
reference_audio,
|
199 |
+
reference_text,
|
200 |
+
max_new_tokens,
|
201 |
+
chunk_length,
|
202 |
+
top_p,
|
203 |
+
repetition_penalty,
|
204 |
+
temperature,
|
205 |
+
seed,
|
206 |
+
batch_infer_num,
|
207 |
+
):
|
208 |
+
audios = []
|
209 |
+
errors = []
|
210 |
+
|
211 |
+
for _ in range(batch_infer_num):
|
212 |
+
result = inference(
|
213 |
+
text,
|
214 |
+
enable_reference_audio,
|
215 |
+
reference_audio,
|
216 |
+
reference_text,
|
217 |
+
max_new_tokens,
|
218 |
+
chunk_length,
|
219 |
+
top_p,
|
220 |
+
repetition_penalty,
|
221 |
+
temperature,
|
222 |
+
seed,
|
223 |
+
)
|
224 |
+
|
225 |
+
_, audio_data, error_message = next(result)
|
226 |
+
|
227 |
+
audios.append(
|
228 |
+
gr.Audio(value=audio_data if audio_data else None, visible=True),
|
229 |
+
)
|
230 |
+
errors.append(
|
231 |
+
gr.HTML(value=error_message if error_message else None, visible=True),
|
232 |
+
)
|
233 |
+
|
234 |
+
for _ in range(batch_infer_num, n_audios):
|
235 |
+
audios.append(
|
236 |
+
gr.Audio(value=None, visible=False),
|
237 |
+
)
|
238 |
+
errors.append(
|
239 |
+
gr.HTML(value=None, visible=False),
|
240 |
+
)
|
241 |
+
|
242 |
+
return None, *audios, *errors
|
243 |
+
|
244 |
+
|
245 |
+
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
246 |
+
buffer = io.BytesIO()
|
247 |
+
|
248 |
+
with wave.open(buffer, "wb") as wav_file:
|
249 |
+
wav_file.setnchannels(channels)
|
250 |
+
wav_file.setsampwidth(bit_depth // 8)
|
251 |
+
wav_file.setframerate(sample_rate)
|
252 |
+
|
253 |
+
wav_header_bytes = buffer.getvalue()
|
254 |
+
buffer.close()
|
255 |
+
return wav_header_bytes
|
256 |
+
|
257 |
+
|
258 |
+
def normalize_text(user_input, use_normalization):
|
259 |
+
if use_normalization:
|
260 |
+
return ChnNormedText(raw_text=user_input).normalize()
|
261 |
+
else:
|
262 |
+
return user_input
|
263 |
+
|
264 |
+
|
265 |
+
def update_examples():
|
266 |
+
examples_dir = Path("references")
|
267 |
+
examples_dir.mkdir(parents=True, exist_ok=True)
|
268 |
+
example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
|
269 |
+
return gr.Dropdown(choices=example_audios + [""])
|
270 |
+
|
271 |
+
|
272 |
+
def build_app():
|
273 |
+
with gr.Blocks(theme=gr.themes.Base()) as app:
|
274 |
+
gr.Markdown(HEADER_MD)
|
275 |
+
|
276 |
+
# Use light theme by default
|
277 |
+
app.load(
|
278 |
+
None,
|
279 |
+
None,
|
280 |
+
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
|
281 |
+
% args.theme,
|
282 |
+
)
|
283 |
+
|
284 |
+
# Inference
|
285 |
+
with gr.Row():
|
286 |
+
with gr.Column(scale=3):
|
287 |
+
text = gr.Textbox(
|
288 |
+
label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
289 |
+
)
|
290 |
+
refined_text = gr.Textbox(
|
291 |
+
label=i18n("Realtime Transform Text"),
|
292 |
+
placeholder=i18n(
|
293 |
+
"Normalization Result Preview (Currently Only Chinese)"
|
294 |
+
),
|
295 |
+
lines=5,
|
296 |
+
interactive=False,
|
297 |
+
)
|
298 |
+
|
299 |
+
with gr.Row():
|
300 |
+
if_refine_text = gr.Checkbox(
|
301 |
+
label=i18n("Text Normalization"),
|
302 |
+
value=False,
|
303 |
+
scale=1,
|
304 |
+
)
|
305 |
+
|
306 |
+
with gr.Row():
|
307 |
+
with gr.Column():
|
308 |
+
with gr.Tab(label=i18n("Advanced Config")):
|
309 |
+
with gr.Row():
|
310 |
+
chunk_length = gr.Slider(
|
311 |
+
label=i18n("Iterative Prompt Length, 0 means off"),
|
312 |
+
minimum=50,
|
313 |
+
maximum=300,
|
314 |
+
value=200,
|
315 |
+
step=8,
|
316 |
+
)
|
317 |
+
|
318 |
+
max_new_tokens = gr.Slider(
|
319 |
+
label=i18n(
|
320 |
+
"Maximum tokens per batch, 0 means no limit"
|
321 |
+
),
|
322 |
+
minimum=0,
|
323 |
+
maximum=2048,
|
324 |
+
value=0, # 0 means no limit
|
325 |
+
step=8,
|
326 |
+
)
|
327 |
+
|
328 |
+
with gr.Row():
|
329 |
+
top_p = gr.Slider(
|
330 |
+
label="Top-P",
|
331 |
+
minimum=0.6,
|
332 |
+
maximum=0.9,
|
333 |
+
value=0.7,
|
334 |
+
step=0.01,
|
335 |
+
)
|
336 |
+
|
337 |
+
repetition_penalty = gr.Slider(
|
338 |
+
label=i18n("Repetition Penalty"),
|
339 |
+
minimum=1,
|
340 |
+
maximum=1.5,
|
341 |
+
value=1.2,
|
342 |
+
step=0.01,
|
343 |
+
)
|
344 |
+
|
345 |
+
with gr.Row():
|
346 |
+
temperature = gr.Slider(
|
347 |
+
label="Temperature",
|
348 |
+
minimum=0.6,
|
349 |
+
maximum=0.9,
|
350 |
+
value=0.7,
|
351 |
+
step=0.01,
|
352 |
+
)
|
353 |
+
seed = gr.Textbox(
|
354 |
+
label="Seed",
|
355 |
+
info="0 means randomized inference, otherwise deterministic",
|
356 |
+
placeholder="any 32-bit-integer",
|
357 |
+
value="0",
|
358 |
+
)
|
359 |
+
|
360 |
+
with gr.Tab(label=i18n("Reference Audio")):
|
361 |
+
with gr.Row():
|
362 |
+
gr.Markdown(
|
363 |
+
i18n(
|
364 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
365 |
+
)
|
366 |
+
)
|
367 |
+
with gr.Row():
|
368 |
+
enable_reference_audio = gr.Checkbox(
|
369 |
+
label=i18n("Enable Reference Audio"),
|
370 |
+
)
|
371 |
+
|
372 |
+
with gr.Row():
|
373 |
+
example_audio_dropdown = gr.Dropdown(
|
374 |
+
label=i18n("Select Example Audio"),
|
375 |
+
choices=[""],
|
376 |
+
value="",
|
377 |
+
interactive=True,
|
378 |
+
allow_custom_value=True,
|
379 |
+
)
|
380 |
+
with gr.Row():
|
381 |
+
reference_audio = gr.Audio(
|
382 |
+
label=i18n("Reference Audio"),
|
383 |
+
type="filepath",
|
384 |
+
)
|
385 |
+
with gr.Row():
|
386 |
+
reference_text = gr.Textbox(
|
387 |
+
label=i18n("Reference Text"),
|
388 |
+
lines=1,
|
389 |
+
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
390 |
+
value="",
|
391 |
+
)
|
392 |
+
|
393 |
+
with gr.Tab(label=i18n("Batch Inference")):
|
394 |
+
with gr.Row():
|
395 |
+
batch_infer_num = gr.Slider(
|
396 |
+
label="Batch infer nums",
|
397 |
+
minimum=1,
|
398 |
+
maximum=n_audios,
|
399 |
+
step=1,
|
400 |
+
value=1,
|
401 |
+
)
|
402 |
+
|
403 |
+
with gr.Column(scale=3):
|
404 |
+
for _ in range(n_audios):
|
405 |
+
with gr.Row():
|
406 |
+
error = gr.HTML(
|
407 |
+
label=i18n("Error Message"),
|
408 |
+
visible=True if _ == 0 else False,
|
409 |
+
)
|
410 |
+
global_error_list.append(error)
|
411 |
+
with gr.Row():
|
412 |
+
audio = gr.Audio(
|
413 |
+
label=i18n("Generated Audio"),
|
414 |
+
type="numpy",
|
415 |
+
interactive=False,
|
416 |
+
visible=True if _ == 0 else False,
|
417 |
+
)
|
418 |
+
global_audio_list.append(audio)
|
419 |
+
|
420 |
+
with gr.Row():
|
421 |
+
stream_audio = gr.Audio(
|
422 |
+
label=i18n("Streaming Audio"),
|
423 |
+
streaming=True,
|
424 |
+
autoplay=True,
|
425 |
+
interactive=False,
|
426 |
+
show_download_button=True,
|
427 |
+
)
|
428 |
+
with gr.Row():
|
429 |
+
with gr.Column(scale=3):
|
430 |
+
generate = gr.Button(
|
431 |
+
value="\U0001F3A7 " + i18n("Generate"), variant="primary"
|
432 |
+
)
|
433 |
+
|
434 |
+
generate_stream = gr.Button(
|
435 |
+
value="\U0001F3A7 " + i18n("Streaming Generate"),
|
436 |
+
variant="primary",
|
437 |
+
visible=False # 隱藏按鈕
|
438 |
+
)
|
439 |
+
|
440 |
+
text.input(
|
441 |
+
fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
|
442 |
+
)
|
443 |
+
|
444 |
+
def select_example_audio(audio_path):
|
445 |
+
audio_path = Path(audio_path)
|
446 |
+
if audio_path.is_file():
|
447 |
+
lab_file = Path(audio_path.with_suffix(".lab"))
|
448 |
+
|
449 |
+
if lab_file.exists():
|
450 |
+
lab_content = lab_file.read_text(encoding="utf-8").strip()
|
451 |
+
else:
|
452 |
+
lab_content = ""
|
453 |
+
|
454 |
+
return str(audio_path), lab_content, True
|
455 |
+
return None, "", False
|
456 |
+
|
457 |
+
# Connect the dropdown to update reference audio and text
|
458 |
+
|
459 |
+
example_audio_dropdown.change(
|
460 |
+
fn=update_examples, inputs=[], outputs=[example_audio_dropdown]
|
461 |
+
).then(
|
462 |
+
fn=select_example_audio,
|
463 |
+
inputs=[example_audio_dropdown],
|
464 |
+
outputs=[reference_audio, reference_text, enable_reference_audio],
|
465 |
+
)
|
466 |
+
|
467 |
+
# # Submit
|
468 |
+
generate.click(
|
469 |
+
inference_wrapper,
|
470 |
+
[
|
471 |
+
refined_text,
|
472 |
+
enable_reference_audio,
|
473 |
+
reference_audio,
|
474 |
+
reference_text,
|
475 |
+
max_new_tokens,
|
476 |
+
chunk_length,
|
477 |
+
top_p,
|
478 |
+
repetition_penalty,
|
479 |
+
temperature,
|
480 |
+
seed,
|
481 |
+
batch_infer_num,
|
482 |
+
],
|
483 |
+
[stream_audio, *global_audio_list, *global_error_list],
|
484 |
+
concurrency_limit=1,
|
485 |
+
)
|
486 |
+
|
487 |
+
generate_stream.click(
|
488 |
+
inference_stream,
|
489 |
+
[
|
490 |
+
refined_text,
|
491 |
+
enable_reference_audio,
|
492 |
+
reference_audio,
|
493 |
+
reference_text,
|
494 |
+
max_new_tokens,
|
495 |
+
chunk_length,
|
496 |
+
top_p,
|
497 |
+
repetition_penalty,
|
498 |
+
temperature,
|
499 |
+
seed,
|
500 |
+
],
|
501 |
+
[stream_audio, global_audio_list[0], global_error_list[0]],
|
502 |
+
concurrency_limit=1,
|
503 |
+
)
|
504 |
+
|
505 |
+
return app
|
506 |
+
|
507 |
+
|
508 |
+
def parse_args():
|
509 |
+
parser = ArgumentParser()
|
510 |
+
parser.add_argument(
|
511 |
+
"--llama-checkpoint-path",
|
512 |
+
type=Path,
|
513 |
+
default="checkpoints/fish-speech-1.2",
|
514 |
+
)
|
515 |
+
parser.add_argument(
|
516 |
+
"--decoder-checkpoint-path",
|
517 |
+
type=Path,
|
518 |
+
default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
519 |
+
)
|
520 |
+
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
521 |
+
parser.add_argument("--device", type=str, default="cuda")
|
522 |
+
parser.add_argument("--half", action="store_true")
|
523 |
+
parser.add_argument("--compile", action="store_true")
|
524 |
+
parser.add_argument("--max-gradio-length", type=int, default=0)
|
525 |
+
parser.add_argument("--theme", type=str, default="light")
|
526 |
+
|
527 |
+
return parser.parse_args()
|
528 |
+
|
529 |
+
|
530 |
+
if __name__ == "__main__":
|
531 |
+
args = parse_args()
|
532 |
+
args.precision = torch.half if args.half else torch.bfloat16
|
533 |
+
|
534 |
+
logger.info("Loading Llama model...")
|
535 |
+
llama_queue = launch_thread_safe_queue(
|
536 |
+
checkpoint_path=args.llama_checkpoint_path,
|
537 |
+
device=args.device,
|
538 |
+
precision=args.precision,
|
539 |
+
compile=args.compile,
|
540 |
+
)
|
541 |
+
logger.info("Llama model loaded, loading VQ-GAN model...")
|
542 |
+
|
543 |
+
decoder_model = load_decoder_model(
|
544 |
+
config_name=args.decoder_config_name,
|
545 |
+
checkpoint_path=args.decoder_checkpoint_path,
|
546 |
+
device=args.device,
|
547 |
+
)
|
548 |
+
|
549 |
+
logger.info("Decoder model loaded, warming up...")
|
550 |
+
|
551 |
+
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
552 |
+
list(
|
553 |
+
inference(
|
554 |
+
text="Hello, world!",
|
555 |
+
enable_reference_audio=False,
|
556 |
+
reference_audio=None,
|
557 |
+
reference_text="",
|
558 |
+
max_new_tokens=500,
|
559 |
+
chunk_length=200,
|
560 |
+
top_p=0.7,
|
561 |
+
repetition_penalty=1.2,
|
562 |
+
temperature=0.7,
|
563 |
+
)
|
564 |
+
)
|
565 |
+
|
566 |
+
logger.info("Warming up done, launching the web UI...")
|
567 |
+
|
568 |
+
app = build_app()
|
569 |
+
app.launch(show_api=True, server_name="0.0.0.0",share=True)
|
570 |
+
|
tools/whisper_asr.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Used to transcribe all audio files in one folder into another folder.
|
3 |
+
e.g.
|
4 |
+
Directory structure:
|
5 |
+
--pre_data_root
|
6 |
+
----SP_1
|
7 |
+
------01.wav
|
8 |
+
------02.wav
|
9 |
+
------......
|
10 |
+
----SP_2
|
11 |
+
------01.wav
|
12 |
+
------02.wav
|
13 |
+
------......
|
14 |
+
Use
|
15 |
+
python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
|
16 |
+
to transcribe the first speaker.
|
17 |
+
|
18 |
+
Use
|
19 |
+
python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
|
20 |
+
to transcribe the second speaker.
|
21 |
+
|
22 |
+
Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
|
23 |
+
"""
|
24 |
+
|
25 |
+
import re
|
26 |
+
from pathlib import Path
|
27 |
+
|
28 |
+
import click
|
29 |
+
import soundfile as sf
|
30 |
+
from faster_whisper import WhisperModel
|
31 |
+
from loguru import logger
|
32 |
+
from pydub import AudioSegment
|
33 |
+
from tqdm import tqdm
|
34 |
+
|
35 |
+
from tools.file import AUDIO_EXTENSIONS, list_files
|
36 |
+
|
37 |
+
|
38 |
+
@click.command()
|
39 |
+
@click.option("--model-size", default="large-v3", help="Size of the Whisper model")
|
40 |
+
@click.option(
|
41 |
+
"--compute-type",
|
42 |
+
default="float16",
|
43 |
+
help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
|
44 |
+
)
|
45 |
+
@click.option("--audio-dir", required=True, help="Directory containing audio files")
|
46 |
+
@click.option(
|
47 |
+
"--save-dir", required=True, help="Directory to save processed audio files"
|
48 |
+
)
|
49 |
+
@click.option(
|
50 |
+
"--sample-rate",
|
51 |
+
default=44100,
|
52 |
+
type=int,
|
53 |
+
help="Output sample rate, default to input sample rate",
|
54 |
+
)
|
55 |
+
@click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
|
56 |
+
@click.option("--language", default="auto", help="Language of the transcription")
|
57 |
+
@click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
|
58 |
+
def main(
|
59 |
+
model_size,
|
60 |
+
compute_type,
|
61 |
+
audio_dir,
|
62 |
+
save_dir,
|
63 |
+
sample_rate,
|
64 |
+
device,
|
65 |
+
language,
|
66 |
+
initial_prompt,
|
67 |
+
):
|
68 |
+
logger.info("Loading / Downloading Faster Whisper model...")
|
69 |
+
|
70 |
+
model = WhisperModel(
|
71 |
+
model_size,
|
72 |
+
device=device,
|
73 |
+
compute_type=compute_type,
|
74 |
+
download_root="faster_whisper",
|
75 |
+
)
|
76 |
+
|
77 |
+
logger.info("Model loaded.")
|
78 |
+
|
79 |
+
save_path = Path(save_dir)
|
80 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
81 |
+
|
82 |
+
audio_files = list_files(
|
83 |
+
path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
|
84 |
+
)
|
85 |
+
|
86 |
+
for file_path in tqdm(audio_files, desc="Processing audio file"):
|
87 |
+
file_stem = file_path.stem
|
88 |
+
file_suffix = file_path.suffix
|
89 |
+
|
90 |
+
rel_path = Path(file_path).relative_to(audio_dir)
|
91 |
+
(save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
|
92 |
+
|
93 |
+
audio = AudioSegment.from_file(file_path)
|
94 |
+
|
95 |
+
segments, info = model.transcribe(
|
96 |
+
file_path,
|
97 |
+
beam_size=5,
|
98 |
+
language=None if language == "auto" else language,
|
99 |
+
initial_prompt=initial_prompt,
|
100 |
+
)
|
101 |
+
|
102 |
+
print(
|
103 |
+
"Detected language '%s' with probability %f"
|
104 |
+
% (info.language, info.language_probability)
|
105 |
+
)
|
106 |
+
print("Total len(ms): ", len(audio))
|
107 |
+
|
108 |
+
whole_text = None
|
109 |
+
for segment in segments:
|
110 |
+
id, start, end, text = (
|
111 |
+
segment.id,
|
112 |
+
segment.start,
|
113 |
+
segment.end,
|
114 |
+
segment.text,
|
115 |
+
)
|
116 |
+
print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
|
117 |
+
if not whole_text:
|
118 |
+
whole_text = text
|
119 |
+
else:
|
120 |
+
whole_text += ", " + text
|
121 |
+
|
122 |
+
whole_text += "."
|
123 |
+
|
124 |
+
audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
|
125 |
+
audio.export(audio_save_path, format=file_suffix[1:])
|
126 |
+
print(f"Exported {audio_save_path}")
|
127 |
+
|
128 |
+
transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
|
129 |
+
with open(
|
130 |
+
transcript_save_path,
|
131 |
+
"w",
|
132 |
+
encoding="utf-8",
|
133 |
+
) as f:
|
134 |
+
f.write(whole_text)
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
main()
|
139 |
+
exit(0)
|
140 |
+
|
141 |
+
audio = AudioSegment.from_wav(
|
142 |
+
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
|
143 |
+
)
|
144 |
+
|
145 |
+
model_size = "large-v3"
|
146 |
+
|
147 |
+
model = WhisperModel(
|
148 |
+
model_size,
|
149 |
+
device="cuda",
|
150 |
+
compute_type="float16",
|
151 |
+
download_root="faster_whisper",
|
152 |
+
)
|
153 |
+
|
154 |
+
segments, info = model.transcribe(
|
155 |
+
r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
|
156 |
+
beam_size=5,
|
157 |
+
)
|
158 |
+
|
159 |
+
print(
|
160 |
+
"Detected language '%s' with probability %f"
|
161 |
+
% (info.language, info.language_probability)
|
162 |
+
)
|
163 |
+
print("Total len(ms): ", len(audio))
|
164 |
+
|
165 |
+
for i, segment in enumerate(segments):
|
166 |
+
print(
|
167 |
+
"Segment %03d [%.2fs -> %.2fs] %s"
|
168 |
+
% (i, segment.start, segment.end, segment.text)
|
169 |
+
)
|
170 |
+
start_ms = int(segment.start * 1000)
|
171 |
+
end_ms = int(segment.end * 1000)
|
172 |
+
segment_audio = audio[start_ms:end_ms]
|
173 |
+
segment_audio.export(f"segment_{i:03d}.wav", format="wav")
|
174 |
+
print(f"Exported segment_{i:03d}.wav")
|
175 |
+
|
176 |
+
print("All segments have been exported.")
|