uto1125 commited on
Commit
c667b6b
·
verified ·
1 Parent(s): 28c311d

Upload 37 files

Browse files
tools/__pycache__/api.cpython-310.pyc ADDED
Binary file (22.2 kB). View file
 
tools/__pycache__/api.cpython-311.pyc ADDED
Binary file (45 kB). View file
 
tools/__pycache__/auto_rerank.cpython-310.pyc ADDED
Binary file (3.49 kB). View file
 
tools/__pycache__/commons.cpython-310.pyc ADDED
Binary file (1.49 kB). View file
 
tools/__pycache__/file.cpython-310.pyc ADDED
Binary file (2.99 kB). View file
 
tools/__pycache__/schema.cpython-310.pyc ADDED
Binary file (7.67 kB). View file
 
tools/__pycache__/webui.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
tools/api.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import queue
4
+ import re
5
+ import time
6
+ import traceback
7
+ import wave
8
+ from argparse import ArgumentParser
9
+ from http import HTTPStatus
10
+ from pathlib import Path
11
+ from typing import Annotated, Any
12
+
13
+ import librosa
14
+ import numpy as np
15
+ import ormsgpack
16
+ import pyrootutils
17
+ import soundfile as sf
18
+ import torch
19
+ import torchaudio
20
+ from baize.datastructures import ContentType
21
+ from kui.asgi import (
22
+ Body,
23
+ FactoryClass,
24
+ HTTPException,
25
+ HttpRequest,
26
+ HttpView,
27
+ JSONResponse,
28
+ Kui,
29
+ OpenAPI,
30
+ StreamResponse,
31
+ request,
32
+ )
33
+ from kui.asgi.routing import MultimethodRoutes
34
+ from loguru import logger
35
+ from transformers import AutoTokenizer
36
+
37
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
38
+ import struct
39
+ from threading import Lock
40
+
41
+ import httpx
42
+ from cachetools import LRUCache, cached
43
+ from funasr import AutoModel
44
+ from silero_vad import get_speech_timestamps, load_silero_vad
45
+
46
+ from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
47
+ from fish_speech.models.text2semantic.llama import BaseModelArgs
48
+
49
+ # from fish_speech.models.vqgan.lit_module import VQGAN
50
+ from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
51
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
52
+ from fish_speech.utils import autocast_exclude_mps, set_seed
53
+ from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
54
+ from tools.llama.generate import (
55
+ GenerateRequest,
56
+ GenerateResponse,
57
+ WrappedGenerateResponse,
58
+ launch_thread_safe_queue,
59
+ launch_thread_safe_queue_agent,
60
+ )
61
+ from tools.schema import (
62
+ GLOBAL_NUM_SAMPLES,
63
+ ASRPackRequest,
64
+ ServeASRRequest,
65
+ ServeASRResponse,
66
+ ServeASRSegment,
67
+ ServeAudioPart,
68
+ ServeForwardMessage,
69
+ ServeMessage,
70
+ ServeRequest,
71
+ ServeResponse,
72
+ ServeStreamDelta,
73
+ ServeStreamResponse,
74
+ ServeTextPart,
75
+ ServeTimedASRResponse,
76
+ ServeTTSRequest,
77
+ ServeVQGANDecodeRequest,
78
+ ServeVQGANDecodeResponse,
79
+ ServeVQGANEncodeRequest,
80
+ ServeVQGANEncodeResponse,
81
+ ServeVQPart,
82
+ )
83
+ from tools.vqgan.inference import load_model as load_decoder_model
84
+
85
+ global_lock = Lock()
86
+
87
+ # Whether to disable keepalive (which is helpful if the server is in the same cluster)
88
+ DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
89
+ async_client = httpx.AsyncClient(
90
+ timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
91
+ )
92
+ backends = torchaudio.list_audio_backends()
93
+
94
+ if "ffmpeg" in backends:
95
+ backend = "ffmpeg"
96
+ else:
97
+ backend = "soundfile"
98
+
99
+
100
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
101
+ buffer = io.BytesIO()
102
+
103
+ with wave.open(buffer, "wb") as wav_file:
104
+ wav_file.setnchannels(channels)
105
+ wav_file.setsampwidth(bit_depth // 8)
106
+ wav_file.setframerate(sample_rate)
107
+
108
+ wav_header_bytes = buffer.getvalue()
109
+ buffer.close()
110
+ return wav_header_bytes
111
+
112
+
113
+ # Define utils for web server
114
+ async def http_execption_handler(exc: HTTPException):
115
+ return JSONResponse(
116
+ dict(
117
+ statusCode=exc.status_code,
118
+ message=exc.content,
119
+ error=HTTPStatus(exc.status_code).phrase,
120
+ ),
121
+ exc.status_code,
122
+ exc.headers,
123
+ )
124
+
125
+
126
+ async def other_exception_handler(exc: "Exception"):
127
+ traceback.print_exc()
128
+
129
+ status = HTTPStatus.INTERNAL_SERVER_ERROR
130
+ return JSONResponse(
131
+ dict(statusCode=status, message=str(exc), error=status.phrase),
132
+ status,
133
+ )
134
+
135
+
136
+ def load_audio(reference_audio, sr):
137
+ if len(reference_audio) > 255 or not Path(reference_audio).exists():
138
+ audio_data = reference_audio
139
+ reference_audio = io.BytesIO(audio_data)
140
+
141
+ waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
142
+
143
+ if waveform.shape[0] > 1:
144
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
145
+
146
+ if original_sr != sr:
147
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
148
+ waveform = resampler(waveform)
149
+
150
+ audio = waveform.squeeze().numpy()
151
+ return audio
152
+
153
+
154
+ def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
155
+ if enable_reference_audio and reference_audio is not None:
156
+ # Load audios, and prepare basic info here
157
+ reference_audio_content = load_audio(
158
+ reference_audio, decoder_model.spec_transform.sample_rate
159
+ )
160
+
161
+ audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
162
+ None, None, :
163
+ ]
164
+ audio_lengths = torch.tensor(
165
+ [audios.shape[2]], device=decoder_model.device, dtype=torch.long
166
+ )
167
+ logger.info(
168
+ f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
169
+ )
170
+
171
+ # VQ Encoder
172
+ if isinstance(decoder_model, FireflyArchitecture):
173
+ prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
174
+
175
+ logger.info(f"Encoded prompt: {prompt_tokens.shape}")
176
+ else:
177
+ prompt_tokens = None
178
+ logger.info("No reference audio provided")
179
+
180
+ return prompt_tokens
181
+
182
+
183
+ def decode_vq_tokens(
184
+ *,
185
+ decoder_model,
186
+ codes,
187
+ ):
188
+ feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
189
+ logger.info(f"VQ features: {codes.shape}")
190
+
191
+ if isinstance(decoder_model, FireflyArchitecture):
192
+ # VQGAN Inference
193
+ return decoder_model.decode(
194
+ indices=codes[None],
195
+ feature_lengths=feature_lengths,
196
+ )[0].squeeze()
197
+
198
+ raise ValueError(f"Unknown model type: {type(decoder_model)}")
199
+
200
+
201
+ routes = MultimethodRoutes(base_class=HttpView)
202
+
203
+
204
+ def get_content_type(audio_format):
205
+ if audio_format == "wav":
206
+ return "audio/wav"
207
+ elif audio_format == "flac":
208
+ return "audio/flac"
209
+ elif audio_format == "mp3":
210
+ return "audio/mpeg"
211
+ else:
212
+ return "application/octet-stream"
213
+
214
+
215
+ @torch.no_grad()
216
+ @torch.autocast(device_type="cuda", dtype=torch.half)
217
+ def batch_encode(model, audios: list[bytes | torch.Tensor]):
218
+ audios = [
219
+ (
220
+ torch.from_numpy(
221
+ librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
222
+ )[None]
223
+ if isinstance(audio, bytes)
224
+ else audio
225
+ )
226
+ for audio in audios
227
+ ]
228
+
229
+ # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
230
+ # raise ValueError("Single audio length is too long (>120s)")
231
+
232
+ max_length = max(audio.shape[-1] for audio in audios)
233
+ print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
234
+
235
+ lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
236
+ max_length = lengths.max().item()
237
+ padded = torch.stack(
238
+ [
239
+ torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
240
+ for audio in audios
241
+ ]
242
+ ).to(model.device)
243
+
244
+ features, feature_lengths = model.encode(padded, audio_lengths=lengths)
245
+ features, feature_lengths = features.cpu(), feature_lengths.cpu()
246
+
247
+ return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
248
+
249
+
250
+ @cached(
251
+ cache=LRUCache(maxsize=10000),
252
+ key=lambda model, audios: (model.device, tuple(audios)),
253
+ )
254
+ def cached_vqgan_batch_encode(model, audios: list[bytes]):
255
+ return batch_encode(model, audios)
256
+
257
+
258
+ @routes.http.post("/v1/vqgan/encode")
259
+ def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
260
+
261
+ start_time = time.time()
262
+ tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
263
+ logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
264
+
265
+ return ormsgpack.packb(
266
+ ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
267
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
268
+ )
269
+
270
+
271
+ @torch.no_grad()
272
+ @torch.autocast(device_type="cuda", dtype=torch.half)
273
+ def vqgan_decode(model, features):
274
+ lengths = torch.tensor(
275
+ [feature.shape[-1] for feature in features], device=model.device
276
+ )
277
+ max_length = lengths.max().item()
278
+ padded = torch.stack(
279
+ [
280
+ torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
281
+ for feature in features
282
+ ]
283
+ ).to(model.device)
284
+
285
+ # If bs too large, we do micro batch decode
286
+ audios, audio_lengths = [], []
287
+ for i in range(0, padded.shape[0], 8):
288
+ audio, audio_length = model.decode(
289
+ padded[i : i + 8], feature_lengths=lengths[i : i + 8]
290
+ )
291
+ audios.append(audio)
292
+ audio_lengths.append(audio_length)
293
+ audios = torch.cat(audios, dim=0)
294
+ audio_lengths = torch.cat(audio_lengths, dim=0)
295
+ audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
296
+
297
+ return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
298
+
299
+
300
+ @routes.http.post("/v1/vqgan/decode")
301
+ def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
302
+ tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
303
+ start_time = time.time()
304
+ audios = vqgan_decode(decoder_model, tokens)
305
+ logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
306
+ audios = [audio.astype(np.float16).tobytes() for audio in audios]
307
+ return ormsgpack.packb(
308
+ ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
309
+ )
310
+
311
+
312
+ @torch.no_grad()
313
+ def batch_asr(model, audios, sr, language="auto"):
314
+ resampled_audios = []
315
+ for audio in audios:
316
+ audio = torchaudio.functional.resample(audio, sr, 16000)
317
+ assert audio.ndim == 1
318
+ resampled_audios.append(audio)
319
+
320
+ with global_lock:
321
+ res = model.generate(
322
+ input=resampled_audios,
323
+ batch_size=len(resampled_audios),
324
+ language=language,
325
+ use_itn=True,
326
+ )
327
+
328
+ results = []
329
+ for r, audio in zip(res, audios):
330
+ text = r["text"]
331
+ text = re.sub(r"<\|.*?\|>", "", text)
332
+ duration = len(audio) / sr * 1000
333
+ huge_gap = False
334
+
335
+ if "timestamp" in r and len(r["timestamp"]) > 2:
336
+ for timestamp_a, timestamp_b in zip(
337
+ r["timestamp"][:-1], r["timestamp"][1:]
338
+ ):
339
+ # If there is a gap of more than 5 seconds, we consider it as a huge gap
340
+ if timestamp_b[0] - timestamp_a[1] > 5000:
341
+ huge_gap = True
342
+ break
343
+
344
+ # Doesn't make sense to have a huge gap at the end
345
+ if duration - r["timestamp"][-1][1] > 3000:
346
+ huge_gap = True
347
+
348
+ results.append(
349
+ {
350
+ "text": text,
351
+ "duration": duration,
352
+ "huge_gap": huge_gap,
353
+ }
354
+ )
355
+
356
+ return results
357
+
358
+
359
+ @routes.http.post("/v1/asr")
360
+ def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
361
+ start_time = time.time()
362
+ audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
363
+ audios = [torch.from_numpy(audio).float() for audio in audios]
364
+
365
+ if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
366
+ raise HTTPException(status_code=400, detail="Audio length is too long")
367
+
368
+ transcriptions = batch_asr(
369
+ asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
370
+ )
371
+ logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
372
+
373
+ return ormsgpack.packb(
374
+ ServeASRResponse(transcriptions=transcriptions),
375
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
376
+ )
377
+
378
+
379
+ from fish_speech.conversation import Conversation, Message
380
+
381
+
382
+ def execute_request(
383
+ input_queue: queue.Queue,
384
+ tokenizer: AutoTokenizer,
385
+ config: BaseModelArgs,
386
+ request: ServeRequest,
387
+ device: str = "cuda:0",
388
+ ):
389
+ semantic_id, im_end_id = tokenizer.convert_tokens_to_ids(
390
+ [SEMANTIC_TOKEN, IM_END_TOKEN]
391
+ )
392
+ messages = []
393
+ for message in request.messages:
394
+ messages.append(message.to_conversation_message())
395
+
396
+ assert len(messages) >= 1, "At least one message is required"
397
+ # assert messages[-1].role == "user", "The last message must be from the user"
398
+
399
+ if messages[-1].role == "user":
400
+ messages.append(Message(role="assistant", parts=[], add_im_end=False))
401
+ else:
402
+ assert (
403
+ messages[-1].role == "assistant"
404
+ ), "The last message must be from the assistant"
405
+ messages[-1].add_im_end = False
406
+
407
+ conv = Conversation(messages=messages)
408
+ prompt = conv.encode_for_inference(
409
+ tokenizer=tokenizer, num_codebooks=config.num_codebooks
410
+ ).to(device)
411
+
412
+ if request.streaming:
413
+ for i in range(request.num_samples):
414
+ yield ServeStreamResponse(
415
+ sample_id=i,
416
+ delta=ServeStreamDelta(
417
+ role="assistant",
418
+ ),
419
+ )
420
+
421
+ req = {
422
+ "prompt": prompt,
423
+ "max_new_tokens": request.max_new_tokens,
424
+ "im_end_id": im_end_id,
425
+ "semantic_id": semantic_id,
426
+ "temperature": request.temperature,
427
+ "top_p": request.top_p,
428
+ "repetition_penalty": request.repetition_penalty,
429
+ "num_samples": request.num_samples,
430
+ "early_stop_threshold": request.early_stop_threshold,
431
+ }
432
+
433
+ start = time.time()
434
+ response_queue = queue.Queue()
435
+ input_queue.put(GenerateRequest(req, response_queue))
436
+
437
+ # Decoding
438
+ decode_buffer = [[] for _ in range(request.num_samples)]
439
+ parts = [[] for _ in range(request.num_samples)]
440
+
441
+ def send_reset_buffer(sample_id):
442
+ nonlocal decode_buffer
443
+ if len(decode_buffer[sample_id]) == 0:
444
+ return
445
+
446
+ decoded = tokenizer.decode(decode_buffer[sample_id])
447
+ part = ServeTextPart(text=decoded)
448
+
449
+ if request.streaming:
450
+ yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
451
+ else:
452
+ parts[sample_id].append(part)
453
+
454
+ decode_buffer[sample_id] = []
455
+
456
+ # Decode process
457
+ finished = [False for _ in range(request.num_samples)]
458
+ stats = {}
459
+ idx = 0
460
+ while True:
461
+ response = response_queue.get()
462
+
463
+ if response in ["stop", "error"]:
464
+ break
465
+
466
+ for sample_id, tokens in enumerate(response):
467
+ if finished[sample_id]:
468
+ continue
469
+
470
+ if tokens[0] == im_end_id:
471
+ finished[sample_id] = True
472
+ if request.streaming:
473
+ yield from send_reset_buffer(sample_id)
474
+ yield ServeStreamResponse(
475
+ sample_id=sample_id,
476
+ finish_reason="stop",
477
+ stats=stats,
478
+ )
479
+ continue
480
+
481
+ if tokens[0] == semantic_id and request.streaming:
482
+ yield from send_reset_buffer(sample_id)
483
+ # Streaming vq
484
+ _tokens = tokens[1:].clone() - 1
485
+
486
+ if config.share_codebook_embeddings is False:
487
+ for i in range(len(_tokens)):
488
+ _tokens[i] -= config.codebook_size * i
489
+
490
+ yield ServeStreamResponse(
491
+ sample_id=sample_id,
492
+ delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
493
+ )
494
+ continue
495
+
496
+ # Not streaming vq
497
+ if tokens[0] == semantic_id:
498
+ yield from send_reset_buffer(sample_id)
499
+ # None streaming vq
500
+ if len(parts[sample_id]) == 0 or not isinstance(
501
+ parts[sample_id][-1], ServeVQPart
502
+ ):
503
+ _tokens = tokens[1:].clone() - 1
504
+
505
+ if config.share_codebook_embeddings is False:
506
+ for i in range(len(_tokens)):
507
+ _tokens[i] -= config.codebook_size * i
508
+
509
+ parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
510
+ else:
511
+ for codebook_id, value in enumerate(tokens[1:, :]):
512
+ val = value.item() - 1
513
+ if config.share_codebook_embeddings is False:
514
+ val -= config.codebook_size * codebook_id
515
+
516
+ parts[sample_id][-1].codes[codebook_id].append(val)
517
+ continue
518
+
519
+ if tokens[0] != semantic_id:
520
+ # Stream text decode is not supported now
521
+ decode_buffer[sample_id].append(tokens[0, 0])
522
+
523
+ if idx == 0:
524
+ stats["time_to_first_token"] = (time.time() - start) * 1000
525
+
526
+ idx += 1
527
+
528
+ for sample_id in range(request.num_samples):
529
+ yield from send_reset_buffer(sample_id)
530
+
531
+ stats["total_time"] = (time.time() - start) * 1000
532
+ stats["total_tokens"] = idx
533
+
534
+ if request.streaming:
535
+ for sample_id in range(request.num_samples):
536
+ if finished[sample_id]:
537
+ continue
538
+ yield ServeStreamResponse(
539
+ finish_reason=response, stats=stats, sample_id=sample_id
540
+ )
541
+ return
542
+
543
+ yield ServeResponse(
544
+ messages=[
545
+ ServeMessage(role="assistant", parts=parts[i])
546
+ for i in range(request.num_samples)
547
+ ],
548
+ finish_reason=response,
549
+ stats=stats,
550
+ )
551
+
552
+
553
+ @routes.http.post("/v1/chat")
554
+ def api_invoke_chat(
555
+ req: Annotated[ServeRequest, Body(exclusive=True)],
556
+ ):
557
+ """
558
+ Invoke model and generate audio
559
+ """
560
+
561
+ # This makes torch compile happy
562
+ assert (
563
+ req.num_samples == GLOBAL_NUM_SAMPLES
564
+ ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
565
+
566
+ content_type = request.headers.get("Content-Type", "application/json")
567
+ json_mode = "application/json" in content_type
568
+
569
+ async def wrapped_generator():
570
+ generator = execute_request(llama_queue, tokenizer, config, req, args.device)
571
+
572
+ for i in generator:
573
+ if json_mode:
574
+ body = i.model_dump_json().encode("utf-8")
575
+ yield b"data: " + body + b"\n\n"
576
+ else:
577
+ body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
578
+ yield struct.pack("I", len(body)) + body
579
+
580
+ # Naive mode
581
+ if req.streaming is False:
582
+ result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
583
+
584
+ if json_mode:
585
+ return JSONResponse(result.model_dump())
586
+ else:
587
+ return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
588
+
589
+ return StreamResponse(
590
+ iterable=wrapped_generator(), content_type="text/event-stream"
591
+ )
592
+
593
+
594
+ @torch.inference_mode()
595
+ def inference(req: ServeTTSRequest):
596
+
597
+ global prompt_tokens, prompt_texts
598
+
599
+ idstr: str | None = req.reference_id
600
+ if idstr is not None:
601
+ ref_folder = Path("references") / idstr
602
+ ref_folder.mkdir(parents=True, exist_ok=True)
603
+ ref_audios = list_files(
604
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
605
+ )
606
+
607
+ if req.use_memory_cache == "never" or (
608
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
609
+ ):
610
+ prompt_tokens = [
611
+ encode_reference(
612
+ decoder_model=decoder_model,
613
+ reference_audio=audio_to_bytes(str(ref_audio)),
614
+ enable_reference_audio=True,
615
+ )
616
+ for ref_audio in ref_audios
617
+ ]
618
+ prompt_texts = [
619
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
620
+ for ref_audio in ref_audios
621
+ ]
622
+ else:
623
+ logger.info("Use same references")
624
+
625
+ else:
626
+ # Parse reference audio aka prompt
627
+ refs = req.references
628
+
629
+ if req.use_memory_cache == "never" or (
630
+ req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
631
+ ):
632
+ prompt_tokens = [
633
+ encode_reference(
634
+ decoder_model=decoder_model,
635
+ reference_audio=ref.audio,
636
+ enable_reference_audio=True,
637
+ )
638
+ for ref in refs
639
+ ]
640
+ prompt_texts = [ref.text for ref in refs]
641
+ else:
642
+ logger.info("Use same references")
643
+
644
+ if req.seed is not None:
645
+ set_seed(req.seed)
646
+ logger.warning(f"set seed: {req.seed}")
647
+
648
+ # LLAMA Inference
649
+ request = dict(
650
+ device=decoder_model.device,
651
+ max_new_tokens=req.max_new_tokens,
652
+ text=(
653
+ req.text
654
+ if not req.normalize
655
+ else ChnNormedText(raw_text=req.text).normalize()
656
+ ),
657
+ top_p=req.top_p,
658
+ repetition_penalty=req.repetition_penalty,
659
+ temperature=req.temperature,
660
+ compile=args.compile,
661
+ iterative_prompt=req.chunk_length > 0,
662
+ chunk_length=req.chunk_length,
663
+ max_length=4096,
664
+ prompt_tokens=prompt_tokens,
665
+ prompt_text=prompt_texts,
666
+ )
667
+
668
+ response_queue = queue.Queue()
669
+ llama_queue.put(
670
+ GenerateRequest(
671
+ request=request,
672
+ response_queue=response_queue,
673
+ )
674
+ )
675
+
676
+ if req.streaming:
677
+ yield wav_chunk_header()
678
+
679
+ segments = []
680
+ while True:
681
+ result: WrappedGenerateResponse = response_queue.get()
682
+ if result.status == "error":
683
+ raise result.response
684
+ break
685
+
686
+ result: GenerateResponse = result.response
687
+ if result.action == "next":
688
+ break
689
+
690
+ with autocast_exclude_mps(
691
+ device_type=decoder_model.device.type, dtype=args.precision
692
+ ):
693
+ fake_audios = decode_vq_tokens(
694
+ decoder_model=decoder_model,
695
+ codes=result.codes,
696
+ )
697
+
698
+ fake_audios = fake_audios.float().cpu().numpy()
699
+
700
+ if req.streaming:
701
+ yield (fake_audios * 32768).astype(np.int16).tobytes()
702
+ else:
703
+ segments.append(fake_audios)
704
+
705
+ if req.streaming:
706
+ return
707
+
708
+ if len(segments) == 0:
709
+ raise HTTPException(
710
+ HTTPStatus.INTERNAL_SERVER_ERROR,
711
+ content="No audio generated, please check the input text.",
712
+ )
713
+
714
+ fake_audios = np.concatenate(segments, axis=0)
715
+ yield fake_audios
716
+
717
+
718
+ async def inference_async(req: ServeTTSRequest):
719
+ for chunk in inference(req):
720
+ yield chunk
721
+
722
+
723
+ async def buffer_to_async_generator(buffer):
724
+ yield buffer
725
+
726
+
727
+ @routes.http.post("/v1/tts")
728
+ async def api_invoke_model(
729
+ req: Annotated[ServeTTSRequest, Body(exclusive=True)],
730
+ ):
731
+ """
732
+ Invoke model and generate audio
733
+ """
734
+
735
+ if args.max_text_length > 0 and len(req.text) > args.max_text_length:
736
+ raise HTTPException(
737
+ HTTPStatus.BAD_REQUEST,
738
+ content=f"Text is too long, max length is {args.max_text_length}",
739
+ )
740
+
741
+ if req.streaming and req.format != "wav":
742
+ raise HTTPException(
743
+ HTTPStatus.BAD_REQUEST,
744
+ content="Streaming only supports WAV format",
745
+ )
746
+
747
+ if req.streaming:
748
+ return StreamResponse(
749
+ iterable=inference_async(req),
750
+ headers={
751
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
752
+ },
753
+ content_type=get_content_type(req.format),
754
+ )
755
+ else:
756
+ fake_audios = next(inference(req))
757
+ buffer = io.BytesIO()
758
+ sf.write(
759
+ buffer,
760
+ fake_audios,
761
+ decoder_model.spec_transform.sample_rate,
762
+ format=req.format,
763
+ )
764
+
765
+ return StreamResponse(
766
+ iterable=buffer_to_async_generator(buffer.getvalue()),
767
+ headers={
768
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
769
+ },
770
+ content_type=get_content_type(req.format),
771
+ )
772
+
773
+
774
+ @routes.http.post("/v1/health")
775
+ async def api_health():
776
+ """
777
+ Health check
778
+ """
779
+
780
+ return JSONResponse({"status": "ok"})
781
+
782
+
783
+ def parse_args():
784
+ parser = ArgumentParser()
785
+ parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
786
+ parser.add_argument("--load-asr-model", action="store_true")
787
+ parser.add_argument(
788
+ "--llama-checkpoint-path",
789
+ type=str,
790
+ default="checkpoints/fish-speech-1.4",
791
+ )
792
+ parser.add_argument(
793
+ "--decoder-checkpoint-path",
794
+ type=str,
795
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
796
+ )
797
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
798
+ parser.add_argument("--device", type=str, default="cuda")
799
+ parser.add_argument("--half", action="store_true")
800
+ parser.add_argument("--compile", action="store_true")
801
+ parser.add_argument("--max-text-length", type=int, default=0)
802
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
803
+ parser.add_argument("--workers", type=int, default=1)
804
+
805
+ return parser.parse_args()
806
+
807
+
808
+ # Define Kui app
809
+ openapi = OpenAPI(
810
+ {
811
+ "title": "Fish Speech API",
812
+ "version": "1.4.2",
813
+ },
814
+ ).routes
815
+
816
+
817
+ class MsgPackRequest(HttpRequest):
818
+ async def data(
819
+ self,
820
+ ) -> Annotated[
821
+ Any, ContentType("application/msgpack"), ContentType("application/json")
822
+ ]:
823
+ if self.content_type == "application/msgpack":
824
+ return ormsgpack.unpackb(await self.body)
825
+
826
+ elif self.content_type == "application/json":
827
+ return await self.json
828
+
829
+ raise HTTPException(
830
+ HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
831
+ headers={"Accept": "application/msgpack, application/json"},
832
+ )
833
+
834
+
835
+ app = Kui(
836
+ routes=routes + openapi[1:], # Remove the default route
837
+ exception_handlers={
838
+ HTTPException: http_execption_handler,
839
+ Exception: other_exception_handler,
840
+ },
841
+ factory_class=FactoryClass(http=MsgPackRequest),
842
+ cors_config={},
843
+ )
844
+
845
+
846
+ def load_asr_model(*, device="cuda", hub="ms"):
847
+ return AutoModel(
848
+ model="iic/SenseVoiceSmall",
849
+ device=device,
850
+ disable_pbar=True,
851
+ hub=hub,
852
+ )
853
+
854
+
855
+ # Each worker process created by Uvicorn has its own memory space,
856
+ # meaning that models and variables are not shared between processes.
857
+ # Therefore, any global variables (like `llama_queue` or `decoder_model`)
858
+ # will not be shared across workers.
859
+
860
+
861
+ # Multi-threading for deep learning can cause issues, such as inconsistent
862
+ # outputs if multiple threads access the same buffers simultaneously.
863
+ # Instead, it's better to use multiprocessing or independent models per thread.
864
+ @app.on_startup
865
+ def initialize_app(app: Kui):
866
+
867
+ global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
868
+
869
+ prompt_tokens, prompt_texts = [], []
870
+
871
+ args = parse_args() # args same as ones in other processes
872
+ args.precision = torch.half if args.half else torch.bfloat16
873
+
874
+ if args.load_asr_model:
875
+ logger.info(f"Loading ASR model...")
876
+ asr_model = load_asr_model(device=args.device)
877
+
878
+ logger.info("Loading Llama model...")
879
+
880
+ if args.mode == "tts":
881
+ llama_queue = launch_thread_safe_queue(
882
+ checkpoint_path=args.llama_checkpoint_path,
883
+ device=args.device,
884
+ precision=args.precision,
885
+ compile=args.compile,
886
+ )
887
+ else:
888
+ llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
889
+ checkpoint_path=args.llama_checkpoint_path,
890
+ device=args.device,
891
+ precision=args.precision,
892
+ compile=args.compile,
893
+ )
894
+
895
+ logger.info("Llama model loaded, loading VQ-GAN model...")
896
+
897
+ decoder_model = load_decoder_model(
898
+ config_name=args.decoder_config_name,
899
+ checkpoint_path=args.decoder_checkpoint_path,
900
+ device=args.device,
901
+ )
902
+
903
+ logger.info("VQ-GAN model loaded, warming up...")
904
+
905
+ vad_model = load_silero_vad()
906
+
907
+ logger.info("VAD model loaded, warming up...")
908
+
909
+ if args.mode == "tts":
910
+ # Dry run to ensure models work and avoid first-time latency
911
+ list(
912
+ inference(
913
+ ServeTTSRequest(
914
+ text="Hello world.",
915
+ references=[],
916
+ reference_id=None,
917
+ max_new_tokens=0,
918
+ chunk_length=200,
919
+ top_p=0.7,
920
+ repetition_penalty=1.2,
921
+ temperature=0.7,
922
+ emotion=None,
923
+ format="wav",
924
+ )
925
+ )
926
+ )
927
+
928
+ logger.info(f"Warming up done, starting server at http://{args.listen}")
929
+
930
+
931
+ if __name__ == "__main__":
932
+
933
+ import uvicorn
934
+
935
+ args = parse_args()
936
+ host, port = args.listen.split(":")
937
+ uvicorn.run(
938
+ "tools.api:app",
939
+ host='0.0.0.0',
940
+ port=int(port),
941
+ workers=args.workers,
942
+ log_level="info",
943
+ )
tools/auto_rerank.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["MODELSCOPE_CACHE"] = ".cache/"
4
+
5
+ import string
6
+ import time
7
+ from threading import Lock
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import opencc
12
+ import torch
13
+ from faster_whisper import WhisperModel
14
+
15
+ t2s_converter = opencc.OpenCC("t2s")
16
+
17
+
18
+ def load_model(*, device="cuda"):
19
+ model = WhisperModel(
20
+ "medium",
21
+ device=device,
22
+ compute_type="float16",
23
+ download_root="faster_whisper",
24
+ )
25
+ print("faster_whisper loaded!")
26
+ return model
27
+
28
+
29
+ @torch.no_grad()
30
+ def batch_asr_internal(model: WhisperModel, audios, sr):
31
+ resampled_audios = []
32
+ for audio in audios:
33
+
34
+ if isinstance(audio, np.ndarray):
35
+ audio = torch.from_numpy(audio).float()
36
+
37
+ if audio.dim() > 1:
38
+ audio = audio.squeeze()
39
+
40
+ assert audio.dim() == 1
41
+ audio_np = audio.numpy()
42
+ resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
43
+ resampled_audios.append(resampled_audio)
44
+
45
+ trans_results = []
46
+
47
+ for resampled_audio in resampled_audios:
48
+ segments, info = model.transcribe(
49
+ resampled_audio,
50
+ language=None,
51
+ beam_size=5,
52
+ initial_prompt="Punctuation is needed in any language.",
53
+ )
54
+ trans_results.append(list(segments))
55
+
56
+ results = []
57
+ for trans_res, audio in zip(trans_results, audios):
58
+
59
+ duration = len(audio) / sr * 1000
60
+ huge_gap = False
61
+ max_gap = 0.0
62
+
63
+ text = None
64
+ last_tr = None
65
+
66
+ for tr in trans_res:
67
+ delta = tr.text.strip()
68
+ if tr.id > 1:
69
+ max_gap = max(tr.start - last_tr.end, max_gap)
70
+ text += delta
71
+ else:
72
+ text = delta
73
+
74
+ last_tr = tr
75
+ if max_gap > 3.0:
76
+ huge_gap = True
77
+ break
78
+
79
+ sim_text = t2s_converter.convert(text)
80
+ results.append(
81
+ {
82
+ "text": sim_text,
83
+ "duration": duration,
84
+ "huge_gap": huge_gap,
85
+ }
86
+ )
87
+
88
+ return results
89
+
90
+
91
+ global_lock = Lock()
92
+
93
+
94
+ def batch_asr(model, audios, sr):
95
+ return batch_asr_internal(model, audios, sr)
96
+
97
+
98
+ def is_chinese(text):
99
+ return True
100
+
101
+
102
+ def calculate_wer(text1, text2, debug=False):
103
+ chars1 = remove_punctuation(text1)
104
+ chars2 = remove_punctuation(text2)
105
+
106
+ m, n = len(chars1), len(chars2)
107
+
108
+ if m > n:
109
+ chars1, chars2 = chars2, chars1
110
+ m, n = n, m
111
+
112
+ prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
113
+ curr = [0] * (m + 1)
114
+
115
+ for j in range(1, n + 1):
116
+ curr[0] = j
117
+ for i in range(1, m + 1):
118
+ if chars1[i - 1] == chars2[j - 1]:
119
+ curr[i] = prev[i - 1]
120
+ else:
121
+ curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
122
+ prev, curr = curr, prev
123
+
124
+ edits = prev[m]
125
+ tot = max(len(chars1), len(chars2))
126
+ wer = edits / tot
127
+
128
+ if debug:
129
+ print(" gt: ", chars1)
130
+ print(" pred: ", chars2)
131
+ print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
132
+
133
+ return wer
134
+
135
+
136
+ def remove_punctuation(text):
137
+ chinese_punctuation = (
138
+ " \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
139
+ '‛""„‟…‧﹏'
140
+ )
141
+ all_punctuation = string.punctuation + chinese_punctuation
142
+ translator = str.maketrans("", "", all_punctuation)
143
+ text_without_punctuation = text.translate(translator)
144
+ return text_without_punctuation
145
+
146
+
147
+ if __name__ == "__main__":
148
+ model = load_model()
149
+ audios = [
150
+ librosa.load("44100.wav", sr=44100)[0],
151
+ librosa.load("lengyue.wav", sr=44100)[0],
152
+ ]
153
+ print(np.array(audios[0]))
154
+ print(batch_asr(model, audios, 44100))
155
+
156
+ start_time = time.time()
157
+ for _ in range(10):
158
+ print(batch_asr(model, audios, 44100))
159
+ print("Time taken:", time.time() - start_time)
tools/commons.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, Literal, Optional
2
+
3
+ from pydantic import BaseModel, Field, conint
4
+
5
+
6
+ class ServeReferenceAudio(BaseModel):
7
+ audio: bytes
8
+ text: str
9
+
10
+
11
+ class ServeTTSRequest(BaseModel):
12
+ text: str
13
+ chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
14
+ # Audio format
15
+ format: Literal["wav", "pcm", "mp3"] = "wav"
16
+ mp3_bitrate: Literal[64, 128, 192] = 128
17
+ # References audios for in-context learning
18
+ references: list[ServeReferenceAudio] = []
19
+ # Reference id
20
+ # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
21
+ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
22
+ reference_id: str | None = None
23
+ # Normalize text for en & zh, this increase stability for numbers
24
+ normalize: bool = True
25
+ mp3_bitrate: Optional[int] = 64
26
+ opus_bitrate: Optional[int] = -1000
27
+ # Balance mode will reduce latency to 300ms, but may decrease stability
28
+ latency: Literal["normal", "balanced"] = "normal"
29
+ # not usually used below
30
+ streaming: bool = False
31
+ emotion: Optional[str] = None
32
+ max_new_tokens: int = 1024
33
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
34
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
35
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
tools/download_models.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import hf_hub_download
4
+
5
+
6
+ # Download
7
+ def check_and_download_files(repo_id, file_list, local_dir):
8
+ os.makedirs(local_dir, exist_ok=True)
9
+ for file in file_list:
10
+ file_path = os.path.join(local_dir, file)
11
+ if not os.path.exists(file_path):
12
+ print(f"{file} 不存在,从 Hugging Face 仓库下载...")
13
+ hf_hub_download(
14
+ repo_id=repo_id,
15
+ filename=file,
16
+ resume_download=True,
17
+ local_dir=local_dir,
18
+ local_dir_use_symlinks=False,
19
+ )
20
+ else:
21
+ print(f"{file} 已存在,跳过下载。")
22
+
23
+
24
+ # 1st
25
+ repo_id_1 = "fishaudio/fish-speech-1.4"
26
+ local_dir_1 = "./checkpoints/fish-speech-1.4"
27
+ files_1 = [
28
+ "model.pth",
29
+ "README.md",
30
+ "special_tokens_map.json",
31
+ "tokenizer_config.json",
32
+ "tokenizer.json",
33
+ "config.json",
34
+ "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
35
+ ]
36
+
37
+ # 3rd
38
+ repo_id_3 = "fishaudio/fish-speech-1"
39
+ local_dir_3 = "./"
40
+ files_3 = [
41
+ "ffmpeg.exe",
42
+ "ffprobe.exe",
43
+ ]
44
+
45
+ # 4th
46
+ repo_id_4 = "SpicyqSama007/fish-speech-packed"
47
+ local_dir_4 = "./"
48
+ files_4 = [
49
+ "asr-label-win-x64.exe",
50
+ ]
51
+
52
+ check_and_download_files(repo_id_1, files_1, local_dir_1)
53
+
54
+ check_and_download_files(repo_id_3, files_3, local_dir_3)
55
+ check_and_download_files(repo_id_4, files_4, local_dir_4)
tools/e2e_webui.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import re
3
+ import wave
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+
8
+ from .fish_e2e import FishE2EAgent, FishE2EEventType
9
+ from .schema import ServeMessage, ServeTextPart, ServeVQPart
10
+
11
+
12
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
13
+ buffer = io.BytesIO()
14
+
15
+ with wave.open(buffer, "wb") as wav_file:
16
+ wav_file.setnchannels(channels)
17
+ wav_file.setsampwidth(bit_depth // 8)
18
+ wav_file.setframerate(sample_rate)
19
+
20
+ wav_header_bytes = buffer.getvalue()
21
+ buffer.close()
22
+ return wav_header_bytes
23
+
24
+
25
+ class ChatState:
26
+ def __init__(self):
27
+ self.conversation = []
28
+ self.added_systext = False
29
+ self.added_sysaudio = False
30
+
31
+ def get_history(self):
32
+ results = []
33
+ for msg in self.conversation:
34
+ results.append({"role": msg.role, "content": self.repr_message(msg)})
35
+
36
+ # Process assistant messages to extract questions and update user messages
37
+ for i, msg in enumerate(results):
38
+ if msg["role"] == "assistant":
39
+ match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
40
+ if match and i > 0 and results[i - 1]["role"] == "user":
41
+ # Update previous user message with extracted question
42
+ results[i - 1]["content"] += "\n" + match.group(1)
43
+ # Remove the Question/Answer format from assistant message
44
+ msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
45
+ return results
46
+
47
+ def repr_message(self, msg: ServeMessage):
48
+ response = ""
49
+ for part in msg.parts:
50
+ if isinstance(part, ServeTextPart):
51
+ response += part.text
52
+ elif isinstance(part, ServeVQPart):
53
+ response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
54
+ return response
55
+
56
+
57
+ def clear_fn():
58
+ return [], ChatState(), None, None, None
59
+
60
+
61
+ async def process_audio_input(
62
+ sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
63
+ ):
64
+ if audio_input is None and not text_input:
65
+ raise gr.Error("No input provided")
66
+
67
+ agent = FishE2EAgent() # Create new agent instance for each request
68
+
69
+ # Convert audio input to numpy array
70
+ if isinstance(audio_input, tuple):
71
+ sr, audio_data = audio_input
72
+ elif text_input:
73
+ sr = 44100
74
+ audio_data = None
75
+ else:
76
+ raise gr.Error("Invalid audio format")
77
+
78
+ if isinstance(sys_audio_input, tuple):
79
+ sr, sys_audio_data = sys_audio_input
80
+ else:
81
+ sr = 44100
82
+ sys_audio_data = None
83
+
84
+ def append_to_chat_ctx(
85
+ part: ServeTextPart | ServeVQPart, role: str = "assistant"
86
+ ) -> None:
87
+ if not state.conversation or state.conversation[-1].role != role:
88
+ state.conversation.append(ServeMessage(role=role, parts=[part]))
89
+ else:
90
+ state.conversation[-1].parts.append(part)
91
+
92
+ if state.added_systext is False and sys_text_input:
93
+ state.added_systext = True
94
+ append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
95
+ if text_input:
96
+ append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
97
+ audio_data = None
98
+
99
+ result_audio = b""
100
+ async for event in agent.stream(
101
+ sys_audio_data,
102
+ audio_data,
103
+ sr,
104
+ 1,
105
+ chat_ctx={
106
+ "messages": state.conversation,
107
+ "added_sysaudio": state.added_sysaudio,
108
+ },
109
+ ):
110
+ if event.type == FishE2EEventType.USER_CODES:
111
+ append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
112
+ elif event.type == FishE2EEventType.SPEECH_SEGMENT:
113
+ append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
114
+ yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
115
+ elif event.type == FishE2EEventType.TEXT_SEGMENT:
116
+ append_to_chat_ctx(ServeTextPart(text=event.text))
117
+ yield state.get_history(), None, None, None
118
+
119
+ yield state.get_history(), None, None, None
120
+
121
+
122
+ async def process_text_input(
123
+ sys_audio_input, sys_text_input, state: ChatState, text_input: str
124
+ ):
125
+ async for event in process_audio_input(
126
+ sys_audio_input, sys_text_input, None, state, text_input
127
+ ):
128
+ yield event
129
+
130
+
131
+ def create_demo():
132
+ with gr.Blocks() as demo:
133
+ state = gr.State(ChatState())
134
+
135
+ with gr.Row():
136
+ # Left column (70%) for chatbot and notes
137
+ with gr.Column(scale=7):
138
+ chatbot = gr.Chatbot(
139
+ [],
140
+ elem_id="chatbot",
141
+ bubble_full_width=False,
142
+ height=600,
143
+ type="messages",
144
+ )
145
+
146
+ # notes = gr.Markdown(
147
+ # """
148
+ # # Fish Agent
149
+ # 1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
150
+ # 2. 你可以在我们的官方仓��找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
151
+ # 3. Demo为早期灰度测试版本,推理速度尚待优化.
152
+ # # 特色
153
+ # 1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
154
+ # 2. 模型可以使用reference audio控制说话音色.
155
+ # 3. 可以生成具有较强情感与韵律的音频.
156
+ # """
157
+ # )
158
+ notes = gr.Markdown(
159
+ """
160
+ # Fish Agent
161
+ 1. This demo is Fish Audio's self-researh end-to-end language model, Fish Agent version 3B.
162
+ 2. You can find the code and weights in our official repo in [gitub](https://github.com/fishaudio/fish-speech) and [hugging face](https://huggingface.co/fishaudio/fish-agent-v0.1-3b), but the content is released under a CC BY-NC-SA 4.0 licence.
163
+ 3. The demo is an early alpha test version, the inference speed needs to be optimised.
164
+ # Features
165
+ 1. The model automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
166
+ 2. The model can use reference audio to control the speech timbre.
167
+ 3. The model can generate speech with strong emotion.
168
+ """
169
+ )
170
+
171
+ # Right column (30%) for controls
172
+ with gr.Column(scale=3):
173
+ sys_audio_input = gr.Audio(
174
+ sources=["upload"],
175
+ type="numpy",
176
+ label="Give a timbre for your assistant",
177
+ )
178
+ sys_text_input = gr.Textbox(
179
+ label="What is your assistant's role?",
180
+ value="You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user's speech, then answer it in the following format: 'Question: [USER_SPEECH]\n\nAnswer: [YOUR_RESPONSE]\n'. You are required to use the following voice in this conversation.",
181
+ type="text",
182
+ )
183
+ audio_input = gr.Audio(
184
+ sources=["microphone"], type="numpy", label="Speak your message"
185
+ )
186
+
187
+ text_input = gr.Textbox(label="Or type your message", type="text")
188
+
189
+ output_audio = gr.Audio(
190
+ label="Assistant's Voice",
191
+ streaming=True,
192
+ autoplay=True,
193
+ interactive=False,
194
+ )
195
+
196
+ send_button = gr.Button("Send", variant="primary")
197
+ clear_button = gr.Button("Clear")
198
+
199
+ # Event handlers
200
+ audio_input.stop_recording(
201
+ process_audio_input,
202
+ inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
203
+ outputs=[chatbot, output_audio, audio_input, text_input],
204
+ show_progress=True,
205
+ )
206
+
207
+ send_button.click(
208
+ process_text_input,
209
+ inputs=[sys_audio_input, sys_text_input, state, text_input],
210
+ outputs=[chatbot, output_audio, audio_input, text_input],
211
+ show_progress=True,
212
+ )
213
+
214
+ text_input.submit(
215
+ process_text_input,
216
+ inputs=[sys_audio_input, sys_text_input, state, text_input],
217
+ outputs=[chatbot, output_audio, audio_input, text_input],
218
+ show_progress=True,
219
+ )
220
+
221
+ clear_button.click(
222
+ clear_fn,
223
+ inputs=[],
224
+ outputs=[chatbot, state, audio_input, output_audio, text_input],
225
+ )
226
+
227
+ return demo
228
+
229
+
230
+ if __name__ == "__main__":
231
+ demo = create_demo()
232
+ demo.launch(server_name="127.0.0.1", server_port=7860, share=True)
tools/extract_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import torch
3
+ from loguru import logger
4
+
5
+
6
+ @click.command()
7
+ @click.argument("model_path")
8
+ @click.argument("output_path")
9
+ def main(model_path, output_path):
10
+ if model_path == output_path:
11
+ logger.error("Model path and output path are the same")
12
+ return
13
+
14
+ logger.info(f"Loading model from {model_path}")
15
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
16
+ torch.save(state_dict, output_path)
17
+ logger.info(f"Model saved to {output_path}")
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
tools/file.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ from loguru import logger
6
+ from natsort import natsorted
7
+
8
+ AUDIO_EXTENSIONS = {
9
+ ".mp3",
10
+ ".wav",
11
+ ".flac",
12
+ ".ogg",
13
+ ".m4a",
14
+ ".wma",
15
+ ".aac",
16
+ ".aiff",
17
+ ".aif",
18
+ ".aifc",
19
+ }
20
+
21
+ VIDEO_EXTENSIONS = {
22
+ ".mp4",
23
+ ".avi",
24
+ }
25
+
26
+
27
+ def audio_to_bytes(file_path):
28
+ if not file_path or not Path(file_path).exists():
29
+ return None
30
+ with open(file_path, "rb") as wav_file:
31
+ wav = wav_file.read()
32
+ return wav
33
+
34
+
35
+ def read_ref_text(ref_text):
36
+ path = Path(ref_text)
37
+ if path.exists() and path.is_file():
38
+ with path.open("r", encoding="utf-8") as file:
39
+ return file.read()
40
+ return ref_text
41
+
42
+
43
+ def list_files(
44
+ path: Union[Path, str],
45
+ extensions: set[str] = None,
46
+ recursive: bool = False,
47
+ sort: bool = True,
48
+ ) -> list[Path]:
49
+ """List files in a directory.
50
+
51
+ Args:
52
+ path (Path): Path to the directory.
53
+ extensions (set, optional): Extensions to filter. Defaults to None.
54
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
55
+ sort (bool, optional): Whether to sort the files. Defaults to True.
56
+
57
+ Returns:
58
+ list: List of files.
59
+ """
60
+
61
+ if isinstance(path, str):
62
+ path = Path(path)
63
+
64
+ if not path.exists():
65
+ raise FileNotFoundError(f"Directory {path} does not exist.")
66
+
67
+ files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
68
+
69
+ if sort:
70
+ files = natsorted(files)
71
+
72
+ return files
73
+
74
+
75
+ def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
76
+ """
77
+ Load a Bert-VITS2 style filelist.
78
+ """
79
+
80
+ files = set()
81
+ results = []
82
+ count_duplicated, count_not_found = 0, 0
83
+
84
+ LANGUAGE_TO_LANGUAGES = {
85
+ "zh": ["zh", "en"],
86
+ "jp": ["jp", "en"],
87
+ "en": ["en"],
88
+ }
89
+
90
+ with open(path, "r", encoding="utf-8") as f:
91
+ for line in f.readlines():
92
+ splits = line.strip().split("|", maxsplit=3)
93
+ if len(splits) != 4:
94
+ logger.warning(f"Invalid line: {line}")
95
+ continue
96
+
97
+ filename, speaker, language, text = splits
98
+ file = Path(filename)
99
+ language = language.strip().lower()
100
+
101
+ if language == "ja":
102
+ language = "jp"
103
+
104
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
105
+ languages = LANGUAGE_TO_LANGUAGES[language]
106
+
107
+ if file in files:
108
+ logger.warning(f"Duplicated file: {file}")
109
+ count_duplicated += 1
110
+ continue
111
+
112
+ if not file.exists():
113
+ logger.warning(f"File not found: {file}")
114
+ count_not_found += 1
115
+ continue
116
+
117
+ results.append((file, speaker, languages, text))
118
+
119
+ if count_duplicated > 0:
120
+ logger.warning(f"Total duplicated files: {count_duplicated}")
121
+
122
+ if count_not_found > 0:
123
+ logger.warning(f"Total files not found: {count_not_found}")
124
+
125
+ return results
tools/fish_e2e.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import ctypes
3
+ import io
4
+ import json
5
+ import os
6
+ import struct
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+ from typing import AsyncGenerator, Union
10
+
11
+ import httpx
12
+ import numpy as np
13
+ import ormsgpack
14
+ import soundfile as sf
15
+
16
+ from .schema import (
17
+ ServeMessage,
18
+ ServeRequest,
19
+ ServeTextPart,
20
+ ServeVQGANDecodeRequest,
21
+ ServeVQGANEncodeRequest,
22
+ ServeVQPart,
23
+ )
24
+
25
+
26
+ class CustomAudioFrame:
27
+ def __init__(self, data, sample_rate, num_channels, samples_per_channel):
28
+ if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
29
+ ctypes.c_int16
30
+ ):
31
+ raise ValueError(
32
+ "data length must be >= num_channels * samples_per_channel * sizeof(int16)"
33
+ )
34
+
35
+ self._data = bytearray(data)
36
+ self._sample_rate = sample_rate
37
+ self._num_channels = num_channels
38
+ self._samples_per_channel = samples_per_channel
39
+
40
+ @property
41
+ def data(self):
42
+ return memoryview(self._data).cast("h")
43
+
44
+ @property
45
+ def sample_rate(self):
46
+ return self._sample_rate
47
+
48
+ @property
49
+ def num_channels(self):
50
+ return self._num_channels
51
+
52
+ @property
53
+ def samples_per_channel(self):
54
+ return self._samples_per_channel
55
+
56
+ @property
57
+ def duration(self):
58
+ return self.samples_per_channel / self.sample_rate
59
+
60
+ def __repr__(self):
61
+ return (
62
+ f"CustomAudioFrame(sample_rate={self.sample_rate}, "
63
+ f"num_channels={self.num_channels}, "
64
+ f"samples_per_channel={self.samples_per_channel}, "
65
+ f"duration={self.duration:.3f})"
66
+ )
67
+
68
+
69
+ class FishE2EEventType(Enum):
70
+ SPEECH_SEGMENT = 1
71
+ TEXT_SEGMENT = 2
72
+ END_OF_TEXT = 3
73
+ END_OF_SPEECH = 4
74
+ ASR_RESULT = 5
75
+ USER_CODES = 6
76
+
77
+
78
+ @dataclass
79
+ class FishE2EEvent:
80
+ type: FishE2EEventType
81
+ frame: np.ndarray = None
82
+ text: str = None
83
+ vq_codes: list[list[int]] = None
84
+
85
+
86
+ client = httpx.AsyncClient(
87
+ timeout=None,
88
+ limits=httpx.Limits(
89
+ max_connections=None,
90
+ max_keepalive_connections=None,
91
+ keepalive_expiry=None,
92
+ ),
93
+ )
94
+
95
+
96
+ class FishE2EAgent:
97
+ def __init__(self):
98
+ self.llm_url = "http://localhost:8080/v1/chat"
99
+ self.vqgan_url = "http://localhost:8080"
100
+ self.client = httpx.AsyncClient(timeout=None)
101
+
102
+ async def get_codes(self, audio_data, sample_rate):
103
+ audio_buffer = io.BytesIO()
104
+ sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
105
+ audio_buffer.seek(0)
106
+ # Step 1: Encode audio using VQGAN
107
+ encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
108
+ encode_request_bytes = ormsgpack.packb(
109
+ encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
110
+ )
111
+ encode_response = await self.client.post(
112
+ f"{self.vqgan_url}/v1/vqgan/encode",
113
+ data=encode_request_bytes,
114
+ headers={"Content-Type": "application/msgpack"},
115
+ )
116
+ encode_response_data = ormsgpack.unpackb(encode_response.content)
117
+ codes = encode_response_data["tokens"][0]
118
+ return codes
119
+
120
+ async def stream(
121
+ self,
122
+ system_audio_data: np.ndarray | None,
123
+ user_audio_data: np.ndarray | None,
124
+ sample_rate: int,
125
+ num_channels: int,
126
+ chat_ctx: dict | None = None,
127
+ ) -> AsyncGenerator[bytes, None]:
128
+
129
+ if system_audio_data is not None:
130
+ sys_codes = await self.get_codes(system_audio_data, sample_rate)
131
+ else:
132
+ sys_codes = None
133
+ if user_audio_data is not None:
134
+ user_codes = await self.get_codes(user_audio_data, sample_rate)
135
+ # Step 2: Prepare LLM request
136
+ if chat_ctx is None:
137
+ sys_parts = [
138
+ ServeTextPart(
139
+ text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
140
+ ),
141
+ ]
142
+ if system_audio_data is not None:
143
+ sys_parts.append(ServeVQPart(codes=sys_codes))
144
+ chat_ctx = {
145
+ "messages": [
146
+ ServeMessage(
147
+ role="system",
148
+ parts=sys_parts,
149
+ ),
150
+ ],
151
+ }
152
+ else:
153
+ if chat_ctx["added_sysaudio"] is False and sys_codes:
154
+ chat_ctx["added_sysaudio"] = True
155
+ chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
156
+
157
+ prev_messages = chat_ctx["messages"].copy()
158
+ if user_audio_data is not None:
159
+ yield FishE2EEvent(
160
+ type=FishE2EEventType.USER_CODES,
161
+ vq_codes=user_codes,
162
+ )
163
+ else:
164
+ user_codes = None
165
+
166
+ request = ServeRequest(
167
+ messages=prev_messages
168
+ + (
169
+ [
170
+ ServeMessage(
171
+ role="user",
172
+ parts=[ServeVQPart(codes=user_codes)],
173
+ )
174
+ ]
175
+ if user_codes
176
+ else []
177
+ ),
178
+ streaming=True,
179
+ num_samples=1,
180
+ )
181
+
182
+ # Step 3: Stream LLM response and decode audio
183
+ buffer = b""
184
+ vq_codes = []
185
+ current_vq = False
186
+
187
+ async def decode_send():
188
+ nonlocal current_vq
189
+ nonlocal vq_codes
190
+
191
+ data = np.concatenate(vq_codes, axis=1).tolist()
192
+ # Decode VQ codes to audio
193
+ decode_request = ServeVQGANDecodeRequest(tokens=[data])
194
+ decode_response = await self.client.post(
195
+ f"{self.vqgan_url}/v1/vqgan/decode",
196
+ data=ormsgpack.packb(
197
+ decode_request,
198
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
199
+ ),
200
+ headers={"Content-Type": "application/msgpack"},
201
+ )
202
+ decode_data = ormsgpack.unpackb(decode_response.content)
203
+
204
+ # Convert float16 audio data to int16
205
+ audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
206
+ audio_data = (audio_data * 32768).astype(np.int16).tobytes()
207
+
208
+ audio_frame = CustomAudioFrame(
209
+ data=audio_data,
210
+ samples_per_channel=len(audio_data) // 2,
211
+ sample_rate=44100,
212
+ num_channels=1,
213
+ )
214
+ yield FishE2EEvent(
215
+ type=FishE2EEventType.SPEECH_SEGMENT,
216
+ frame=audio_frame,
217
+ vq_codes=data,
218
+ )
219
+
220
+ current_vq = False
221
+ vq_codes = []
222
+
223
+ async with self.client.stream(
224
+ "POST",
225
+ self.llm_url,
226
+ data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
227
+ headers={"Content-Type": "application/msgpack"},
228
+ ) as response:
229
+
230
+ async for chunk in response.aiter_bytes():
231
+ buffer += chunk
232
+
233
+ while len(buffer) >= 4:
234
+ read_length = struct.unpack("I", buffer[:4])[0]
235
+ if len(buffer) < 4 + read_length:
236
+ break
237
+
238
+ body = buffer[4 : 4 + read_length]
239
+ buffer = buffer[4 + read_length :]
240
+ data = ormsgpack.unpackb(body)
241
+
242
+ if data["delta"] and data["delta"]["part"]:
243
+ if current_vq and data["delta"]["part"]["type"] == "text":
244
+ async for event in decode_send():
245
+ yield event
246
+ if data["delta"]["part"]["type"] == "text":
247
+ yield FishE2EEvent(
248
+ type=FishE2EEventType.TEXT_SEGMENT,
249
+ text=data["delta"]["part"]["text"],
250
+ )
251
+ elif data["delta"]["part"]["type"] == "vq":
252
+ vq_codes.append(np.array(data["delta"]["part"]["codes"]))
253
+ current_vq = True
254
+
255
+ if current_vq and vq_codes:
256
+ async for event in decode_send():
257
+ yield event
258
+
259
+ yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
260
+ yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
261
+
262
+
263
+ # Example usage:
264
+ async def main():
265
+ import torchaudio
266
+
267
+ agent = FishE2EAgent()
268
+
269
+ # Replace this with actual audio data loading
270
+ with open("uz_story_en.m4a", "rb") as f:
271
+ audio_data = f.read()
272
+
273
+ audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
274
+ audio_data = (audio_data.numpy() * 32768).astype(np.int16)
275
+
276
+ stream = agent.stream(audio_data, sample_rate, 1)
277
+ if os.path.exists("audio_segment.wav"):
278
+ os.remove("audio_segment.wav")
279
+
280
+ async for event in stream:
281
+ if event.type == FishE2EEventType.SPEECH_SEGMENT:
282
+ # Handle speech segment (e.g., play audio or save to file)
283
+ with open("audio_segment.wav", "ab+") as f:
284
+ f.write(event.frame.data)
285
+ elif event.type == FishE2EEventType.ASR_RESULT:
286
+ print(event.text, flush=True)
287
+ elif event.type == FishE2EEventType.TEXT_SEGMENT:
288
+ print(event.text, flush=True, end="")
289
+ elif event.type == FishE2EEventType.END_OF_TEXT:
290
+ print("\nEnd of text reached.")
291
+ elif event.type == FishE2EEventType.END_OF_SPEECH:
292
+ print("End of speech reached.")
293
+
294
+
295
+ if __name__ == "__main__":
296
+ import asyncio
297
+
298
+ asyncio.run(main())
tools/llama/__pycache__/generate.cpython-310.pyc ADDED
Binary file (21.1 kB). View file
 
tools/llama/build_dataset.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ import re
4
+ from collections import defaultdict
5
+ from functools import partial
6
+ from multiprocessing import Pool
7
+ from pathlib import Path
8
+
9
+ import click
10
+ import numpy as np
11
+ from loguru import logger
12
+ from tqdm import tqdm
13
+
14
+ from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
15
+ from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
16
+ from tools.file import load_filelist
17
+
18
+ # To avoid CPU overload
19
+ os.environ["MKL_NUM_THREADS"] = "1"
20
+ os.environ["OMP_NUM_THREADS"] = "1"
21
+
22
+
23
+ def task_generator_folder(root: Path, text_extension: str):
24
+ files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
25
+ files = sorted(files)
26
+
27
+ grouped_files = defaultdict(list)
28
+ for file in tqdm(files, desc=f"Grouping {root}"):
29
+ p = str(file.parent)
30
+ speaker = file.parent.name
31
+
32
+ try:
33
+ if isinstance(text_extension, str):
34
+ texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
35
+ else:
36
+ texts = [
37
+ file.with_suffix(ext).read_text(encoding="utf-8")
38
+ for ext in text_extension
39
+ ]
40
+ except Exception as e:
41
+ logger.error(f"Failed to read text {file}: {e}")
42
+ continue
43
+
44
+ grouped_files[p].append((speaker, file, texts))
45
+
46
+ logger.info(
47
+ f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
48
+ )
49
+
50
+ for i in grouped_files.values():
51
+ subset = [(f, t) for _, f, t in i]
52
+ yield i[0][0], subset, "folder"
53
+
54
+
55
+ def task_generator_filelist(filelist):
56
+ grouped_files = defaultdict(list)
57
+ for filename, speaker, _, text in load_filelist(filelist):
58
+ grouped_files[speaker].append((Path(filename), [text]))
59
+
60
+ logger.info(f"Found {len(grouped_files)} groups in {filelist}")
61
+ for speaker, values in grouped_files.items():
62
+ yield speaker, values, "filelist"
63
+
64
+
65
+ def run_task(task):
66
+ name, subset, source = task
67
+
68
+ # Parse the files
69
+ sentences = []
70
+ for file, texts in subset:
71
+ np_file = file.with_suffix(".npy")
72
+ if np_file.exists() is False:
73
+ logger.warning(f"Can't find {np_file}")
74
+ continue
75
+
76
+ new_texts = []
77
+
78
+ for text in texts:
79
+ # Simple cleaning: replace { xxx } and < xxx > with space
80
+ text = re.sub(r"\{.*?\}", " ", text)
81
+ text = re.sub(r"<.*?>", " ", text)
82
+ text = re.sub(r"\s+", " ", text)
83
+ new_texts.append(text)
84
+
85
+ try:
86
+ semantics = np.load(np_file)
87
+ except Exception as e:
88
+ logger.error(f"Failed to parse {file}: {e}")
89
+ continue
90
+
91
+ if isinstance(semantics, np.ndarray):
92
+ semantics = semantics.tolist()
93
+
94
+ sentences.append(
95
+ Sentence(
96
+ texts=new_texts,
97
+ semantics=[Semantics(values=s) for s in semantics],
98
+ )
99
+ )
100
+
101
+ # Pack the sentences
102
+ return pack_pb_stream(
103
+ TextData(
104
+ source=source,
105
+ name=name,
106
+ sentences=sentences,
107
+ )
108
+ )
109
+
110
+
111
+ @click.command()
112
+ @click.option(
113
+ "--input",
114
+ type=click.Path(path_type=Path),
115
+ required=True,
116
+ help="A folder containing the dataset or a filelist",
117
+ multiple=True,
118
+ )
119
+ @click.option(
120
+ "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
121
+ )
122
+ @click.option("--num-workers", type=int, default=16)
123
+ @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
124
+ @click.option(
125
+ "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
126
+ )
127
+ def main(input, output, num_workers, text_extension, shard_size):
128
+ generator_fns = []
129
+
130
+ for f in input:
131
+ assert f.exists(), f"{f} not found"
132
+
133
+ if f.is_dir():
134
+ generator_fn = task_generator_folder(f, text_extension)
135
+ else:
136
+ generator_fn = task_generator_filelist(f)
137
+
138
+ generator_fns.append(generator_fn)
139
+
140
+ generator_fn = itertools.chain(*generator_fns)
141
+ output.mkdir(parents=True, exist_ok=True)
142
+
143
+ dataset_fp = None
144
+ tar_idx = 0
145
+ written_size = 0
146
+
147
+ with Pool(num_workers) as p:
148
+ for result in tqdm(p.imap_unordered(run_task, generator_fn)):
149
+ if dataset_fp is None:
150
+ dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
151
+
152
+ dataset_fp.write(result)
153
+ written_size += len(result)
154
+
155
+ if written_size > shard_size * 1024 * 1024:
156
+ logger.info(f"Finished writing {tar_idx} shards to {output}")
157
+ dataset_fp.close()
158
+ dataset_fp = None
159
+ written_size = 0
160
+ tar_idx += 1
161
+
162
+ if dataset_fp is not None:
163
+ dataset_fp.close()
164
+
165
+ logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
166
+
167
+
168
+ if __name__ == "__main__":
169
+ main()
tools/llama/eval_in_context.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrootutils
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from matplotlib import pyplot as plt
5
+ from transformers import AutoTokenizer
6
+
7
+ # register eval resolver and root
8
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
9
+
10
+ from torch.utils.data import DataLoader
11
+
12
+ from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
13
+ from tools.llama.generate import load_model
14
+
15
+
16
+ def smooth(
17
+ scalars: list[float], weight: float
18
+ ) -> list[float]: # Weight between 0 and 1
19
+ last = scalars[0] # First value in the plot (first timestep)
20
+ smoothed = list()
21
+ for point in scalars:
22
+ smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
23
+ smoothed.append(smoothed_val) # Save it
24
+ last = smoothed_val # Anchor the last smoothed value
25
+
26
+ return smoothed
27
+
28
+
29
+ @torch.inference_mode()
30
+ def analyze_one_model(loader, config, weight, max_length):
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ model = load_model(
33
+ config,
34
+ weight,
35
+ device,
36
+ torch.bfloat16,
37
+ max_length,
38
+ compile=False,
39
+ )[0]
40
+
41
+ current_step = 0
42
+ model.eval()
43
+
44
+ semantic_loss_sum = torch.zeros(
45
+ max_length,
46
+ dtype=torch.float32,
47
+ device=device,
48
+ )
49
+ counter = torch.zeros(
50
+ max_length,
51
+ dtype=torch.long,
52
+ device=device,
53
+ )
54
+
55
+ for batch in loader:
56
+ batch = {k: v.to(device) for k, v in batch.items()}
57
+
58
+ labels = batch["labels"]
59
+ outputs = model(
60
+ inp=batch["inputs"],
61
+ key_padding_mask=batch["attention_masks"],
62
+ )
63
+
64
+ token_logits = outputs.token_logits
65
+ codebook_logits = outputs.codebook_logits
66
+
67
+ # Generate labels
68
+ base_loss = F.cross_entropy(
69
+ token_logits.reshape(-1, token_logits.size(-1)),
70
+ labels[:, 0].reshape(-1),
71
+ ignore_index=-100,
72
+ reduction="none",
73
+ )
74
+
75
+ codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
76
+ semantic_loss = F.cross_entropy(
77
+ codebook_logits.reshape(-1, codebook_logits.size(-1)),
78
+ codebook_labels.reshape(-1),
79
+ ignore_index=-100,
80
+ reduction="none",
81
+ )
82
+
83
+ base_loss = base_loss.reshape(labels[:, 0].shape)
84
+ semantic_loss = semantic_loss.reshape(codebook_labels.shape)
85
+
86
+ semantic_loss_frame = semantic_loss.mean(-1)
87
+ pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
88
+
89
+ for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
90
+ semantic_loss_sum[~pad] += loss_sample[~pad]
91
+ counter[~pad] += 1
92
+
93
+ current_step += 1
94
+ if current_step == 10:
95
+ break
96
+
97
+ semantic_loss = semantic_loss.cpu()
98
+ counter = counter.cpu()
99
+ xs, ys = [], []
100
+
101
+ for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
102
+ if count > 0:
103
+ xs.append(i)
104
+ ys.append((loss / count).item()) # for better loss visualization
105
+
106
+ smoothed_ys = smooth(ys, 0.95)
107
+
108
+ # Unload model
109
+ del model
110
+ torch.cuda.empty_cache()
111
+
112
+ return xs, ys, smoothed_ys
113
+
114
+
115
+ def main():
116
+ tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
117
+ max_length = 4096
118
+
119
+ ds = AutoAugTextDataset(
120
+ ["data/protos/sft/云天河"],
121
+ tokenizer=tokenizer,
122
+ use_speaker=False,
123
+ interactive_prob=1.0,
124
+ max_length=max_length,
125
+ )
126
+
127
+ loader = DataLoader(
128
+ ds,
129
+ batch_size=8,
130
+ collate_fn=TextDataCollator(tokenizer, max_length=max_length),
131
+ num_workers=0,
132
+ shuffle=False,
133
+ )
134
+
135
+ plt.figure(figsize=(10, 5), dpi=200)
136
+
137
+ plt.xlabel("Frame")
138
+ plt.ylabel("Loss")
139
+ plt.yscale("log")
140
+ plt.title("Semantic Loss")
141
+ plt.grid(which="both", axis="both")
142
+ plt.xlim(0, max_length)
143
+
144
+ tests = [
145
+ (
146
+ "pertrain-medium",
147
+ "dual_ar_2_codebook_medium",
148
+ "checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
149
+ ),
150
+ (
151
+ "sft-medium",
152
+ "dual_ar_2_codebook_medium",
153
+ "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
154
+ ),
155
+ (
156
+ "sft-large",
157
+ "dual_ar_2_codebook_large",
158
+ "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
159
+ ),
160
+ ]
161
+
162
+ for name, config, weight in tests:
163
+ xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
164
+ plt.plot(xs, smoothed_ys, label=name)
165
+
166
+ plt.legend()
167
+ plt.savefig("semantic_loss.png")
168
+
169
+
170
+ if __name__ == "__main__":
171
+ main()
tools/llama/generate.py ADDED
@@ -0,0 +1,1087 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import queue
3
+ import threading
4
+ import time
5
+ from contextlib import nullcontext
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Literal, Optional, Tuple, Union
9
+
10
+ import click
11
+ import hydra
12
+ import numpy as np
13
+ import torch
14
+ import torch._dynamo.config
15
+ import torch._inductor.config
16
+ from loguru import logger
17
+ from tqdm import tqdm
18
+ from transformers import AutoTokenizer
19
+
20
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
21
+ from fish_speech.models.text2semantic.llama import BaseModelArgs
22
+ from fish_speech.text import clean_text, split_text
23
+
24
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
25
+ torch._inductor.config.coordinate_descent_tuning = True
26
+ torch._inductor.config.triton.unique_kernel_names = True
27
+
28
+ if hasattr(torch._inductor.config, "fx_graph_cache"):
29
+ # Experimental feature to reduce compilation times, will be on by default in future
30
+ torch._inductor.config.fx_graph_cache = True
31
+
32
+
33
+ from torch.nn.attention import SDPBackend, sdpa_kernel
34
+
35
+ from fish_speech.models.text2semantic.llama import (
36
+ BaseTransformer,
37
+ DualARTransformer,
38
+ NaiveTransformer,
39
+ )
40
+
41
+
42
+ def multinomial_sample_one_no_sync(
43
+ probs_sort,
44
+ ): # Does multinomial sampling without a cuda synchronization
45
+ q = torch.empty_like(probs_sort).exponential_(1)
46
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
47
+
48
+
49
+ def logits_to_probs(
50
+ logits,
51
+ previous_tokens: Optional[torch.Tensor] = None,
52
+ temperature: torch.Tensor = 1.0,
53
+ top_p: torch.Tensor = 1.0,
54
+ repetition_penalty: torch.Tensor = 1.0,
55
+ ) -> torch.Tensor:
56
+ # Apply repetition penalty
57
+ if previous_tokens is not None:
58
+ previous_tokens = previous_tokens.long()
59
+ score = torch.gather(logits, dim=0, index=previous_tokens)
60
+ score = torch.where(
61
+ score < 0, score * repetition_penalty, score / repetition_penalty
62
+ )
63
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
64
+
65
+ # Apply top-p sampling
66
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
67
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
68
+ sorted_indices_to_remove = cum_probs > top_p
69
+ sorted_indices_to_remove[0] = False # keep at least one option
70
+ indices_to_remove = sorted_indices_to_remove.scatter(
71
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
72
+ )
73
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
74
+
75
+ logits = logits / max(temperature, 1e-5)
76
+
77
+ probs = torch.nn.functional.softmax(logits, dim=-1)
78
+ return probs
79
+
80
+
81
+ def multinomial_sample_one_no_sync_agent(
82
+ probs_sort,
83
+ ): # Does multinomial sampling without a cuda synchronization
84
+ q = torch.empty_like(probs_sort).exponential_(1)
85
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
86
+
87
+
88
+ def logits_to_probs_agent(
89
+ logits,
90
+ previous_tokens: Optional[torch.Tensor] = None,
91
+ temperature: torch.Tensor = 1.0,
92
+ top_p: torch.Tensor = 1.0,
93
+ repetition_penalty: torch.Tensor = 1.0,
94
+ ) -> torch.Tensor:
95
+ # Apply repetition penalty
96
+ if previous_tokens is not None:
97
+ previous_tokens = previous_tokens.long()
98
+ score = torch.gather(logits, dim=-1, index=previous_tokens)
99
+ score = torch.where(
100
+ score < 0, score * repetition_penalty, score / repetition_penalty
101
+ )
102
+ logits.scatter_(dim=-1, index=previous_tokens, src=score)
103
+
104
+ # Apply top-p sampling
105
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
106
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
107
+ sorted_indices_to_remove = cum_probs > top_p
108
+ sorted_indices_to_remove[..., 0] = False # keep at least one option
109
+ indices_to_remove = sorted_indices_to_remove.scatter(
110
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
111
+ )
112
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
113
+
114
+ logits = logits / max(temperature, 1e-5)
115
+
116
+ probs = torch.nn.functional.softmax(logits, dim=-1)
117
+ return probs
118
+
119
+
120
+ def sample(
121
+ logits,
122
+ previous_tokens: Optional[torch.Tensor] = None,
123
+ **sampling_kwargs,
124
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
125
+ probs = logits_to_probs(
126
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
127
+ )
128
+ idx_next = multinomial_sample_one_no_sync(probs)
129
+ return idx_next, probs
130
+
131
+
132
+ def sample_agent(
133
+ logits,
134
+ previous_tokens: Optional[torch.Tensor] = None,
135
+ **sampling_kwargs,
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ probs = logits_to_probs_agent(
138
+ logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
139
+ )
140
+ idx_next = multinomial_sample_one_no_sync_agent(probs)
141
+ return idx_next, probs
142
+
143
+
144
+ def decode_one_token_ar_agent(
145
+ model: DualARTransformer,
146
+ x: torch.Tensor,
147
+ input_pos: torch.Tensor,
148
+ previous_tokens: torch.Tensor = None,
149
+ semantic_id: int = 32003,
150
+ **sampling_kwargs,
151
+ ) -> torch.Tensor:
152
+ # print(x, input_pos)
153
+ x = model.forward_generate(x, input_pos)
154
+ logits = x.logits # [:, -1:]
155
+ hidden_states = x.hidden_states # [:, -1:]
156
+
157
+ sampling_kwargs_main = sampling_kwargs.copy()
158
+ sampling_kwargs_main["temperature"] = 0.1
159
+ sampling_kwargs_main["top_p"] = 0.1
160
+ sampling_kwargs_main["repetition_penalty"] = 1.0
161
+
162
+ codebooks = [
163
+ sample_agent(
164
+ logits,
165
+ previous_tokens=None, # Disable repetition penalty for the token codebook
166
+ **sampling_kwargs_main,
167
+ )[0]
168
+ ]
169
+
170
+ # Cleanup the cache
171
+ for layer in model.fast_layers:
172
+ layer.attention.kv_cache.k_cache.fill_(0)
173
+ layer.attention.kv_cache.v_cache.fill_(0)
174
+
175
+ for codebook_idx in range(model.config.num_codebooks):
176
+ input_pos = torch.tensor(
177
+ [codebook_idx], device=hidden_states.device, dtype=torch.long
178
+ )
179
+ logits = model.forward_generate_fast(hidden_states, input_pos)
180
+ a = sample_agent(
181
+ logits,
182
+ previous_tokens=(
183
+ previous_tokens[:, codebook_idx + 1]
184
+ if previous_tokens is not None
185
+ else None
186
+ ),
187
+ **sampling_kwargs,
188
+ )[0]
189
+ hidden_states = model.fast_embeddings(a)
190
+ codebooks.append(a)
191
+
192
+ codebooks = torch.stack(codebooks, dim=1)
193
+ codebooks[:, 1:, :] = torch.masked_fill(
194
+ codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
195
+ )
196
+
197
+ # for i in range(codebooks.size(1) - 1):
198
+ # codebooks[:, i + 1, :] = torch.masked_fill(
199
+ # codebooks[:, i + 1, :],
200
+ # codebooks[:, :1, :] != semantic_id,
201
+ # CODEBOOK_PAD_TOKEN_ID + i * 1024,
202
+ # )
203
+
204
+ # print(codebooks)
205
+
206
+ return codebooks
207
+
208
+
209
+ def decode_one_token_naive_agent(
210
+ model: NaiveTransformer,
211
+ x: torch.Tensor,
212
+ input_pos: torch.Tensor,
213
+ previous_tokens: torch.Tensor = None,
214
+ semantic_id: int = 32003,
215
+ **sampling_kwargs,
216
+ ) -> torch.Tensor:
217
+ x = model.forward_generate(x, input_pos)
218
+
219
+ codebooks = [
220
+ sample(
221
+ x.token_logits,
222
+ previous_tokens=None, # Disable repetition penalty for the token codebook
223
+ **sampling_kwargs,
224
+ )[0]
225
+ ]
226
+
227
+ for i in range(model.config.num_codebooks):
228
+ codebooks.append(
229
+ sample_agent(
230
+ x.codebook_logits[:, :, i],
231
+ previous_tokens=(
232
+ previous_tokens[:, i + 1] if previous_tokens is not None else None
233
+ ),
234
+ **sampling_kwargs,
235
+ )[0]
236
+ )
237
+
238
+ codebooks = torch.stack(codebooks, dim=1)
239
+ codebooks[:, 1:, :] = torch.masked_fill(
240
+ codebooks[:, 1:, :], codebooks[:, :1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
241
+ )
242
+
243
+ return codebooks
244
+
245
+
246
+ def decode_one_token_ar(
247
+ model: DualARTransformer,
248
+ x: torch.Tensor,
249
+ input_pos: torch.Tensor,
250
+ previous_tokens: torch.Tensor = None,
251
+ semantic_id: int = 0,
252
+ **sampling_kwargs,
253
+ ) -> torch.Tensor:
254
+ x = model.forward_generate(x, input_pos)
255
+
256
+ sampling_kwargs_main = sampling_kwargs.copy()
257
+ # sampling_kwargs_main["temperature"] = 0.1
258
+ # sampling_kwargs_main["top_p"] = 0.1
259
+ # sampling_kwargs_main["repetition_penalty"] = 1.0
260
+
261
+ codebooks = [
262
+ sample(
263
+ x.logits,
264
+ previous_tokens=None, # Disable repetition penalty for the token codebook
265
+ **sampling_kwargs_main,
266
+ )[0]
267
+ ]
268
+
269
+ x = x.hidden_states
270
+
271
+ # Cleanup the cache
272
+ for layer in model.fast_layers:
273
+ layer.attention.kv_cache.k_cache.fill_(0)
274
+ layer.attention.kv_cache.v_cache.fill_(0)
275
+
276
+ for codebook_idx in range(model.config.num_codebooks):
277
+ input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
278
+ logits = model.forward_generate_fast(x, input_pos)
279
+ a = sample(
280
+ logits,
281
+ previous_tokens=(
282
+ previous_tokens[codebook_idx + 1]
283
+ if previous_tokens is not None
284
+ else None
285
+ ),
286
+ **sampling_kwargs,
287
+ )[0]
288
+ x = model.fast_embeddings(a)
289
+ codebooks.append(a)
290
+
291
+ codebooks = torch.stack(codebooks, dim=0)
292
+ codebooks[1:, :] = torch.masked_fill(
293
+ codebooks[1:, :], codebooks[:1, :] != semantic_id, CODEBOOK_PAD_TOKEN_ID
294
+ )
295
+
296
+ return codebooks
297
+
298
+
299
+ def decode_one_token_naive(
300
+ model: NaiveTransformer,
301
+ x: torch.Tensor,
302
+ input_pos: torch.Tensor,
303
+ previous_tokens: torch.Tensor = None,
304
+ **sampling_kwargs,
305
+ ) -> torch.Tensor:
306
+ x = model.forward_generate(x, input_pos)
307
+
308
+ sampling_kwargs_main = sampling_kwargs.copy()
309
+ sampling_kwargs_main["temperature"] = 0.1
310
+ sampling_kwargs_main["top_p"] = 0.1
311
+ sampling_kwargs_main["repetition_penalty"] = 1.0
312
+
313
+ codebooks = [
314
+ sample(
315
+ x.logits,
316
+ previous_tokens=None, # Disable repetition penalty for the token codebook
317
+ **sampling_kwargs_main,
318
+ )[0]
319
+ ]
320
+
321
+ for i in range(model.config.num_codebooks):
322
+ codebooks.append(
323
+ sample(
324
+ x.codebook_logits[:, :, i],
325
+ previous_tokens=(
326
+ previous_tokens[i + 1] if previous_tokens is not None else None
327
+ ),
328
+ **sampling_kwargs,
329
+ )[0]
330
+ )
331
+
332
+ return torch.stack(codebooks, dim=0)
333
+
334
+
335
+ def decode_n_tokens(
336
+ model: NaiveTransformer,
337
+ cur_token: torch.Tensor,
338
+ input_pos: torch.Tensor,
339
+ num_new_tokens: int,
340
+ im_end_id: int = 4,
341
+ decode_one_token=decode_one_token_naive,
342
+ semantic_id: int = 0,
343
+ **sampling_kwargs,
344
+ ):
345
+ previous_tokens = torch.zeros(
346
+ (model.config.num_codebooks + 1, model.config.max_seq_len),
347
+ dtype=torch.int,
348
+ device=cur_token.device,
349
+ )
350
+
351
+ for i in tqdm(range(num_new_tokens)):
352
+ # We need to get windowed repeat penalty
353
+ win_size = 16
354
+ if i < win_size:
355
+ window = previous_tokens[:, :win_size]
356
+ else:
357
+ window = previous_tokens[:, i - win_size : i]
358
+
359
+ with (
360
+ torch.backends.cuda.sdp_kernel(
361
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
362
+ )
363
+ if torch.cuda.is_available()
364
+ else nullcontext()
365
+ ): # Actually better for Inductor to codegen attention here
366
+ next_token = decode_one_token(
367
+ model=model,
368
+ x=cur_token,
369
+ input_pos=input_pos,
370
+ previous_tokens=window,
371
+ semantic_id=semantic_id,
372
+ **sampling_kwargs,
373
+ )
374
+
375
+ input_pos += 1
376
+ cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
377
+ previous_tokens[:, i : i + 1] = next_token.view(
378
+ model.config.num_codebooks + 1, -1
379
+ )
380
+
381
+ if cur_token[0, 0, -1] == im_end_id:
382
+ break
383
+
384
+ return previous_tokens[:, : i + 1]
385
+
386
+
387
+ @torch.no_grad()
388
+ @torch.inference_mode()
389
+ def generate(
390
+ *,
391
+ model: NaiveTransformer,
392
+ prompt: torch.Tensor,
393
+ max_new_tokens: int =600,
394
+ im_end_id: int = 4,
395
+ decode_one_token=decode_one_token_naive,
396
+ **sampling_kwargs,
397
+ ) -> torch.Tensor:
398
+ """
399
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
400
+ """
401
+
402
+ # create an empty tensor of the expected final shape and fill in the current tokens
403
+ T = prompt.size(1)
404
+ semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
405
+
406
+ if max_new_tokens:
407
+ if T + max_new_tokens > model.config.max_seq_len:
408
+ max_new_tokens = model.config.max_seq_len - T
409
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
410
+
411
+ T_new = T + max_new_tokens
412
+ else:
413
+ T_new = model.config.max_seq_len
414
+ max_new_tokens = T_new - T
415
+
416
+ device, dtype = prompt.device, prompt.dtype
417
+
418
+ codebook_dim = 1 + model.config.num_codebooks
419
+ # create an empty tensor of the expected final shape and fill in the current tokens
420
+ empty = torch.empty(
421
+ (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
422
+ )
423
+ empty[:, :T] = prompt
424
+ seq = empty
425
+ input_pos = torch.arange(0, T, device=device)
426
+
427
+ # Use non-accelerated version for now, to avoid compilation overhead
428
+ prefill_decode = (
429
+ decode_one_token_naive
430
+ if isinstance(model, NaiveTransformer)
431
+ else decode_one_token_ar
432
+ )
433
+
434
+ next_token = prefill_decode(
435
+ model,
436
+ prompt.view(1, codebook_dim, -1),
437
+ input_pos,
438
+ semantic_id=semantic_id,
439
+ **sampling_kwargs,
440
+ )
441
+ seq[:, T : T + 1] = next_token
442
+
443
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
444
+ x = decode_n_tokens(
445
+ model,
446
+ next_token.view(1, codebook_dim, -1),
447
+ input_pos,
448
+ max_new_tokens - 1,
449
+ im_end_id=im_end_id,
450
+ decode_one_token=decode_one_token,
451
+ semantic_id=semantic_id,
452
+ **sampling_kwargs,
453
+ )
454
+ # x = torch.cat(generated_tokens, dim=1)
455
+ seq = seq[:, : T + 1 + x.size(1)]
456
+ seq[:, T + 1 :] = x
457
+
458
+ return seq
459
+
460
+
461
+ def decode_n_tokens_agent(
462
+ model: NaiveTransformer,
463
+ cur_token: torch.Tensor,
464
+ input_pos: torch.Tensor,
465
+ num_new_tokens: int,
466
+ im_end_id: int = 4,
467
+ semantic_id: int = 32003,
468
+ decode_one_token=decode_one_token_naive_agent,
469
+ early_stop_threshold: float = 0.6,
470
+ **sampling_kwargs,
471
+ ):
472
+ batch_size = cur_token.size(0)
473
+ previous_tokens = torch.zeros(
474
+ (batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
475
+ dtype=torch.int,
476
+ device=cur_token.device,
477
+ )
478
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
479
+ finished = finished | (cur_token[:, 0, -1] == im_end_id)
480
+ start_time = time.time()
481
+
482
+ for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
483
+ # We need to get windowed repeat penalty
484
+ win_size = 16
485
+ if i < win_size:
486
+ window = previous_tokens[:, :, :win_size]
487
+ else:
488
+ window = previous_tokens[:, :, i - win_size : i]
489
+
490
+ with sdpa_kernel(
491
+ SDPBackend.MATH
492
+ ): # Actually better for Inductor to codegen attention here
493
+ next_token = decode_one_token(
494
+ model=model,
495
+ x=cur_token,
496
+ input_pos=input_pos,
497
+ previous_tokens=window,
498
+ semantic_id=semantic_id,
499
+ **sampling_kwargs,
500
+ )
501
+
502
+ input_pos += 1
503
+ cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
504
+ previous_tokens[:, :, i : i + 1] = next_token.view(
505
+ batch_size, model.config.num_codebooks + 1, -1
506
+ )
507
+
508
+ yield cur_token.cpu()
509
+
510
+ finished = finished | (cur_token[:, 0, -1] == im_end_id)
511
+ if finished.all() or (
512
+ 0 < early_stop_threshold < 1
513
+ and finished.sum() >= round(batch_size * early_stop_threshold)
514
+ ):
515
+ break
516
+
517
+ total_time = time.time() - start_time
518
+ generated_tokens = i + 1
519
+ tokens_per_second = (generated_tokens / total_time) * batch_size
520
+ logger.info(
521
+ f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
522
+ )
523
+
524
+
525
+ @torch.no_grad()
526
+ @torch.inference_mode()
527
+ def generate_agent(
528
+ *,
529
+ model: BaseTransformer,
530
+ prompt: torch.Tensor,
531
+ max_new_tokens: int =500,
532
+ im_end_id: int = 4,
533
+ semantic_id: int = 32003,
534
+ decode_one_token=decode_one_token_naive_agent,
535
+ num_samples: int = 1,
536
+ early_stop_threshold: float = 0.6,
537
+ **sampling_kwargs,
538
+ ):
539
+ """
540
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
541
+ """
542
+
543
+ # create an empty tensor of the expected final shape and fill in the current tokens
544
+ T = prompt.size(1)
545
+ prompt = prompt[None].repeat(num_samples, 1, 1)
546
+
547
+ if T >= model.config.max_seq_len:
548
+ raise ValueError(
549
+ f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
550
+ )
551
+
552
+ if max_new_tokens:
553
+ if T + max_new_tokens > model.config.max_seq_len:
554
+ max_new_tokens = model.config.max_seq_len - T
555
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
556
+
557
+ T_new = T + max_new_tokens
558
+ else:
559
+ T_new = model.config.max_seq_len
560
+ max_new_tokens = T_new - T
561
+
562
+ device, dtype = prompt.device, prompt.dtype
563
+
564
+ codebook_dim = 1 + model.config.num_codebooks
565
+ input_pos = torch.arange(0, T, device=device)
566
+
567
+ # Use non-accelerated version for now, to avoid compilation overhead
568
+ prefill_decode = (
569
+ decode_one_token_naive_agent
570
+ if isinstance(model, NaiveTransformer)
571
+ else decode_one_token_ar_agent
572
+ )
573
+ next_token = prefill_decode(
574
+ model,
575
+ prompt,
576
+ input_pos,
577
+ semantic_id=semantic_id,
578
+ **sampling_kwargs,
579
+ ).view(num_samples, codebook_dim, -1)
580
+ yield next_token.cpu()
581
+
582
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
583
+
584
+ yield from decode_n_tokens_agent(
585
+ model,
586
+ next_token,
587
+ input_pos,
588
+ max_new_tokens - 1,
589
+ im_end_id=im_end_id,
590
+ semantic_id=semantic_id,
591
+ decode_one_token=decode_one_token,
592
+ early_stop_threshold=early_stop_threshold,
593
+ **sampling_kwargs,
594
+ )
595
+
596
+
597
+ def encode_tokens(
598
+ tokenizer,
599
+ string,
600
+ device="cuda",
601
+ prompt_tokens=None,
602
+ num_codebooks=4,
603
+ ):
604
+ string = clean_text(string)
605
+ string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
606
+
607
+ new_tokens = tokenizer.encode(
608
+ string,
609
+ add_special_tokens=False,
610
+ max_length=10**6,
611
+ truncation=False,
612
+ )
613
+ tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
614
+
615
+ # Codebooks
616
+ zeros = (
617
+ torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
618
+ * CODEBOOK_PAD_TOKEN_ID
619
+ )
620
+ prompt = torch.cat((tokens, zeros), dim=0)
621
+
622
+ if prompt_tokens is None:
623
+ return prompt
624
+
625
+ # Get prompt tokens
626
+ if prompt_tokens.ndim == 3:
627
+ assert (
628
+ prompt_tokens.shape[0] == 1
629
+ ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
630
+ prompt_tokens = prompt_tokens[0]
631
+
632
+ assert prompt_tokens.ndim == 2
633
+ data = prompt_tokens + 1
634
+
635
+ if prompt_tokens.shape[0] > num_codebooks:
636
+ logger.warning(
637
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
638
+ )
639
+ data = data[:num_codebooks]
640
+
641
+ # Add pad token for each codebook
642
+ data = torch.cat(
643
+ (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
644
+ dim=1,
645
+ )
646
+
647
+ # Since 1.0, we use <|semantic|>
648
+ s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
649
+ end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
650
+ main_token_ids = (
651
+ torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
652
+ )
653
+ main_token_ids[0, -1] = end_token_id
654
+
655
+ data = torch.cat((main_token_ids, data), dim=0)
656
+ prompt = torch.cat((prompt, data), dim=1)
657
+
658
+ return prompt
659
+
660
+
661
+ def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
662
+ model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
663
+ checkpoint_path, load_weights=True
664
+ )
665
+
666
+ model = model.to(device=device, dtype=precision)
667
+ logger.info(f"Restored model from checkpoint")
668
+
669
+ if isinstance(model, DualARTransformer):
670
+ decode_one_token = (
671
+ decode_one_token_ar_agent if is_agent else decode_one_token_ar
672
+ )
673
+ logger.info("Using DualARTransformer")
674
+ else:
675
+ decode_one_token = (
676
+ decode_one_token_naive_agent if is_agent else decode_one_token_naive
677
+ )
678
+ logger.info("Using NaiveTransformer")
679
+
680
+ if compile:
681
+ logger.info("Compiling function...")
682
+ decode_one_token = torch.compile(
683
+ decode_one_token,
684
+ fullgraph=True,
685
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
686
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
687
+ )
688
+
689
+ return model.eval(), decode_one_token
690
+
691
+
692
+ @dataclass
693
+ class GenerateResponse:
694
+ action: Literal["sample", "next"]
695
+ codes: Optional[torch.Tensor] = None
696
+ text: Optional[str] = None
697
+
698
+
699
+ def generate_long(
700
+ *,
701
+ model,
702
+ device: str | torch.device,
703
+ decode_one_token: callable,
704
+ text: str,
705
+ num_samples: int = 1,
706
+ max_new_tokens: int = 600,
707
+ top_p: int = 0.7,
708
+ repetition_penalty: float = 1.5,
709
+ temperature: float = 0.7,
710
+ compile: bool = False,
711
+ iterative_prompt: bool = True,
712
+ max_length: int = 2048,
713
+ chunk_length: int = 150,
714
+ prompt_text: Optional[str | list[str]] = None,
715
+ prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
716
+ ):
717
+ assert 0 < top_p <= 1, "top_p must be in (0, 1]"
718
+ assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
719
+ assert 0 < temperature < 2, "temperature must be in (0, 2)"
720
+
721
+ use_prompt = prompt_text is not None and prompt_tokens is not None
722
+ if use_prompt and isinstance(prompt_text, str):
723
+ prompt_text = [prompt_text]
724
+ prompt_tokens = [prompt_tokens]
725
+
726
+ assert use_prompt is False or len(prompt_text) == len(
727
+ prompt_tokens
728
+ ), "Prompt text and tokens must have the same length"
729
+
730
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
731
+ tokenizer = model.tokenizer
732
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
733
+
734
+ encoded = []
735
+ texts = split_text(text, chunk_length) if iterative_prompt else [text]
736
+ encoded_prompts = []
737
+
738
+ if use_prompt:
739
+ for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
740
+ encoded_prompts.append(
741
+ encode_tokens(
742
+ tokenizer,
743
+ string=t,
744
+ device=device,
745
+ prompt_tokens=c,
746
+ num_codebooks=model.config.num_codebooks,
747
+ )
748
+ )
749
+
750
+ for idx, text in enumerate(texts):
751
+ encoded.append(
752
+ encode_tokens(
753
+ tokenizer,
754
+ string=text,
755
+ device=device,
756
+ num_codebooks=model.config.num_codebooks,
757
+ )
758
+ )
759
+ logger.info(f"Encoded text: {text}")
760
+
761
+ # Move temperature, top_p, repetition_penalty to device
762
+ # This is important so that changing params doesn't trigger recompile
763
+ temperature = torch.tensor(temperature, device=device, dtype=torch.float)
764
+ top_p = torch.tensor(top_p, device=device, dtype=torch.float)
765
+ repetition_penalty = torch.tensor(
766
+ repetition_penalty, device=device, dtype=torch.float
767
+ )
768
+
769
+ for sample_idx in range(num_samples):
770
+ if torch.cuda.is_available():
771
+ torch.cuda.synchronize()
772
+
773
+ global_encoded = []
774
+ seg_idx = 0
775
+
776
+ while seg_idx < len(encoded):
777
+ logger.info(
778
+ f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
779
+ )
780
+
781
+ seg = encoded[seg_idx]
782
+ global_encoded.append(seg)
783
+
784
+ lengths = reversed([seg.size(1) for seg in global_encoded])
785
+
786
+ # Pick last 2000 tokens
787
+ count = 0
788
+ for i, length in enumerate(lengths):
789
+ count += length
790
+ if count + length > max_length - 1024 - sum(
791
+ t.shape[1] for t in encoded_prompts
792
+ ):
793
+ break
794
+
795
+ if i != 0 and i % 2 == 0:
796
+ i -= 1
797
+
798
+ # Rotate the list, always make sure first segment is included to avoid drift
799
+ if i < len(global_encoded) - 2:
800
+ partial_encoded = global_encoded[:2] + global_encoded[-i:]
801
+ else:
802
+ partial_encoded = global_encoded
803
+
804
+ if use_prompt:
805
+ partial_encoded = encoded_prompts + partial_encoded
806
+
807
+ cat_encoded = torch.cat(partial_encoded, dim=1)
808
+ prompt_length = cat_encoded.size(1)
809
+
810
+ t0 = time.perf_counter()
811
+ y = generate(
812
+ model=model,
813
+ prompt=cat_encoded,
814
+ max_new_tokens=max_new_tokens,
815
+ im_end_id=im_end_id,
816
+ decode_one_token=decode_one_token,
817
+ temperature=temperature,
818
+ top_p=top_p,
819
+ repetition_penalty=repetition_penalty,
820
+ )
821
+
822
+ if sample_idx == 0 and seg_idx == 0 and compile:
823
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
824
+
825
+ if torch.cuda.is_available():
826
+ torch.cuda.synchronize()
827
+
828
+ t = time.perf_counter() - t0
829
+
830
+ tokens_generated = y.size(1) - prompt_length
831
+ tokens_sec = tokens_generated / t
832
+ logger.info(
833
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
834
+ )
835
+ logger.info(
836
+ f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
837
+ )
838
+
839
+ if torch.cuda.is_available():
840
+ logger.info(
841
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
842
+ )
843
+
844
+ # Put the generated tokens
845
+ # since there is <im_end> and <eos> tokens, we remove last 2 tokens
846
+ codes = y[1:, prompt_length:-1].clone()
847
+ codes = codes - 1
848
+ assert (codes >= 0).all(), f"Negative code found"
849
+
850
+ decoded = y[:, prompt_length:-1].clone()
851
+ # But for global encoding, we should keep the <im_end> token
852
+
853
+ global_encoded.append(decoded)
854
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
855
+ yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
856
+ seg_idx += 1
857
+
858
+ # This indicates the end of the current sample
859
+ yield GenerateResponse(action="next")
860
+
861
+
862
+ @dataclass
863
+ class WrappedGenerateResponse:
864
+ status: Literal["success", "error"]
865
+ response: Optional[GenerateResponse | Exception] = None
866
+
867
+
868
+ @dataclass
869
+ class GenerateRequest:
870
+ request: dict
871
+ response_queue: queue.Queue
872
+
873
+
874
+ def launch_thread_safe_queue(
875
+ checkpoint_path,
876
+ device,
877
+ precision,
878
+ compile: bool = False,
879
+ ):
880
+ input_queue = queue.Queue()
881
+ init_event = threading.Event()
882
+
883
+ def worker():
884
+ model, decode_one_token = load_model(
885
+ checkpoint_path, device, precision, compile=compile
886
+ )
887
+ with torch.device(device):
888
+ model.setup_caches(
889
+ max_batch_size=1,
890
+ max_seq_len=model.config.max_seq_len,
891
+ dtype=next(model.parameters()).dtype,
892
+ )
893
+ init_event.set()
894
+
895
+ while True:
896
+ item: GenerateRequest | None = input_queue.get()
897
+ if item is None:
898
+ break
899
+
900
+ kwargs = item.request
901
+ response_queue = item.response_queue
902
+
903
+ try:
904
+ for chunk in generate_long(
905
+ model=model, decode_one_token=decode_one_token, **kwargs
906
+ ):
907
+ response_queue.put(
908
+ WrappedGenerateResponse(status="success", response=chunk)
909
+ )
910
+ except Exception as e:
911
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
912
+
913
+ threading.Thread(target=worker, daemon=True).start()
914
+ init_event.wait()
915
+
916
+ return input_queue
917
+
918
+
919
+ def launch_thread_safe_queue_agent(
920
+ checkpoint_path,
921
+ device,
922
+ precision,
923
+ compile: bool = False,
924
+ ):
925
+ input_queue = queue.Queue()
926
+ init_event = threading.Event()
927
+
928
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
929
+ config = BaseModelArgs.from_pretrained(checkpoint_path)
930
+
931
+ def worker():
932
+ model, decode_one_token = load_model(
933
+ checkpoint_path, device, precision, compile=compile, is_agent=True
934
+ )
935
+
936
+ with torch.device(device):
937
+ model.setup_caches(
938
+ max_batch_size=1,
939
+ max_seq_len=model.config.max_seq_len,
940
+ dtype=next(model.parameters()).dtype,
941
+ )
942
+ init_event.set()
943
+
944
+ while True:
945
+ item: GenerateRequest | None = input_queue.get()
946
+ if item is None:
947
+ break
948
+
949
+ kwargs = item.request
950
+ response_queue = item.response_queue
951
+
952
+ try:
953
+ for token in generate_agent(
954
+ model=model,
955
+ decode_one_token=decode_one_token,
956
+ **kwargs,
957
+ ):
958
+ response_queue.put(token)
959
+
960
+ response_queue.put("stop")
961
+ except Exception as e:
962
+ import traceback
963
+
964
+ logger.exception(f"Error in worker: {traceback.format_exc()}")
965
+ response_queue.put("error")
966
+
967
+ threading.Thread(target=worker, daemon=True).start()
968
+ init_event.wait()
969
+
970
+ return input_queue, tokenizer, config
971
+
972
+
973
+ @click.command()
974
+ @click.option(
975
+ "--text",
976
+ type=str,
977
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
978
+ )
979
+ @click.option("--prompt-text", type=str, default=None, multiple=True)
980
+ @click.option(
981
+ "--prompt-tokens",
982
+ type=click.Path(path_type=Path, exists=True),
983
+ default=None,
984
+ multiple=True,
985
+ )
986
+ @click.option("--num-samples", type=int, default=1)
987
+ @click.option("--max-new-tokens", type=int, default=0)
988
+ @click.option("--top-p", type=float, default=0.7)
989
+ @click.option("--repetition-penalty", type=float, default=1.2)
990
+ @click.option("--temperature", type=float, default=0.7)
991
+ @click.option(
992
+ "--checkpoint-path",
993
+ type=click.Path(path_type=Path, exists=True),
994
+ default="checkpoints/fish-speech-1.4",
995
+ )
996
+ @click.option("--device", type=str, default="cuda")
997
+ @click.option("--compile/--no-compile", default=False)
998
+ @click.option("--seed", type=int, default=42)
999
+ @click.option("--half/--no-half", default=False)
1000
+ @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
1001
+ @click.option("--chunk-length", type=int, default=100)
1002
+ def main(
1003
+ text: str,
1004
+ prompt_text: Optional[list[str]],
1005
+ prompt_tokens: Optional[list[Path]],
1006
+ num_samples: int,
1007
+ max_new_tokens: int,
1008
+ top_p: int,
1009
+ repetition_penalty: float,
1010
+ temperature: float,
1011
+ checkpoint_path: Path,
1012
+ device: str,
1013
+ compile: bool,
1014
+ seed: int,
1015
+ half: bool,
1016
+ iterative_prompt: bool,
1017
+ chunk_length: int,
1018
+ ) -> None:
1019
+
1020
+ precision = torch.half if half else torch.bfloat16
1021
+
1022
+ if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
1023
+ raise ValueError(
1024
+ f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
1025
+ )
1026
+
1027
+ logger.info("Loading model ...")
1028
+ t0 = time.time()
1029
+ model, decode_one_token = load_model(
1030
+ checkpoint_path, device, precision, compile=compile
1031
+ )
1032
+ with torch.device(device):
1033
+ model.setup_caches(
1034
+ max_batch_size=1,
1035
+ max_seq_len=model.config.max_seq_len,
1036
+ dtype=next(model.parameters()).dtype,
1037
+ )
1038
+ if torch.cuda.is_available():
1039
+ torch.cuda.synchronize()
1040
+
1041
+ logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
1042
+
1043
+ if prompt_tokens is not None:
1044
+ prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
1045
+
1046
+ torch.manual_seed(seed)
1047
+
1048
+ if torch.cuda.is_available():
1049
+ torch.cuda.manual_seed(seed)
1050
+
1051
+ generator = generate_long(
1052
+ model=model,
1053
+ device=device,
1054
+ decode_one_token=decode_one_token,
1055
+ text=text,
1056
+ num_samples=num_samples,
1057
+ max_new_tokens=max_new_tokens,
1058
+ top_p=top_p,
1059
+ repetition_penalty=repetition_penalty,
1060
+ temperature=temperature,
1061
+ compile=compile,
1062
+ iterative_prompt=iterative_prompt,
1063
+ chunk_length=chunk_length,
1064
+ prompt_text=prompt_text,
1065
+ prompt_tokens=prompt_tokens,
1066
+ )
1067
+
1068
+ idx = 0
1069
+ codes = []
1070
+
1071
+ for response in generator:
1072
+ if response.action == "sample":
1073
+ codes.append(response.codes)
1074
+ logger.info(f"Sampled text: {response.text}")
1075
+ elif response.action == "next":
1076
+ if codes:
1077
+ np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
1078
+ logger.info(f"Saved codes to codes_{idx}.npy")
1079
+ logger.info(f"Next sample")
1080
+ codes = []
1081
+ idx += 1
1082
+ else:
1083
+ logger.error(f"Error: {response}")
1084
+
1085
+
1086
+ if __name__ == "__main__":
1087
+ main()
tools/llama/merge_lora.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from copy import deepcopy
3
+ from pathlib import Path
4
+
5
+ import click
6
+ import hydra
7
+ import torch
8
+ from hydra import compose, initialize
9
+ from hydra.utils import instantiate
10
+ from loguru import logger
11
+
12
+ from fish_speech.models.text2semantic.llama import BaseTransformer
13
+ from fish_speech.models.text2semantic.lora import get_merged_state_dict
14
+
15
+
16
+ @click.command()
17
+ @click.option("--lora-config", type=str, default="r_8_alpha_16")
18
+ @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
19
+ @click.option("--lora-weight", type=str, required=True)
20
+ @click.option("--output", type=str, required=True)
21
+ def merge(lora_config, base_weight, lora_weight, output):
22
+ output = Path(output)
23
+ logger.info(
24
+ f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
25
+ )
26
+
27
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
28
+ cfg = compose(config_name=lora_config)
29
+
30
+ lora_config = instantiate(cfg)
31
+ logger.info(f"Loaded lora model with config {lora_config}")
32
+
33
+ llama_model = BaseTransformer.from_pretrained(
34
+ path=base_weight,
35
+ load_weights=True,
36
+ lora_config=lora_config,
37
+ )
38
+ logger.info(f"Loaded llama model")
39
+
40
+ llama_state_dict = llama_model.state_dict()
41
+ llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
42
+ llama_state_dict_copy = deepcopy(llama_state_dict)
43
+ lora_state_dict = torch.load(lora_weight, map_location="cpu")
44
+
45
+ if "state_dict" in llama_state_dict:
46
+ llama_state_dict = llama_state_dict["state_dict"]
47
+
48
+ if "state_dict" in lora_state_dict:
49
+ lora_state_dict = lora_state_dict["state_dict"]
50
+
51
+ # remove prefix model.
52
+ if any(k.startswith("model.") for k in llama_state_dict.keys()):
53
+ llama_state_dict = {
54
+ k.replace("model.", ""): v
55
+ for k, v in llama_state_dict.items()
56
+ if k.startswith("model.")
57
+ }
58
+ if any(k.startswith("model.") for k in lora_state_dict.keys()):
59
+ lora_state_dict = {
60
+ k.replace("model.", ""): v
61
+ for k, v in lora_state_dict.items()
62
+ if k.startswith("model.")
63
+ }
64
+
65
+ logger.info(f"Found {len(llama_state_dict)} keys in llama model")
66
+ logger.info(f"Found {len(lora_state_dict)} keys in lora model")
67
+
68
+ merged_state_dict = llama_state_dict | lora_state_dict
69
+ llama_model.load_state_dict(merged_state_dict, strict=True)
70
+ logger.info(f"Merged model loaded")
71
+
72
+ # Trigger eval mode to merge lora
73
+ llama_model.eval()
74
+ llama_model.save_pretrained(output, drop_lora=True)
75
+ logger.info(f"Saved merged model to {output}, validating")
76
+
77
+ new_state_dict = torch.load(output / "model.pth", map_location="cpu")
78
+ original_keys = set(llama_state_dict_copy.keys())
79
+ merged_keys = set(new_state_dict.keys())
80
+
81
+ assert original_keys == merged_keys, "Keys should be same"
82
+
83
+ for key in original_keys:
84
+ diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
85
+ if diff_l1 != 0:
86
+ break
87
+ else:
88
+ logger.error("Merged model is same as the original model")
89
+ exit(1)
90
+
91
+ logger.info("Merged model is different from the original model, check passed")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ merge()
tools/llama/quantize.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ import datetime
4
+ import shutil
5
+
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+ import time
9
+ from pathlib import Path
10
+
11
+ import click
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from fish_speech.models.text2semantic.llama import find_multiple
17
+ from tools.llama.generate import load_model
18
+
19
+ ##### Quantization Primitives ######
20
+
21
+
22
+ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
23
+ # assumes symmetric quantization
24
+ # assumes axis == 0
25
+ # assumes dense memory format
26
+ # TODO(future): relax ^ as needed
27
+
28
+ # default setup for affine quantization of activations
29
+ eps = torch.finfo(torch.float32).eps
30
+
31
+ # get min and max
32
+ min_val, max_val = torch.aminmax(x, dim=1)
33
+
34
+ # calculate scales and zero_points based on min and max
35
+ # reference: https://fburl.com/code/srbiybme
36
+ min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
37
+ max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
38
+ device = min_val_neg.device
39
+
40
+ # reference: https://fburl.com/code/4wll53rk
41
+ max_val_pos = torch.max(-min_val_neg, max_val_pos)
42
+ scales = max_val_pos / (float(quant_max - quant_min) / 2)
43
+ # ensure scales is the same dtype as the original tensor
44
+ scales = torch.clamp(scales, min=eps).to(x.dtype)
45
+ zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
46
+
47
+ # quantize based on qmin/qmax/scales/zp
48
+ # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
49
+ x_div = x / scales.unsqueeze(-1)
50
+ x_round = torch.round(x_div)
51
+ x_zp = x_round + zero_points.unsqueeze(-1)
52
+ quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
53
+
54
+ return quant, scales, zero_points
55
+
56
+
57
+ def get_group_qparams(w, n_bit=4, groupsize=128):
58
+ # needed for GPTQ with padding
59
+ if groupsize > w.shape[-1]:
60
+ groupsize = w.shape[-1]
61
+ assert groupsize > 1
62
+ assert w.shape[-1] % groupsize == 0
63
+ assert w.dim() == 2
64
+
65
+ to_quant = w.reshape(-1, groupsize)
66
+ assert torch.isnan(to_quant).sum() == 0
67
+
68
+ max_val = to_quant.amax(dim=1, keepdim=True)
69
+ min_val = to_quant.amin(dim=1, keepdim=True)
70
+ max_int = 2**n_bit - 1
71
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
72
+ zeros = min_val + scales * (2 ** (n_bit - 1))
73
+ return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
74
+ torch.bfloat16
75
+ ).reshape(w.shape[0], -1)
76
+
77
+
78
+ def pack_scales_and_zeros(scales, zeros):
79
+ assert scales.shape == zeros.shape
80
+ assert scales.dtype == torch.bfloat16
81
+ assert zeros.dtype == torch.bfloat16
82
+ return (
83
+ torch.cat(
84
+ [
85
+ scales.reshape(scales.size(0), scales.size(1), 1),
86
+ zeros.reshape(zeros.size(0), zeros.size(1), 1),
87
+ ],
88
+ 2,
89
+ )
90
+ .transpose(0, 1)
91
+ .contiguous()
92
+ )
93
+
94
+
95
+ def unpack_scales_and_zeros(scales_and_zeros):
96
+ assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
97
+ assert scales_and_zeros.dtype == torch.float
98
+ return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
99
+
100
+
101
+ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
102
+ assert groupsize > 1
103
+ # needed for GPTQ single column quantize
104
+ if groupsize > w.shape[-1] and scales.shape[-1] == 1:
105
+ groupsize = w.shape[-1]
106
+
107
+ assert w.shape[-1] % groupsize == 0
108
+ assert w.dim() == 2
109
+
110
+ to_quant = w.reshape(-1, groupsize)
111
+ assert torch.isnan(to_quant).sum() == 0
112
+
113
+ scales = scales.reshape(-1, 1)
114
+ zeros = zeros.reshape(-1, 1)
115
+ min_val = zeros - scales * (2 ** (n_bit - 1))
116
+ max_int = 2**n_bit - 1
117
+ min_int = 0
118
+ w_int32 = (
119
+ to_quant.sub(min_val)
120
+ .div(scales)
121
+ .round()
122
+ .clamp_(min_int, max_int)
123
+ .to(torch.int32)
124
+ .reshape_as(w)
125
+ )
126
+
127
+ return w_int32
128
+
129
+
130
+ def group_quantize_tensor(w, n_bit=4, groupsize=128):
131
+ scales, zeros = get_group_qparams(w, n_bit, groupsize)
132
+ w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
133
+ scales_and_zeros = pack_scales_and_zeros(scales, zeros)
134
+ return w_int32, scales_and_zeros
135
+
136
+
137
+ def group_dequantize_tensor_from_qparams(
138
+ w_int32, scales, zeros, n_bit=4, groupsize=128
139
+ ):
140
+ assert groupsize > 1
141
+ # needed for GPTQ single column dequantize
142
+ if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
143
+ groupsize = w_int32.shape[-1]
144
+ assert w_int32.shape[-1] % groupsize == 0
145
+ assert w_int32.dim() == 2
146
+
147
+ w_int32_grouped = w_int32.reshape(-1, groupsize)
148
+ scales = scales.reshape(-1, 1)
149
+ zeros = zeros.reshape(-1, 1)
150
+
151
+ w_dq = (
152
+ w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
153
+ )
154
+ return w_dq
155
+
156
+
157
+ def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
158
+ scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
159
+ return group_dequantize_tensor_from_qparams(
160
+ w_int32, scales, zeros, n_bit, groupsize
161
+ )
162
+
163
+
164
+ class QuantHandler:
165
+ def __init__(self, mod):
166
+ self.mod = mod
167
+
168
+ def create_quantized_state_dict(self) -> "StateDict":
169
+ pass
170
+
171
+ def convert_for_runtime(self) -> "nn.Module":
172
+ pass
173
+
174
+
175
+ ##### Weight-only int8 per-channel quantized code ######
176
+
177
+
178
+ def replace_linear_weight_only_int8_per_channel(module):
179
+ for name, child in module.named_children():
180
+ if isinstance(child, nn.Linear):
181
+ setattr(
182
+ module,
183
+ name,
184
+ WeightOnlyInt8Linear(child.in_features, child.out_features),
185
+ )
186
+ else:
187
+ replace_linear_weight_only_int8_per_channel(child)
188
+
189
+
190
+ class WeightOnlyInt8QuantHandler:
191
+ def __init__(self, mod):
192
+ self.mod = mod
193
+
194
+ @torch.no_grad()
195
+ def create_quantized_state_dict(self):
196
+ cur_state_dict = self.mod.state_dict()
197
+ for fqn, mod in self.mod.named_modules():
198
+ if isinstance(mod, torch.nn.Linear):
199
+ int8_weight, scales, _ = dynamically_quantize_per_channel(
200
+ mod.weight.float(), -128, 127, torch.int8
201
+ )
202
+ cur_state_dict[f"{fqn}.weight"] = int8_weight
203
+ cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
204
+
205
+ return cur_state_dict
206
+
207
+ def convert_for_runtime(self):
208
+ replace_linear_weight_only_int8_per_channel(self.mod)
209
+ return self.mod
210
+
211
+
212
+ class WeightOnlyInt8Linear(torch.nn.Module):
213
+ __constants__ = ["in_features", "out_features"]
214
+ in_features: int
215
+ out_features: int
216
+ weight: torch.Tensor
217
+
218
+ def __init__(
219
+ self,
220
+ in_features: int,
221
+ out_features: int,
222
+ bias: bool = True,
223
+ device=None,
224
+ dtype=None,
225
+ ) -> None:
226
+ factory_kwargs = {"device": device, "dtype": dtype}
227
+ super().__init__()
228
+ self.in_features = in_features
229
+ self.out_features = out_features
230
+ self.register_buffer(
231
+ "weight", torch.empty((out_features, in_features), dtype=torch.int8)
232
+ )
233
+ self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
234
+
235
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
236
+ return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
237
+
238
+
239
+ ##### weight only int4 per channel groupwise quantized code ######
240
+
241
+
242
+ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
243
+ weight_int32, scales_and_zeros = group_quantize_tensor(
244
+ weight_bf16, n_bit=4, groupsize=groupsize
245
+ )
246
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
247
+ weight_int32, inner_k_tiles
248
+ )
249
+ return weight_int4pack, scales_and_zeros
250
+
251
+
252
+ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
253
+ origin_x_size = x.size()
254
+ x = x.reshape(-1, origin_x_size[-1])
255
+ c = torch.ops.aten._weight_int4pack_mm(
256
+ x, weight_int4pack, groupsize, scales_and_zeros
257
+ )
258
+ new_shape = origin_x_size[:-1] + (out_features,)
259
+ c = c.reshape(new_shape)
260
+ return c
261
+
262
+
263
+ def _check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
264
+ return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
265
+
266
+
267
+ def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
268
+ for name, child in module.named_children():
269
+ if isinstance(child, nn.Linear):
270
+ if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
271
+ setattr(
272
+ module,
273
+ name,
274
+ WeightOnlyInt4Linear(
275
+ child.in_features,
276
+ child.out_features,
277
+ bias=False,
278
+ groupsize=groupsize,
279
+ inner_k_tiles=inner_k_tiles,
280
+ padding=False,
281
+ ),
282
+ )
283
+ elif padding:
284
+ setattr(
285
+ module,
286
+ name,
287
+ WeightOnlyInt4Linear(
288
+ child.in_features,
289
+ child.out_features,
290
+ bias=False,
291
+ groupsize=groupsize,
292
+ inner_k_tiles=inner_k_tiles,
293
+ padding=True,
294
+ ),
295
+ )
296
+ else:
297
+ replace_linear_int4(child, groupsize, inner_k_tiles, padding)
298
+
299
+
300
+ class WeightOnlyInt4QuantHandler:
301
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
302
+ self.mod = mod
303
+ self.groupsize = groupsize
304
+ self.inner_k_tiles = inner_k_tiles
305
+ self.padding = padding
306
+ assert groupsize in [32, 64, 128, 256]
307
+ assert inner_k_tiles in [2, 4, 8]
308
+
309
+ @torch.no_grad()
310
+ def create_quantized_state_dict(self):
311
+ cur_state_dict = self.mod.state_dict()
312
+ for fqn, mod in self.mod.named_modules():
313
+ if isinstance(mod, torch.nn.Linear):
314
+ assert not mod.bias
315
+ out_features = mod.out_features
316
+ in_features = mod.in_features
317
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
318
+ print(f"linear: {fqn}, in={in_features}, out={out_features}")
319
+
320
+ weight = mod.weight.data
321
+ if not _check_linear_int4_k(
322
+ in_features, self.groupsize, self.inner_k_tiles
323
+ ):
324
+ if self.padding:
325
+ import torch.nn.functional as F
326
+
327
+ print(
328
+ f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
329
+ )
330
+ padded_in_features = find_multiple(in_features, 1024)
331
+ weight = F.pad(
332
+ weight, pad=(0, padded_in_features - in_features)
333
+ )
334
+ else:
335
+ print(
336
+ f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
337
+ + "and that groupsize and inner_k_tiles*16 evenly divide into it"
338
+ )
339
+ continue
340
+ (
341
+ weight_int4pack,
342
+ scales_and_zeros,
343
+ ) = prepare_int4_weight_and_scales_and_zeros(
344
+ weight.to(torch.bfloat16).to("cuda"),
345
+ self.groupsize,
346
+ self.inner_k_tiles,
347
+ )
348
+ cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
349
+ cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
350
+
351
+ return cur_state_dict
352
+
353
+ def convert_for_runtime(self):
354
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
355
+ return self.mod
356
+
357
+
358
+ class WeightOnlyInt4Linear(torch.nn.Module):
359
+ __constants__ = ["in_features", "out_features"]
360
+ in_features: int
361
+ out_features: int
362
+ weight: torch.Tensor
363
+
364
+ def __init__(
365
+ self,
366
+ in_features: int,
367
+ out_features: int,
368
+ bias=True,
369
+ device=None,
370
+ dtype=None,
371
+ groupsize: int = 128,
372
+ inner_k_tiles: int = 8,
373
+ padding: bool = True,
374
+ ) -> None:
375
+ super().__init__()
376
+ self.padding = padding
377
+ if padding:
378
+ self.origin_in_features = in_features
379
+ in_features = find_multiple(in_features, 1024)
380
+
381
+ self.in_features = in_features
382
+ self.out_features = out_features
383
+ assert not bias, "require bias=False"
384
+ self.groupsize = groupsize
385
+ self.inner_k_tiles = inner_k_tiles
386
+
387
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
388
+ assert (
389
+ in_features % (inner_k_tiles * 16) == 0
390
+ ), "require in_features % (innerKTiles * 16) == 0"
391
+ self.register_buffer(
392
+ "weight",
393
+ torch.empty(
394
+ (
395
+ out_features // 8,
396
+ in_features // (inner_k_tiles * 16),
397
+ 32,
398
+ inner_k_tiles // 2,
399
+ ),
400
+ dtype=torch.int32,
401
+ ),
402
+ )
403
+ self.register_buffer(
404
+ "scales_and_zeros",
405
+ torch.empty(
406
+ (in_features // groupsize, out_features, 2), dtype=torch.bfloat16
407
+ ),
408
+ )
409
+
410
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
411
+ input = input.to(torch.bfloat16)
412
+ if self.padding:
413
+ import torch.nn.functional as F
414
+
415
+ input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
416
+ return linear_forward_int4(
417
+ input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
418
+ )
419
+
420
+
421
+ def generate_folder_name():
422
+ now = datetime.datetime.now()
423
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
424
+ return folder_name
425
+
426
+
427
+ @click.command()
428
+ @click.option(
429
+ "--checkpoint-path",
430
+ type=click.Path(path_type=Path, exists=True),
431
+ default="checkpoints/fish-speech-1.4",
432
+ )
433
+ @click.option(
434
+ "--mode", type=str, default="int8", help="type of quantization to perform"
435
+ )
436
+ @click.option(
437
+ "--groupsize", type=int, default=128, help="Group size for int4 quantization."
438
+ )
439
+ @click.option("--timestamp", type=str, default="None", help="When to do quantization")
440
+ def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
441
+
442
+ device = "cpu"
443
+ precision = torch.bfloat16
444
+
445
+ print("Loading model ...")
446
+ t0 = time.time()
447
+
448
+ model, _ = load_model(
449
+ checkpoint_path=checkpoint_path,
450
+ device=device,
451
+ precision=precision,
452
+ compile=False,
453
+ )
454
+ vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
455
+ now = timestamp if timestamp != "None" else generate_folder_name()
456
+
457
+ if mode == "int8":
458
+ print(
459
+ "Quantizing model weights for int8 weight-only symmetric per-channel quantization"
460
+ )
461
+ quant_handler = WeightOnlyInt8QuantHandler(model)
462
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
463
+
464
+ dir_name = checkpoint_path
465
+ dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
466
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
467
+ if (dst_name / vq_model).exists():
468
+ (dst_name / vq_model).unlink()
469
+ quantize_path = dst_name / "model.pth"
470
+
471
+ elif mode == "int4":
472
+ print(
473
+ "Quantizing model weights for int4 weight-only affine per-channel groupwise quantization"
474
+ )
475
+ quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
476
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
477
+
478
+ dir_name = checkpoint_path
479
+ dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
480
+ shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
481
+ if (dst_name / vq_model).exists():
482
+ (dst_name / vq_model).unlink()
483
+ quantize_path = dst_name / "model.pth"
484
+
485
+ else:
486
+ raise ValueError(
487
+ f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]"
488
+ )
489
+
490
+ print(f"Writing quantized weights to {quantize_path}")
491
+ quantize_path.unlink(missing_ok=True) # remove existing file if one already there
492
+ torch.save(quantized_state_dict, quantize_path)
493
+ print(f"Quantization complete took {time.time() - t0:.02f} seconds")
494
+
495
+
496
+ if __name__ == "__main__":
497
+ quantize()
tools/llama/rebuild_tokenizer.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
2
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
3
+
4
+ # Initialize a tokenizer
5
+ tokenizer = Tokenizer(models.BPE())
6
+
7
+ # Customize pre-tokenization and decoding
8
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
9
+ tokenizer.decoder = decoders.ByteLevel()
10
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
11
+
12
+ # Don't train the tokenizer
13
+ trainer = trainers.BpeTrainer(
14
+ vocab_size=0,
15
+ min_frequency=2,
16
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
17
+ special_tokens=[
18
+ "<|begin_of_sequence|>",
19
+ "<|end_of_sequence|>",
20
+ "<|im_start|>",
21
+ "<|im_sep|>", # system, user, assistant, etc.
22
+ "<|im_end|>",
23
+ "<|semantic|>", # audio features
24
+ "<|pad|>",
25
+ ],
26
+ )
27
+
28
+ # <|im_start|>user<|im_sep|>...<|im_end|>
29
+ # <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
30
+ tokenizer.train_from_iterator([], trainer=trainer)
31
+
32
+ print(len(tokenizer.get_vocab()))
33
+ x = tokenizer.encode(
34
+ "Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
35
+ ).ids
36
+ print(x, len(x))
37
+ print(tokenizer.decode(x, skip_special_tokens=True))
38
+
39
+
40
+ tokenizer = PreTrainedTokenizerFast(
41
+ tokenizer_object=tokenizer,
42
+ pad_token="<|pad|>",
43
+ bos_token="<|begin_of_sequence|>",
44
+ eos_token="<|end_of_sequence|>",
45
+ )
46
+
47
+ # Try tokenizing a new sequence
48
+ sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
49
+ encoded = tokenizer(sequence).input_ids
50
+
51
+ print("Test encoding....")
52
+ print(f"\tSentence: {sequence}")
53
+ print(f"\tEncoded: {encoded}")
54
+ print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
55
+ print(f"\tDecoded: {tokenizer.decode(encoded)}")
56
+
57
+ tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
tools/msgpack_api.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ import httpx
6
+ import ormsgpack
7
+
8
+ from tools.schema import ServeReferenceAudio, ServeTTSRequest
9
+
10
+ api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
11
+
12
+
13
+ def audio_request():
14
+ # priority: ref_id > references
15
+ request = ServeTTSRequest(
16
+ text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
17
+ # reference_id="114514",
18
+ references=[
19
+ ServeReferenceAudio(
20
+ audio=open("lengyue.wav", "rb").read(),
21
+ text=open("lengyue.lab", "r", encoding="utf-8").read(),
22
+ )
23
+ ],
24
+ streaming=True,
25
+ )
26
+
27
+ api_key = os.environ.get("FISH_API_KEY", "YOUR_API_KEY")
28
+
29
+ with (
30
+ httpx.Client() as client,
31
+ open("hello.wav", "wb") as f,
32
+ ):
33
+ with client.stream(
34
+ "POST",
35
+ "http://127.0.0.1:8080/v1/tts",
36
+ content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
37
+ headers={
38
+ "authorization": f"Bearer {api_key}",
39
+ "content-type": "application/msgpack",
40
+ },
41
+ timeout=None,
42
+ ) as response:
43
+ for chunk in response.iter_bytes():
44
+ f.write(chunk)
45
+
46
+
47
+ def asr_request(audio_path: Path):
48
+
49
+ # Read the audio file
50
+ with open(
51
+ str(audio_path),
52
+ "rb",
53
+ ) as audio_file:
54
+ audio_data = audio_file.read()
55
+
56
+ # Prepare the request data
57
+ request_data = {
58
+ "audio": audio_data,
59
+ "language": "en", # Optional: specify the language
60
+ "ignore_timestamps": False, # Optional: set to True to ignore precise timestamps
61
+ }
62
+
63
+ # Send the request
64
+ with httpx.Client() as client:
65
+ response = client.post(
66
+ "https://api.fish.audio/v1/asr",
67
+ headers={
68
+ "Authorization": f"Bearer {api_key}",
69
+ "Content-Type": "application/msgpack",
70
+ },
71
+ content=ormsgpack.packb(request_data),
72
+ )
73
+
74
+ # Parse the response
75
+ result = response.json()
76
+
77
+ print(f"Transcribed text: {result['text']}")
78
+ print(f"Audio duration: {result['duration']} seconds")
79
+
80
+ for segment in result["segments"]:
81
+ print(f"Segment: {segment['text']}")
82
+ print(f"Start time: {segment['start']}, End time: {segment['end']}")
83
+
84
+
85
+ def parse_args():
86
+ parser = ArgumentParser()
87
+ parser.add_argument("--audio_path", type=Path, default="audio/ref/trump.mp3")
88
+
89
+ return parser.parse_args()
90
+
91
+
92
+ if __name__ == "__main__":
93
+ args = parse_args()
94
+
95
+ asr_request(args.audio_path)
tools/post_api.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import wave
4
+
5
+ import ormsgpack
6
+ import pyaudio
7
+ import requests
8
+ from pydub import AudioSegment
9
+ from pydub.playback import play
10
+
11
+ from tools.file import audio_to_bytes, read_ref_text
12
+ from tools.schema import ServeReferenceAudio, ServeTTSRequest
13
+
14
+
15
+ def parse_args():
16
+
17
+ parser = argparse.ArgumentParser(
18
+ description="Send a WAV file and text to a server and receive synthesized audio.",
19
+ formatter_class=argparse.RawTextHelpFormatter,
20
+ )
21
+
22
+ parser.add_argument(
23
+ "--url",
24
+ "-u",
25
+ type=str,
26
+ default="http://127.0.0.1:8080/v1/tts",
27
+ help="URL of the server",
28
+ )
29
+ parser.add_argument(
30
+ "--text", "-t", type=str, required=True, help="Text to be synthesized"
31
+ )
32
+ parser.add_argument(
33
+ "--reference_id",
34
+ "-id",
35
+ type=str,
36
+ default=None,
37
+ help="ID of the reference model to be used for the speech\n(Local: name of folder containing audios and files)",
38
+ )
39
+ parser.add_argument(
40
+ "--reference_audio",
41
+ "-ra",
42
+ type=str,
43
+ nargs="+",
44
+ default=None,
45
+ help="Path to the audio file",
46
+ )
47
+ parser.add_argument(
48
+ "--reference_text",
49
+ "-rt",
50
+ type=str,
51
+ nargs="+",
52
+ default=None,
53
+ help="Reference text for voice synthesis",
54
+ )
55
+ parser.add_argument(
56
+ "--output",
57
+ "-o",
58
+ type=str,
59
+ default="generated_audio",
60
+ help="Output audio file name",
61
+ )
62
+ parser.add_argument(
63
+ "--play",
64
+ type=bool,
65
+ default=True,
66
+ help="Whether to play audio after receiving data",
67
+ )
68
+ parser.add_argument("--normalize", type=bool, default=True)
69
+ parser.add_argument(
70
+ "--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
71
+ )
72
+ parser.add_argument(
73
+ "--mp3_bitrate", type=int, choices=[64, 128, 192], default=64, help="kHz"
74
+ )
75
+ parser.add_argument("--opus_bitrate", type=int, default=-1000)
76
+ parser.add_argument(
77
+ "--latency",
78
+ type=str,
79
+ default="normal",
80
+ choices=["normal", "balanced"],
81
+ help="Used in api.fish.audio/v1/tts",
82
+ )
83
+ parser.add_argument(
84
+ "--max_new_tokens",
85
+ type=int,
86
+ default=0,
87
+ help="Maximum new tokens to generate. \n0 means no limit.",
88
+ )
89
+ parser.add_argument(
90
+ "--chunk_length", type=int, default=200, help="Chunk length for synthesis"
91
+ )
92
+ parser.add_argument(
93
+ "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
94
+ )
95
+ parser.add_argument(
96
+ "--repetition_penalty",
97
+ type=float,
98
+ default=1.2,
99
+ help="Repetition penalty for synthesis",
100
+ )
101
+ parser.add_argument(
102
+ "--temperature", type=float, default=0.7, help="Temperature for sampling"
103
+ )
104
+
105
+ parser.add_argument(
106
+ "--streaming", type=bool, default=False, help="Enable streaming response"
107
+ )
108
+ parser.add_argument(
109
+ "--channels", type=int, default=1, help="Number of audio channels"
110
+ )
111
+ parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
112
+ parser.add_argument(
113
+ "--use_memory_cache",
114
+ type=str,
115
+ default="never",
116
+ choices=["on-demand", "never"],
117
+ help="Cache encoded references codes in memory.\n"
118
+ "If `on-demand`, the server will use cached encodings\n "
119
+ "instead of encoding reference audio again.",
120
+ )
121
+ parser.add_argument(
122
+ "--seed",
123
+ type=int,
124
+ default=None,
125
+ help="`None` means randomized inference, otherwise deterministic.\n"
126
+ "It can't be used for fixing a timbre.",
127
+ )
128
+
129
+ return parser.parse_args()
130
+
131
+
132
+ if __name__ == "__main__":
133
+
134
+ args = parse_args()
135
+
136
+ idstr: str | None = args.reference_id
137
+ # priority: ref_id > [{text, audio},...]
138
+ if idstr is None:
139
+ ref_audios = args.reference_audio
140
+ ref_texts = args.reference_text
141
+ if ref_audios is None:
142
+ byte_audios = []
143
+ else:
144
+ byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
145
+ if ref_texts is None:
146
+ ref_texts = []
147
+ else:
148
+ ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
149
+ else:
150
+ byte_audios = []
151
+ ref_texts = []
152
+ pass # in api.py
153
+
154
+ data = {
155
+ "text": args.text,
156
+ "references": [
157
+ ServeReferenceAudio(audio=ref_audio, text=ref_text)
158
+ for ref_text, ref_audio in zip(ref_texts, byte_audios)
159
+ ],
160
+ "reference_id": idstr,
161
+ "normalize": args.normalize,
162
+ "format": args.format,
163
+ "mp3_bitrate": args.mp3_bitrate,
164
+ "opus_bitrate": args.opus_bitrate,
165
+ "max_new_tokens": args.max_new_tokens,
166
+ "chunk_length": args.chunk_length,
167
+ "top_p": args.top_p,
168
+ "repetition_penalty": args.repetition_penalty,
169
+ "temperature": args.temperature,
170
+ "streaming": args.streaming,
171
+ "use_memory_cache": args.use_memory_cache,
172
+ "seed": args.seed,
173
+ }
174
+
175
+ pydantic_data = ServeTTSRequest(**data)
176
+
177
+ response = requests.post(
178
+ args.url,
179
+ data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
180
+ stream=args.streaming,
181
+ headers={
182
+ "authorization": "Bearer YOUR_API_KEY",
183
+ "content-type": "application/msgpack",
184
+ },
185
+ )
186
+
187
+ if response.status_code == 200:
188
+ if args.streaming:
189
+ p = pyaudio.PyAudio()
190
+ audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
191
+ stream = p.open(
192
+ format=audio_format, channels=args.channels, rate=args.rate, output=True
193
+ )
194
+
195
+ wf = wave.open(f"{args.output}.wav", "wb")
196
+ wf.setnchannels(args.channels)
197
+ wf.setsampwidth(p.get_sample_size(audio_format))
198
+ wf.setframerate(args.rate)
199
+
200
+ stream_stopped_flag = False
201
+
202
+ try:
203
+ for chunk in response.iter_content(chunk_size=1024):
204
+ if chunk:
205
+ stream.write(chunk)
206
+ wf.writeframesraw(chunk)
207
+ else:
208
+ if not stream_stopped_flag:
209
+ stream.stop_stream()
210
+ stream_stopped_flag = True
211
+ finally:
212
+ stream.close()
213
+ p.terminate()
214
+ wf.close()
215
+ else:
216
+ audio_content = response.content
217
+ audio_path = f"{args.output}.{args.format}"
218
+ with open(audio_path, "wb") as audio_file:
219
+ audio_file.write(audio_content)
220
+
221
+ audio = AudioSegment.from_file(audio_path, format=args.format)
222
+ if args.play:
223
+ play(audio)
224
+ print(f"Audio has been saved to '{audio_path}'.")
225
+ else:
226
+ print(f"Request failed with status code {response.status_code}")
227
+ print(response.json())
tools/schema.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import queue
3
+ from dataclasses import dataclass
4
+ from typing import Annotated, Literal, Optional
5
+
6
+ import torch
7
+ from pydantic import AfterValidator, BaseModel, Field, confloat, conint, conlist
8
+ from pydantic.functional_validators import SkipValidation
9
+
10
+ from fish_speech.conversation import Message, TextPart, VQPart
11
+
12
+ GLOBAL_NUM_SAMPLES = int(os.getenv("GLOBAL_NUM_SAMPLES", 1))
13
+
14
+
15
+ class ServeVQPart(BaseModel):
16
+ type: Literal["vq"] = "vq"
17
+ codes: SkipValidation[list[list[int]]]
18
+
19
+
20
+ class ServeTextPart(BaseModel):
21
+ type: Literal["text"] = "text"
22
+ text: str
23
+
24
+
25
+ class ServeAudioPart(BaseModel):
26
+ type: Literal["audio"] = "audio"
27
+ audio: bytes
28
+
29
+
30
+ @dataclass
31
+ class ASRPackRequest:
32
+ audio: torch.Tensor
33
+ result_queue: queue.Queue
34
+ language: str
35
+
36
+
37
+ class ServeASRRequest(BaseModel):
38
+ # The audio should be an uncompressed PCM float16 audio
39
+ audios: list[bytes]
40
+ sample_rate: int = 44100
41
+ language: Literal["zh", "en", "ja", "auto"] = "auto"
42
+
43
+
44
+ class ServeASRTranscription(BaseModel):
45
+ text: str
46
+ duration: float
47
+ huge_gap: bool
48
+
49
+
50
+ class ServeASRSegment(BaseModel):
51
+ text: str
52
+ start: float
53
+ end: float
54
+
55
+
56
+ class ServeTimedASRResponse(BaseModel):
57
+ text: str
58
+ segments: list[ServeASRSegment]
59
+ duration: float
60
+
61
+
62
+ class ServeASRResponse(BaseModel):
63
+ transcriptions: list[ServeASRTranscription]
64
+
65
+
66
+ class ServeMessage(BaseModel):
67
+ role: Literal["system", "assistant", "user"]
68
+ parts: list[ServeVQPart | ServeTextPart]
69
+
70
+ def to_conversation_message(self):
71
+ new_message = Message(role=self.role, parts=[])
72
+ for part in self.parts:
73
+ if isinstance(part, ServeTextPart):
74
+ new_message.parts.append(TextPart(text=part.text))
75
+ elif isinstance(part, ServeVQPart):
76
+ new_message.parts.append(
77
+ VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
78
+ )
79
+ else:
80
+ raise ValueError(f"Unsupported part type: {part}")
81
+
82
+ return new_message
83
+
84
+
85
+ class ServeRequest(BaseModel):
86
+ messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
87
+ max_new_tokens: int = 1024
88
+ top_p: float = 0.7
89
+ repetition_penalty: float = 1.2
90
+ temperature: float = 0.7
91
+ streaming: bool = False
92
+ num_samples: int = 1
93
+ early_stop_threshold: float = 1.0
94
+
95
+
96
+ class ServeVQGANEncodeRequest(BaseModel):
97
+ # The audio here should be in wav, mp3, etc
98
+ audios: list[bytes]
99
+
100
+
101
+ class ServeVQGANEncodeResponse(BaseModel):
102
+ tokens: SkipValidation[list[list[list[int]]]]
103
+
104
+
105
+ class ServeVQGANDecodeRequest(BaseModel):
106
+ tokens: SkipValidation[list[list[list[int]]]]
107
+
108
+
109
+ class ServeVQGANDecodeResponse(BaseModel):
110
+ # The audio here should be in PCM float16 format
111
+ audios: list[bytes]
112
+
113
+
114
+ class ServeReferenceAudio(BaseModel):
115
+ audio: bytes
116
+ text: str
117
+
118
+
119
+ class ServeForwardMessage(BaseModel):
120
+ role: str
121
+ content: str
122
+
123
+
124
+ class ServeResponse(BaseModel):
125
+ messages: list[ServeMessage]
126
+ finish_reason: Literal["stop", "error"] | None = None
127
+ stats: dict[str, int | float | str] = {}
128
+
129
+
130
+ class ServeStreamDelta(BaseModel):
131
+ role: Literal["system", "assistant", "user"] | None = None
132
+ part: ServeVQPart | ServeTextPart | None = None
133
+
134
+
135
+ class ServeStreamResponse(BaseModel):
136
+ sample_id: int = 0
137
+ delta: ServeStreamDelta | None = None
138
+ finish_reason: Literal["stop", "error"] | None = None
139
+ stats: dict[str, int | float | str] | None = None
140
+
141
+
142
+ class ServeReferenceAudio(BaseModel):
143
+ audio: bytes
144
+ text: str
145
+
146
+ def __repr__(self) -> str:
147
+ return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
148
+
149
+
150
+ class ServeChatRequestV1(BaseModel):
151
+ model: str = "llama3-8b"
152
+ messages: list[ServeForwardMessage] = []
153
+ audio: bytes | None = None
154
+ temperature: float = 1.0
155
+ top_p: float = 1.0
156
+ max_tokens: int = 256
157
+ voice: str = "jessica"
158
+ tts_audio_format: Literal["mp3", "pcm", "opus"] = "mp3"
159
+ tts_audio_bitrate: Literal[16, 24, 32, 48, 64, 96, 128, 192] = 128
160
+
161
+
162
+ class ServeTTSRequest(BaseModel):
163
+ text: str
164
+ chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
165
+ # Audio format
166
+ format: Literal["wav", "pcm", "mp3"] = "wav"
167
+ mp3_bitrate: Literal[64, 128, 192] = 128
168
+ # References audios for in-context learning
169
+ references: list[ServeReferenceAudio] = []
170
+ # Reference id
171
+ # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
172
+ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
173
+ reference_id: str | None = None
174
+ seed: int | None = None
175
+ use_memory_cache: Literal["on-demand", "never"] = "never"
176
+ # Normalize text for en & zh, this increase stability for numbers
177
+ normalize: bool = True
178
+ mp3_bitrate: Optional[int] = 64
179
+ opus_bitrate: Optional[int] = -1000
180
+ # Balance mode will reduce latency to 300ms, but may decrease stability
181
+ latency: Literal["normal", "balanced"] = "normal"
182
+ # not usually used below
183
+ streaming: bool = False
184
+ max_new_tokens: int = 1024
185
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
186
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
187
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
tools/sensevoice/README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FunASR Command Line Interface
2
+
3
+ This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
4
+
5
+ ## Requirements
6
+
7
+ - Python >= 3.10
8
+ - PyTorch <= 2.3.1
9
+ - ffmpeg, pydub, audio-separator[gpu].
10
+
11
+ ## Installation
12
+
13
+ Install the required packages:
14
+
15
+ ```bash
16
+ pip install -e .[stable]
17
+ ```
18
+
19
+ Make sure you have `ffmpeg` installed and available in your `PATH`.
20
+
21
+ ## Usage
22
+
23
+ ### Basic Usage
24
+
25
+ To run the tool with default settings:
26
+
27
+ ```bash
28
+ python tools/sensevoice/fun_asr.py --audio-dir <audio_directory> --save-dir <output_directory>
29
+ ```
30
+
31
+ ## Options
32
+
33
+ | Option | Description |
34
+ | :-----------------------: | :---------------------------------------------------------------------------: |
35
+ | --audio-dir | Directory containing audio or video files. |
36
+ | --save-dir | Directory to save processed audio files. |
37
+ | --device | Device to use for processing. Options: cuda (default) or cpu. |
38
+ | --language | Language of the transcription. Default is auto. |
39
+ | --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
40
+ | --punc | Enable punctuation prediction. |
41
+ | --denoise | Enable noise reduction (vocal separation). |
42
+
43
+ ## Example
44
+
45
+ To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
46
+
47
+ ```bash
48
+ python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
49
+ ```
50
+
51
+ ## Additional Notes
52
+
53
+ - The tool supports `both audio and video files`. Videos will be converted to audio automatically.
54
+ - If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
55
+ - The script will automatically create necessary directories in the `--save-dir`.
56
+
57
+ ## Troubleshooting
58
+
59
+ If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
tools/sensevoice/__init__.py ADDED
File without changes
tools/sensevoice/auto_model.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ import copy
7
+ import json
8
+ import logging
9
+ import os.path
10
+ import random
11
+ import re
12
+ import string
13
+ import time
14
+
15
+ import numpy as np
16
+ import torch
17
+ from funasr.download.download_model_from_hub import download_model
18
+ from funasr.download.file import download_from_url
19
+ from funasr.register import tables
20
+ from funasr.train_utils.load_pretrained_model import load_pretrained_model
21
+ from funasr.train_utils.set_all_random_seed import set_all_random_seed
22
+ from funasr.utils import export_utils, misc
23
+ from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
24
+ from funasr.utils.misc import deep_update
25
+ from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
26
+ from tqdm import tqdm
27
+
28
+ from .vad_utils import merge_vad, slice_padding_audio_samples
29
+
30
+ try:
31
+ from funasr.models.campplus.cluster_backend import ClusterBackend
32
+ from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
33
+ except:
34
+ pass
35
+
36
+
37
+ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
38
+ """ """
39
+ data_list = []
40
+ key_list = []
41
+ filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
42
+
43
+ chars = string.ascii_letters + string.digits
44
+ if isinstance(data_in, str):
45
+ if data_in.startswith("http://") or data_in.startswith("https://"): # url
46
+ data_in = download_from_url(data_in)
47
+
48
+ if isinstance(data_in, str) and os.path.exists(
49
+ data_in
50
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
51
+ _, file_extension = os.path.splitext(data_in)
52
+ file_extension = file_extension.lower()
53
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
54
+ with open(data_in, encoding="utf-8") as fin:
55
+ for line in fin:
56
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
57
+ if data_in.endswith(
58
+ ".jsonl"
59
+ ): # file.jsonl: json.dumps({"source": data})
60
+ lines = json.loads(line.strip())
61
+ data = lines["source"]
62
+ key = data["key"] if "key" in data else key
63
+ else: # filelist, wav.scp, text.txt: id \t data or data
64
+ lines = line.strip().split(maxsplit=1)
65
+ data = lines[1] if len(lines) > 1 else lines[0]
66
+ key = lines[0] if len(lines) > 1 else key
67
+
68
+ data_list.append(data)
69
+ key_list.append(key)
70
+ else:
71
+ if key is None:
72
+ # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
73
+ key = misc.extract_filename_without_extension(data_in)
74
+ data_list = [data_in]
75
+ key_list = [key]
76
+ elif isinstance(data_in, (list, tuple)):
77
+ if data_type is not None and isinstance(
78
+ data_type, (list, tuple)
79
+ ): # mutiple inputs
80
+ data_list_tmp = []
81
+ for data_in_i, data_type_i in zip(data_in, data_type):
82
+ key_list, data_list_i = prepare_data_iterator(
83
+ data_in=data_in_i, data_type=data_type_i
84
+ )
85
+ data_list_tmp.append(data_list_i)
86
+ data_list = []
87
+ for item in zip(*data_list_tmp):
88
+ data_list.append(item)
89
+ else:
90
+ # [audio sample point, fbank, text]
91
+ data_list = data_in
92
+ key_list = []
93
+ for data_i in data_in:
94
+ if isinstance(data_i, str) and os.path.exists(data_i):
95
+ key = misc.extract_filename_without_extension(data_i)
96
+ else:
97
+ if key is None:
98
+ key = "rand_key_" + "".join(
99
+ random.choice(chars) for _ in range(13)
100
+ )
101
+ key_list.append(key)
102
+
103
+ else: # raw text; audio sample point, fbank; bytes
104
+ if isinstance(data_in, bytes): # audio bytes
105
+ data_in = load_bytes(data_in)
106
+ if key is None:
107
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
108
+ data_list = [data_in]
109
+ key_list = [key]
110
+
111
+ return key_list, data_list
112
+
113
+
114
+ class AutoModel:
115
+
116
+ def __init__(self, **kwargs):
117
+
118
+ try:
119
+ from funasr.utils.version_checker import check_for_update
120
+
121
+ print(
122
+ "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
123
+ )
124
+ check_for_update(disable=kwargs.get("disable_update", False))
125
+ except:
126
+ pass
127
+
128
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
129
+ logging.basicConfig(level=log_level)
130
+
131
+ model, kwargs = self.build_model(**kwargs)
132
+
133
+ # if vad_model is not None, build vad model else None
134
+ vad_model = kwargs.get("vad_model", None)
135
+ vad_kwargs = (
136
+ {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
137
+ )
138
+ if vad_model is not None:
139
+ logging.info("Building VAD model.")
140
+ vad_kwargs["model"] = vad_model
141
+ vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
142
+ vad_kwargs["device"] = kwargs["device"]
143
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
144
+
145
+ # if punc_model is not None, build punc model else None
146
+ punc_model = kwargs.get("punc_model", None)
147
+ punc_kwargs = (
148
+ {}
149
+ if kwargs.get("punc_kwargs", {}) is None
150
+ else kwargs.get("punc_kwargs", {})
151
+ )
152
+ if punc_model is not None:
153
+ logging.info("Building punc model.")
154
+ punc_kwargs["model"] = punc_model
155
+ punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
156
+ punc_kwargs["device"] = kwargs["device"]
157
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
158
+
159
+ # if spk_model is not None, build spk model else None
160
+ spk_model = kwargs.get("spk_model", None)
161
+ spk_kwargs = (
162
+ {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
163
+ )
164
+ if spk_model is not None:
165
+ logging.info("Building SPK model.")
166
+ spk_kwargs["model"] = spk_model
167
+ spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
168
+ spk_kwargs["device"] = kwargs["device"]
169
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
170
+ self.cb_model = ClusterBackend().to(kwargs["device"])
171
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
172
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
173
+ logging.error(
174
+ "spk_mode should be one of default, vad_segment and punc_segment."
175
+ )
176
+ self.spk_mode = spk_mode
177
+
178
+ self.kwargs = kwargs
179
+ self.model = model
180
+ self.vad_model = vad_model
181
+ self.vad_kwargs = vad_kwargs
182
+ self.punc_model = punc_model
183
+ self.punc_kwargs = punc_kwargs
184
+ self.spk_model = spk_model
185
+ self.spk_kwargs = spk_kwargs
186
+ self.model_path = kwargs.get("model_path")
187
+
188
+ @staticmethod
189
+ def build_model(**kwargs):
190
+ assert "model" in kwargs
191
+ if "model_conf" not in kwargs:
192
+ logging.info(
193
+ "download models from model hub: {}".format(kwargs.get("hub", "ms"))
194
+ )
195
+ kwargs = download_model(**kwargs)
196
+
197
+ set_all_random_seed(kwargs.get("seed", 0))
198
+
199
+ device = kwargs.get("device", "cuda")
200
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
201
+ device = "cpu"
202
+ kwargs["batch_size"] = 1
203
+ kwargs["device"] = device
204
+
205
+ torch.set_num_threads(kwargs.get("ncpu", 4))
206
+
207
+ # build tokenizer
208
+ tokenizer = kwargs.get("tokenizer", None)
209
+ if tokenizer is not None:
210
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
211
+ tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
212
+ kwargs["token_list"] = (
213
+ tokenizer.token_list if hasattr(tokenizer, "token_list") else None
214
+ )
215
+ kwargs["token_list"] = (
216
+ tokenizer.get_vocab()
217
+ if hasattr(tokenizer, "get_vocab")
218
+ else kwargs["token_list"]
219
+ )
220
+ vocab_size = (
221
+ len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
222
+ )
223
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
224
+ vocab_size = tokenizer.get_vocab_size()
225
+ else:
226
+ vocab_size = -1
227
+ kwargs["tokenizer"] = tokenizer
228
+
229
+ # build frontend
230
+ frontend = kwargs.get("frontend", None)
231
+ kwargs["input_size"] = None
232
+ if frontend is not None:
233
+ frontend_class = tables.frontend_classes.get(frontend)
234
+ frontend = frontend_class(**kwargs.get("frontend_conf", {}))
235
+ kwargs["input_size"] = (
236
+ frontend.output_size() if hasattr(frontend, "output_size") else None
237
+ )
238
+ kwargs["frontend"] = frontend
239
+ # build model
240
+ model_class = tables.model_classes.get(kwargs["model"])
241
+ assert model_class is not None, f'{kwargs["model"]} is not registered'
242
+ model_conf = {}
243
+ deep_update(model_conf, kwargs.get("model_conf", {}))
244
+ deep_update(model_conf, kwargs)
245
+ model = model_class(**model_conf, vocab_size=vocab_size)
246
+
247
+ # init_param
248
+ init_param = kwargs.get("init_param", None)
249
+ if init_param is not None:
250
+ if os.path.exists(init_param):
251
+ logging.info(f"Loading pretrained params from {init_param}")
252
+ load_pretrained_model(
253
+ model=model,
254
+ path=init_param,
255
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
256
+ oss_bucket=kwargs.get("oss_bucket", None),
257
+ scope_map=kwargs.get("scope_map", []),
258
+ excludes=kwargs.get("excludes", None),
259
+ )
260
+ else:
261
+ print(f"error, init_param does not exist!: {init_param}")
262
+
263
+ # fp16
264
+ if kwargs.get("fp16", False):
265
+ model.to(torch.float16)
266
+ elif kwargs.get("bf16", False):
267
+ model.to(torch.bfloat16)
268
+ model.to(device)
269
+
270
+ if not kwargs.get("disable_log", True):
271
+ tables.print()
272
+
273
+ return model, kwargs
274
+
275
+ def __call__(self, *args, **cfg):
276
+ kwargs = self.kwargs
277
+ deep_update(kwargs, cfg)
278
+ res = self.model(*args, kwargs)
279
+ return res
280
+
281
+ def generate(self, input, input_len=None, **cfg):
282
+ if self.vad_model is None:
283
+ return self.inference(input, input_len=input_len, **cfg)
284
+
285
+ else:
286
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
287
+
288
+ def inference(
289
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
290
+ ):
291
+ kwargs = self.kwargs if kwargs is None else kwargs
292
+ if "cache" in kwargs:
293
+ kwargs.pop("cache")
294
+ deep_update(kwargs, cfg)
295
+ model = self.model if model is None else model
296
+ model.eval()
297
+
298
+ batch_size = kwargs.get("batch_size", 1)
299
+ # if kwargs.get("device", "cpu") == "cpu":
300
+ # batch_size = 1
301
+
302
+ key_list, data_list = prepare_data_iterator(
303
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
304
+ )
305
+
306
+ speed_stats = {}
307
+ asr_result_list = []
308
+ num_samples = len(data_list)
309
+ disable_pbar = self.kwargs.get("disable_pbar", False)
310
+ pbar = (
311
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
312
+ if not disable_pbar
313
+ else None
314
+ )
315
+ time_speech_total = 0.0
316
+ time_escape_total = 0.0
317
+ for beg_idx in range(0, num_samples, batch_size):
318
+ end_idx = min(num_samples, beg_idx + batch_size)
319
+ data_batch = data_list[beg_idx:end_idx]
320
+ key_batch = key_list[beg_idx:end_idx]
321
+ batch = {"data_in": data_batch, "key": key_batch}
322
+
323
+ if (end_idx - beg_idx) == 1 and kwargs.get(
324
+ "data_type", None
325
+ ) == "fbank": # fbank
326
+ batch["data_in"] = data_batch[0]
327
+ batch["data_lengths"] = input_len
328
+
329
+ time1 = time.perf_counter()
330
+ with torch.no_grad():
331
+ res = model.inference(**batch, **kwargs)
332
+ if isinstance(res, (list, tuple)):
333
+ results = res[0] if len(res) > 0 else [{"text": ""}]
334
+ meta_data = res[1] if len(res) > 1 else {}
335
+ time2 = time.perf_counter()
336
+
337
+ asr_result_list.extend(results)
338
+
339
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
340
+ batch_data_time = meta_data.get("batch_data_time", -1)
341
+ time_escape = time2 - time1
342
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
343
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
344
+ speed_stats["forward"] = f"{time_escape:0.3f}"
345
+ speed_stats["batch_size"] = f"{len(results)}"
346
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
347
+ description = f"{speed_stats}, "
348
+ if pbar:
349
+ pbar.update(end_idx - beg_idx)
350
+ pbar.set_description(description)
351
+ time_speech_total += batch_data_time
352
+ time_escape_total += time_escape
353
+
354
+ if pbar:
355
+ # pbar.update(1)
356
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
357
+ torch.cuda.empty_cache()
358
+ return asr_result_list
359
+
360
+ def vad(self, input, input_len=None, **cfg):
361
+ kwargs = self.kwargs
362
+ # step.1: compute the vad model
363
+ deep_update(self.vad_kwargs, cfg)
364
+ beg_vad = time.time()
365
+ res = self.inference(
366
+ input,
367
+ input_len=input_len,
368
+ model=self.vad_model,
369
+ kwargs=self.vad_kwargs,
370
+ **cfg,
371
+ )
372
+ end_vad = time.time()
373
+ # FIX(gcf): concat the vad clips for sense vocie model for better aed
374
+ if cfg.get("merge_vad", False):
375
+ for i in range(len(res)):
376
+ res[i]["value"] = merge_vad(
377
+ res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
378
+ )
379
+ elapsed = end_vad - beg_vad
380
+ return elapsed, res
381
+
382
+ def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
383
+
384
+ kwargs = self.kwargs
385
+
386
+ # step.2 compute asr model
387
+ model = self.model
388
+ deep_update(kwargs, cfg)
389
+ batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
390
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
391
+ kwargs["batch_size"] = batch_size
392
+
393
+ key_list, data_list = prepare_data_iterator(
394
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
395
+ )
396
+ results_ret_list = []
397
+ time_speech_total_all_samples = 1e-6
398
+
399
+ beg_total = time.time()
400
+ pbar_total = (
401
+ tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
402
+ if not kwargs.get("disable_pbar", False)
403
+ else None
404
+ )
405
+
406
+ for i in range(len(vad_res)):
407
+ key = vad_res[i]["key"]
408
+ vadsegments = vad_res[i]["value"]
409
+ input_i = data_list[i]
410
+ fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
411
+ speech = load_audio_text_image_video(
412
+ input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
413
+ )
414
+ speech_lengths = len(speech)
415
+ n = len(vadsegments)
416
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
417
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
418
+ results_sorted = []
419
+
420
+ if not len(sorted_data):
421
+ results_ret_list.append({"key": key, "text": "", "timestamp": []})
422
+ logging.info("decoding, utt: {}, empty speech".format(key))
423
+ continue
424
+
425
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
426
+ batch_size = max(
427
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
428
+ )
429
+
430
+ if kwargs["device"] == "cpu":
431
+ batch_size = 0
432
+
433
+ beg_idx = 0
434
+ beg_asr_total = time.time()
435
+ time_speech_total_per_sample = speech_lengths / 16000
436
+ time_speech_total_all_samples += time_speech_total_per_sample
437
+
438
+ # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
439
+
440
+ all_segments = []
441
+ max_len_in_batch = 0
442
+ end_idx = 1
443
+
444
+ for j, _ in enumerate(range(0, n)):
445
+ # pbar_sample.update(1)
446
+ sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
447
+ potential_batch_length = max(max_len_in_batch, sample_length) * (
448
+ j + 1 - beg_idx
449
+ )
450
+ # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
451
+ if (
452
+ j < n - 1
453
+ and sample_length < batch_size_threshold_ms
454
+ and potential_batch_length < batch_size
455
+ ):
456
+ max_len_in_batch = max(max_len_in_batch, sample_length)
457
+ end_idx += 1
458
+ continue
459
+
460
+ speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
461
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
462
+ )
463
+ results = self.inference(
464
+ speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
465
+ )
466
+
467
+ for _b in range(len(speech_j)):
468
+ results[_b]["interval"] = intervals[_b]
469
+
470
+ if self.spk_model is not None:
471
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
472
+ for _b in range(len(speech_j)):
473
+ vad_segments = [
474
+ [
475
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
476
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
477
+ np.array(speech_j[_b]),
478
+ ]
479
+ ]
480
+ segments = sv_chunk(vad_segments)
481
+ all_segments.extend(segments)
482
+ speech_b = [i[2] for i in segments]
483
+ spk_res = self.inference(
484
+ speech_b,
485
+ input_len=None,
486
+ model=self.spk_model,
487
+ kwargs=kwargs,
488
+ **cfg,
489
+ )
490
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
491
+
492
+ beg_idx = end_idx
493
+ end_idx += 1
494
+ max_len_in_batch = sample_length
495
+ if len(results) < 1:
496
+ continue
497
+ results_sorted.extend(results)
498
+
499
+ # end_asr_total = time.time()
500
+ # time_escape_total_per_sample = end_asr_total - beg_asr_total
501
+ # pbar_sample.update(1)
502
+ # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
503
+ # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
504
+ # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
505
+
506
+ restored_data = [0] * n
507
+ for j in range(n):
508
+ index = sorted_data[j][1]
509
+ cur = results_sorted[j]
510
+ pattern = r"<\|([^|]+)\|>"
511
+ emotion_string = re.findall(pattern, cur["text"])
512
+ cur["text"] = re.sub(pattern, "", cur["text"])
513
+ cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
514
+ if self.punc_model is not None and len(cur["text"].strip()) > 0:
515
+ deep_update(self.punc_kwargs, cfg)
516
+ punc_res = self.inference(
517
+ cur["text"],
518
+ model=self.punc_model,
519
+ kwargs=self.punc_kwargs,
520
+ **cfg,
521
+ )
522
+ cur["text"] = punc_res[0]["text"]
523
+
524
+ restored_data[index] = cur
525
+
526
+ end_asr_total = time.time()
527
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
528
+ if pbar_total:
529
+ pbar_total.update(1)
530
+ pbar_total.set_description(
531
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
532
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
533
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
534
+ )
535
+
536
+ # end_total = time.time()
537
+ # time_escape_total_all_samples = end_total - beg_total
538
+ # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
539
+ # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
540
+ # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
541
+ return restored_data
542
+
543
+ def export(self, input=None, **cfg):
544
+ """
545
+
546
+ :param input:
547
+ :param type:
548
+ :param quantize:
549
+ :param fallback_num:
550
+ :param calib_num:
551
+ :param opset_version:
552
+ :param cfg:
553
+ :return:
554
+ """
555
+
556
+ device = cfg.get("device", "cpu")
557
+ model = self.model.to(device=device)
558
+ kwargs = self.kwargs
559
+ deep_update(kwargs, cfg)
560
+ kwargs["device"] = device
561
+ del kwargs["model"]
562
+ model.eval()
563
+
564
+ type = kwargs.get("type", "onnx")
565
+
566
+ key_list, data_list = prepare_data_iterator(
567
+ input, input_len=None, data_type=kwargs.get("data_type", None), key=None
568
+ )
569
+
570
+ with torch.no_grad():
571
+ export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
572
+
573
+ return export_dir
tools/sensevoice/fun_asr.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import re
4
+
5
+ from audio_separator.separator import Separator
6
+
7
+ os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
8
+ os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
9
+ import json
10
+ import subprocess
11
+ from pathlib import Path
12
+
13
+ import click
14
+ import torch
15
+ from loguru import logger
16
+ from pydub import AudioSegment
17
+ from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
18
+ from tqdm import tqdm
19
+
20
+ from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
21
+ from tools.sensevoice.auto_model import AutoModel
22
+
23
+
24
+ def uvr5_cli(
25
+ audio_dir: Path,
26
+ output_folder: Path,
27
+ audio_files: list[Path] | None = None,
28
+ output_format: str = "flac",
29
+ model: str = "BS-Roformer-Viperx-1297.ckpt",
30
+ ):
31
+ # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
32
+ sepr = Separator(
33
+ model_file_dir=os.environ["UVR5_CACHE"],
34
+ output_dir=output_folder,
35
+ output_format=output_format,
36
+ )
37
+ dictmodel = {
38
+ "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
39
+ "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
40
+ "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
41
+ "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
42
+ }
43
+ roformer_model = dictmodel[model]
44
+ sepr.load_model(roformer_model)
45
+ if audio_files is None:
46
+ audio_files = list_files(
47
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
48
+ )
49
+ total_files = len(audio_files)
50
+
51
+ print(f"{total_files} audio files found")
52
+
53
+ res = []
54
+ for audio in tqdm(audio_files, desc="Denoising: "):
55
+ file_path = str(audio_dir / audio)
56
+ sep_out = sepr.separate(file_path)
57
+ if isinstance(sep_out, str):
58
+ res.append(sep_out)
59
+ elif isinstance(sep_out, list):
60
+ res.extend(sep_out)
61
+ del sepr
62
+ gc.collect()
63
+ if torch.cuda.is_available():
64
+ torch.cuda.empty_cache()
65
+
66
+ return res, roformer_model
67
+
68
+
69
+ def get_sample_rate(media_path: Path):
70
+ result = subprocess.run(
71
+ [
72
+ "ffprobe",
73
+ "-v",
74
+ "quiet",
75
+ "-print_format",
76
+ "json",
77
+ "-show_streams",
78
+ str(media_path),
79
+ ],
80
+ capture_output=True,
81
+ text=True,
82
+ check=True,
83
+ )
84
+ media_info = json.loads(result.stdout)
85
+ for stream in media_info.get("streams", []):
86
+ if stream.get("codec_type") == "audio":
87
+ return stream.get("sample_rate")
88
+ return "44100" # Default sample rate if not found
89
+
90
+
91
+ def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
92
+ sr = get_sample_rate(src_path)
93
+ out_path.parent.mkdir(parents=True, exist_ok=True)
94
+ if src_path.resolve() == out_path.resolve():
95
+ output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
96
+ else:
97
+ output = str(out_path)
98
+ subprocess.run(
99
+ [
100
+ "ffmpeg",
101
+ "-loglevel",
102
+ "error",
103
+ "-i",
104
+ str(src_path),
105
+ "-acodec",
106
+ "pcm_s16le" if out_fmt == "wav" else "flac",
107
+ "-ar",
108
+ sr,
109
+ "-ac",
110
+ "1",
111
+ "-y",
112
+ output,
113
+ ],
114
+ check=True,
115
+ )
116
+ return out_path
117
+
118
+
119
+ def convert_video_to_audio(video_path: Path, audio_dir: Path):
120
+ cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
121
+ vocals = [
122
+ p
123
+ for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
124
+ if p.suffix in AUDIO_EXTENSIONS
125
+ ]
126
+ if len(vocals) > 0:
127
+ return vocals[0]
128
+ audio_path = cur_dir / f"{video_path.stem}.wav"
129
+ convert_to_mono(video_path, audio_path)
130
+ return audio_path
131
+
132
+
133
+ @click.command()
134
+ @click.option("--audio-dir", required=True, help="Directory containing audio files")
135
+ @click.option(
136
+ "--save-dir", required=True, help="Directory to save processed audio files"
137
+ )
138
+ @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
139
+ @click.option("--language", default="auto", help="Language of the transcription")
140
+ @click.option(
141
+ "--max_single_segment_time",
142
+ default=20000,
143
+ type=int,
144
+ help="Maximum of Output single audio duration(ms)",
145
+ )
146
+ @click.option("--fsmn-vad/--silero-vad", default=False)
147
+ @click.option("--punc/--no-punc", default=False)
148
+ @click.option("--denoise/--no-denoise", default=False)
149
+ @click.option("--save_emo/--no_save_emo", default=False)
150
+ def main(
151
+ audio_dir: str,
152
+ save_dir: str,
153
+ device: str,
154
+ language: str,
155
+ max_single_segment_time: int,
156
+ fsmn_vad: bool,
157
+ punc: bool,
158
+ denoise: bool,
159
+ save_emo: bool,
160
+ ):
161
+
162
+ audios_path = Path(audio_dir)
163
+ save_path = Path(save_dir)
164
+ save_path.mkdir(parents=True, exist_ok=True)
165
+
166
+ video_files = list_files(
167
+ path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
168
+ )
169
+ v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
170
+
171
+ if denoise:
172
+ VOCAL = "_(Vocals)"
173
+ original_files = [
174
+ p
175
+ for p in audios_path.glob("**/*")
176
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
177
+ ]
178
+
179
+ _, cur_model = uvr5_cli(
180
+ audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
181
+ )
182
+ need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
183
+ need_remove.extend(original_files)
184
+ for _ in need_remove:
185
+ _.unlink()
186
+ vocal_files = [
187
+ p
188
+ for p in audios_path.glob("**/*")
189
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
190
+ ]
191
+ for f in vocal_files:
192
+ fn, ext = f.stem, f.suffix
193
+
194
+ v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
195
+ if v_pos != -1:
196
+ new_fn = fn[: v_pos + len(VOCAL)]
197
+ new_f = f.with_name(new_fn + ext)
198
+ f = f.rename(new_f)
199
+ convert_to_mono(f, f, "flac")
200
+ f.unlink()
201
+
202
+ audio_files = list_files(
203
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
204
+ )
205
+
206
+ logger.info("Loading / Downloading Funasr model...")
207
+
208
+ model_dir = "iic/SenseVoiceSmall"
209
+
210
+ vad_model = "fsmn-vad" if fsmn_vad else None
211
+ vad_kwargs = {"max_single_segment_time": max_single_segment_time}
212
+ punc_model = "ct-punc" if punc else None
213
+
214
+ manager = AutoModel(
215
+ model=model_dir,
216
+ trust_remote_code=False,
217
+ vad_model=vad_model,
218
+ vad_kwargs=vad_kwargs,
219
+ punc_model=punc_model,
220
+ device=device,
221
+ )
222
+
223
+ if not fsmn_vad and vad_model is None:
224
+ vad_model = load_silero_vad()
225
+
226
+ logger.info("Model loaded.")
227
+
228
+ pattern = re.compile(r"_\d{3}\.")
229
+
230
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
231
+
232
+ if pattern.search(file_path.name):
233
+ # logger.info(f"Skipping {file_path} as it has already been processed.")
234
+ continue
235
+
236
+ file_stem = file_path.stem
237
+ file_suffix = file_path.suffix
238
+
239
+ rel_path = Path(file_path).relative_to(audio_dir)
240
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
241
+
242
+ audio = AudioSegment.from_file(file_path)
243
+
244
+ cfg = dict(
245
+ cache={},
246
+ language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
247
+ use_itn=False,
248
+ batch_size_s=60,
249
+ )
250
+
251
+ if fsmn_vad:
252
+ elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
253
+ else:
254
+ wav = read_audio(
255
+ str(file_path)
256
+ ) # backend (sox, soundfile, or ffmpeg) required!
257
+ audio_key = file_path.stem
258
+ audio_val = []
259
+ speech_timestamps = get_speech_timestamps(
260
+ wav,
261
+ vad_model,
262
+ max_speech_duration_s=max_single_segment_time // 1000,
263
+ return_seconds=True,
264
+ )
265
+
266
+ audio_val = [
267
+ [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
268
+ for timestamp in speech_timestamps
269
+ ]
270
+ vad_res = []
271
+ vad_res.append(dict(key=audio_key, value=audio_val))
272
+
273
+ res = manager.inference_with_vadres(
274
+ input=str(file_path), vad_res=vad_res, **cfg
275
+ )
276
+
277
+ for i, info in enumerate(res):
278
+ [start_ms, end_ms] = info["interval"]
279
+ text = info["text"]
280
+ emo = info["emo"]
281
+ sliced_audio = audio[start_ms:end_ms]
282
+ audio_save_path = (
283
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
284
+ )
285
+ sliced_audio.export(audio_save_path, format=file_suffix[1:])
286
+ print(f"Exported {audio_save_path}: {text}")
287
+
288
+ transcript_save_path = (
289
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
290
+ )
291
+ with open(
292
+ transcript_save_path,
293
+ "w",
294
+ encoding="utf-8",
295
+ ) as f:
296
+ f.write(text)
297
+
298
+ if save_emo:
299
+ emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
300
+ with open(
301
+ emo_save_path,
302
+ "w",
303
+ encoding="utf-8",
304
+ ) as f:
305
+ f.write(emo)
306
+
307
+ if audios_path.resolve() == save_path.resolve():
308
+ file_path.unlink()
309
+
310
+
311
+ if __name__ == "__main__":
312
+ main()
313
+ exit(0)
314
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
315
+
316
+ # Load the audio file
317
+ audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
318
+ model_dir = "iic/SenseVoiceSmall"
319
+ m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
320
+ m.eval()
321
+
322
+ res = m.inference(
323
+ data_in=f"{kwargs['model_path']}/example/zh.mp3",
324
+ language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
325
+ use_itn=False,
326
+ ban_emo_unk=False,
327
+ **kwargs,
328
+ )
329
+
330
+ print(res)
331
+ text = rich_transcription_postprocess(res[0][0]["text"])
332
+ print(text)
tools/sensevoice/vad_utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.utils.rnn import pad_sequence
3
+
4
+
5
+ def slice_padding_fbank(speech, speech_lengths, vad_segments):
6
+ speech_list = []
7
+ speech_lengths_list = []
8
+ for i, segment in enumerate(vad_segments):
9
+
10
+ bed_idx = int(segment[0][0] * 16)
11
+ end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
12
+ speech_i = speech[0, bed_idx:end_idx]
13
+ speech_lengths_i = end_idx - bed_idx
14
+ speech_list.append(speech_i)
15
+ speech_lengths_list.append(speech_lengths_i)
16
+ feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
17
+ speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
18
+ return feats_pad, speech_lengths_pad
19
+
20
+
21
+ def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
22
+ speech_list = []
23
+ speech_lengths_list = []
24
+ intervals = []
25
+ for i, segment in enumerate(vad_segments):
26
+ bed_idx = int(segment[0][0] * 16)
27
+ end_idx = min(int(segment[0][1] * 16), speech_lengths)
28
+ speech_i = speech[bed_idx:end_idx]
29
+ speech_lengths_i = end_idx - bed_idx
30
+ speech_list.append(speech_i)
31
+ speech_lengths_list.append(speech_lengths_i)
32
+ intervals.append([bed_idx // 16, end_idx // 16])
33
+
34
+ return speech_list, speech_lengths_list, intervals
35
+
36
+
37
+ def merge_vad(vad_result, max_length=15000, min_length=0):
38
+ new_result = []
39
+ if len(vad_result) <= 1:
40
+ return vad_result
41
+ time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
42
+ time_step = sorted(list(set(time_step)))
43
+ if len(time_step) == 0:
44
+ return []
45
+ bg = 0
46
+ for i in range(len(time_step) - 1):
47
+ time = time_step[i]
48
+ if time_step[i + 1] - bg < max_length:
49
+ continue
50
+ if time - bg > min_length:
51
+ new_result.append([bg, time])
52
+ # if time - bg < max_length * 1.5:
53
+ # new_result.append([bg, time])
54
+ # else:
55
+ # split_num = int(time - bg) // max_length + 1
56
+ # spl_l = int(time - bg) // split_num
57
+ # for j in range(split_num):
58
+ # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
59
+ bg = time
60
+ new_result.append([bg, time_step[-1]])
61
+ return new_result
tools/smart_pad.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from multiprocessing import Pool
3
+ from pathlib import Path
4
+
5
+ import click
6
+ import librosa
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from tqdm import tqdm
10
+
11
+ from tools.file import AUDIO_EXTENSIONS, list_files
12
+
13
+ threshold = 10 ** (-50 / 20.0)
14
+
15
+
16
+ def process(file):
17
+ waveform, sample_rate = torchaudio.load(str(file), backend="sox")
18
+ if waveform.size(0) > 1:
19
+ waveform = waveform.mean(dim=0, keepdim=True)
20
+
21
+ loudness = librosa.feature.rms(
22
+ y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
23
+ )[0]
24
+
25
+ for i in range(len(loudness) - 1, 0, -1):
26
+ if loudness[i] > threshold:
27
+ break
28
+
29
+ end_silent_time = (len(loudness) - i) * 512 / sample_rate
30
+
31
+ if end_silent_time <= 0.3:
32
+ random_time = random.uniform(0.3, 0.7) - end_silent_time
33
+ waveform = F.pad(
34
+ waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
35
+ )
36
+
37
+ for i in range(len(loudness)):
38
+ if loudness[i] > threshold:
39
+ break
40
+
41
+ start_silent_time = i * 512 / sample_rate
42
+
43
+ if start_silent_time > 0.02:
44
+ waveform = waveform[:, int((start_silent_time - 0.02) * sample_rate) :]
45
+
46
+ torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
47
+
48
+
49
+ @click.command()
50
+ @click.argument("source", type=Path)
51
+ @click.option("--num-workers", type=int, default=12)
52
+ def main(source, num_workers):
53
+ files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
54
+
55
+ with Pool(num_workers) as p:
56
+ list(tqdm(p.imap_unordered(process, files), total=len(files)))
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
tools/vqgan/__pycache__/inference.cpython-310.pyc ADDED
Binary file (3.53 kB). View file
 
tools/vqgan/create_train_split.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pathlib import Path
3
+ from random import Random
4
+
5
+ import click
6
+ from loguru import logger
7
+ from pydub import AudioSegment
8
+ from tqdm import tqdm
9
+
10
+ from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
11
+
12
+
13
+ @click.command()
14
+ @click.argument("root", type=click.Path(exists=True, path_type=Path))
15
+ @click.option("--val-ratio", type=float, default=None)
16
+ @click.option("--val-count", type=int, default=None)
17
+ @click.option("--filelist", default=None, type=Path)
18
+ @click.option("--min-duration", default=None, type=float)
19
+ @click.option("--max-duration", default=None, type=float)
20
+ def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
21
+ if filelist:
22
+ files = [i[0] for i in load_filelist(filelist)]
23
+ else:
24
+ files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
25
+
26
+ if min_duration is None and max_duration is None:
27
+ filtered_files = list(map(str, [file.relative_to(root) for file in files]))
28
+ else:
29
+ filtered_files = []
30
+ for file in tqdm(files):
31
+ try:
32
+ audio = AudioSegment.from_file(str(file))
33
+ duration = len(audio) / 1000.0
34
+
35
+ if min_duration is not None and duration < min_duration:
36
+ logger.info(
37
+ f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
38
+ )
39
+ continue
40
+
41
+ if max_duration is not None and duration > max_duration:
42
+ logger.info(
43
+ f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
44
+ )
45
+ continue
46
+
47
+ filtered_files.append(str(file.relative_to(root)))
48
+ except Exception as e:
49
+ logger.info(f"Error processing {file}: {e}")
50
+
51
+ logger.info(
52
+ f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
53
+ )
54
+
55
+ Random(42).shuffle(filtered_files)
56
+
57
+ if val_count is None and val_ratio is None:
58
+ logger.info("Validation ratio and count not specified, using min(20%, 100)")
59
+ val_size = min(100, math.ceil(len(filtered_files) * 0.2))
60
+ elif val_count is not None and val_ratio is not None:
61
+ logger.error("Cannot specify both val_count and val_ratio")
62
+ return
63
+ elif val_count is not None:
64
+ if val_count < 1 or val_count > len(filtered_files):
65
+ logger.error("val_count must be between 1 and number of files")
66
+ return
67
+ val_size = val_count
68
+ else:
69
+ val_size = math.ceil(len(filtered_files) * val_ratio)
70
+
71
+ logger.info(f"Using {val_size} files for validation")
72
+
73
+ with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
74
+ f.write("\n".join(filtered_files[val_size:]))
75
+
76
+ with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
77
+ f.write("\n".join(filtered_files[:val_size]))
78
+
79
+ logger.info("Done")
80
+
81
+
82
+ if __name__ == "__main__":
83
+ main()
tools/vqgan/extract_vq.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess as sp
3
+ import sys
4
+ import time
5
+ from datetime import timedelta
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+ from random import Random
9
+
10
+ import click
11
+ import numpy as np
12
+ import torch
13
+ import torchaudio
14
+ from hydra import compose, initialize
15
+ from hydra.utils import instantiate
16
+ from lightning import LightningModule
17
+ from loguru import logger
18
+ from omegaconf import OmegaConf
19
+
20
+ from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
21
+
22
+ # register eval resolver
23
+ OmegaConf.register_new_resolver("eval", eval)
24
+ # This file is used to convert the audio files to text files using the Whisper model.
25
+ # It's mainly used to generate the training data for the VQ model.
26
+
27
+ backends = torchaudio.list_audio_backends()
28
+
29
+ if "ffmpeg" in backends:
30
+ backend = "ffmpeg"
31
+ else:
32
+ backend = "soundfile"
33
+
34
+ RANK = int(os.environ.get("SLURM_PROCID", 0))
35
+ WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
36
+
37
+ logger_format = (
38
+ "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
39
+ "<level>{level: <8}</level> | "
40
+ "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
41
+ "{extra[rank]} - <level>{message}</level>"
42
+ )
43
+ logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
44
+ logger.remove()
45
+ logger.add(sys.stderr, format=logger_format)
46
+
47
+
48
+ @lru_cache(maxsize=1)
49
+ def get_model(
50
+ config_name: str = "firefly_gan_vq",
51
+ checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
52
+ device: str | torch.device = "cuda",
53
+ ):
54
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
55
+ cfg = compose(config_name=config_name)
56
+
57
+ model = instantiate(cfg)
58
+ state_dict = torch.load(
59
+ checkpoint_path,
60
+ map_location=device,
61
+ )
62
+ if "state_dict" in state_dict:
63
+ state_dict = state_dict["state_dict"]
64
+
65
+ if any("generator" in k for k in state_dict):
66
+ state_dict = {
67
+ k.replace("generator.", ""): v
68
+ for k, v in state_dict.items()
69
+ if "generator." in k
70
+ }
71
+
72
+ model.load_state_dict(state_dict, strict=False)
73
+ model.eval()
74
+ model.to(device)
75
+
76
+ logger.info(f"Loaded model")
77
+ return model
78
+
79
+
80
+ @torch.inference_mode()
81
+ def process_batch(files: list[Path], model) -> float:
82
+ wavs = []
83
+ audio_lengths = []
84
+ new_files = []
85
+ max_length = total_time = 0
86
+
87
+ for file in files:
88
+ try:
89
+ wav, sr = torchaudio.load(
90
+ str(file), backend=backend
91
+ ) # Need to install libsox-dev
92
+ except Exception as e:
93
+ logger.error(f"Error reading {file}: {e}")
94
+ continue
95
+
96
+ if wav.shape[0] > 1:
97
+ wav = wav.mean(dim=0, keepdim=True)
98
+
99
+ wav = torchaudio.functional.resample(
100
+ wav.cuda(), sr, model.spec_transform.sample_rate
101
+ )[0]
102
+ total_time += len(wav) / model.spec_transform.sample_rate
103
+ max_length = max(max_length, len(wav))
104
+
105
+ wavs.append(wav)
106
+ audio_lengths.append(len(wav))
107
+ new_files.append(file)
108
+
109
+ files = new_files
110
+
111
+ # Pad to max length
112
+ for i, wav in enumerate(wavs):
113
+ wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
114
+
115
+ audios = torch.stack(wavs, dim=0)[:, None]
116
+ audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
117
+
118
+ # Calculate lengths
119
+ indices, feature_lengths = model.encode(audios, audio_lengths)
120
+
121
+ # Save to disk
122
+ outputs = indices.cpu().numpy()
123
+
124
+ for file, length, feature, audio_length in zip(
125
+ files, feature_lengths, outputs, audio_lengths
126
+ ):
127
+ feature = feature[:, :length]
128
+
129
+ # (T,)
130
+ with open(file.with_suffix(".npy"), "wb") as f:
131
+ np.save(f, feature)
132
+
133
+ return total_time
134
+
135
+
136
+ @click.command()
137
+ @click.argument("folder")
138
+ @click.option("--num-workers", default=1)
139
+ @click.option("--config-name", default="firefly_gan_vq")
140
+ @click.option(
141
+ "--checkpoint-path",
142
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
143
+ )
144
+ @click.option("--batch-size", default=64)
145
+ @click.option("--filelist", default=None, type=Path)
146
+ def main(
147
+ folder: str,
148
+ num_workers: int,
149
+ config_name: str,
150
+ checkpoint_path: str,
151
+ batch_size: int,
152
+ filelist: Path,
153
+ ):
154
+ if num_workers > 1 and WORLD_SIZE != num_workers:
155
+ assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
156
+
157
+ logger.info(f"Spawning {num_workers} workers")
158
+
159
+ if torch.cuda.is_available():
160
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
161
+ if visible_devices is None:
162
+ visible_devices = list(range(torch.cuda.device_count()))
163
+ else:
164
+ visible_devices = visible_devices.split(",")
165
+ else:
166
+ # Set to empty string to avoid using GPU
167
+ visible_devices = [""]
168
+
169
+ processes = []
170
+ for i in range(num_workers):
171
+ env = os.environ.copy()
172
+ env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
173
+ env["SLURM_PROCID"] = str(i)
174
+ env["SLURM_NTASKS"] = str(num_workers)
175
+
176
+ processes.append(
177
+ sp.Popen(
178
+ [sys.executable] + sys.argv.copy(),
179
+ env=env,
180
+ )
181
+ )
182
+
183
+ for p in processes:
184
+ p.wait()
185
+
186
+ logger.info(f"All workers finished")
187
+ return
188
+
189
+ # This is a worker
190
+ logger.info(f"Starting worker")
191
+ if filelist:
192
+ files = [i[0] for i in load_filelist(filelist)]
193
+ else:
194
+ files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
195
+
196
+ print(f"Found {len(files)} files")
197
+ files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
198
+
199
+ total_files = len(files)
200
+ files = files[RANK::WORLD_SIZE]
201
+ logger.info(f"Processing {len(files)}/{total_files} files")
202
+
203
+ # Batch processing
204
+ total_time = 0
205
+ begin_time = time.time()
206
+ processed_files = 0
207
+ model = get_model(config_name, checkpoint_path)
208
+
209
+ for n_batch, idx in enumerate(range(0, len(files), batch_size)):
210
+ batch = files[idx : idx + batch_size]
211
+ batch_time = process_batch(batch, model)
212
+
213
+ total_time += batch_time
214
+ processed_files += len(batch)
215
+
216
+ if (n_batch + 1) % 10 == 0:
217
+ eta = (
218
+ (time.time() - begin_time)
219
+ / processed_files
220
+ * (len(files) - processed_files)
221
+ )
222
+ logger.info(
223
+ f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
224
+ + f"ETA: {timedelta(seconds=round(eta))}s"
225
+ )
226
+
227
+ logger.info(
228
+ f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
229
+ )
230
+
231
+
232
+ if __name__ == "__main__":
233
+ main()
tools/vqgan/inference.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import click
4
+ import hydra
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import torch
8
+ import torchaudio
9
+ from hydra import compose, initialize
10
+ from hydra.utils import instantiate
11
+ from loguru import logger
12
+ from omegaconf import OmegaConf
13
+
14
+ from tools.file import AUDIO_EXTENSIONS
15
+
16
+ # register eval resolver
17
+ OmegaConf.register_new_resolver("eval", eval)
18
+
19
+
20
+ def load_model(config_name, checkpoint_path, device="cuda"):
21
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
22
+ with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
23
+ cfg = compose(config_name=config_name)
24
+
25
+ model = instantiate(cfg)
26
+ state_dict = torch.load(
27
+ checkpoint_path, map_location=device, mmap=True, weights_only=True
28
+ )
29
+ if "state_dict" in state_dict:
30
+ state_dict = state_dict["state_dict"]
31
+
32
+ if any("generator" in k for k in state_dict):
33
+ state_dict = {
34
+ k.replace("generator.", ""): v
35
+ for k, v in state_dict.items()
36
+ if "generator." in k
37
+ }
38
+
39
+ result = model.load_state_dict(state_dict, strict=False, assign=True)
40
+ model.eval()
41
+ model.to(device)
42
+
43
+ logger.info(f"Loaded model: {result}")
44
+ return model
45
+
46
+
47
+ @torch.no_grad()
48
+ @click.command()
49
+ @click.option(
50
+ "--input-path",
51
+ "-i",
52
+ default="test.wav",
53
+ type=click.Path(exists=True, path_type=Path),
54
+ )
55
+ @click.option(
56
+ "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
57
+ )
58
+ @click.option("--config-name", default="firefly_gan_vq")
59
+ @click.option(
60
+ "--checkpoint-path",
61
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
62
+ )
63
+ @click.option(
64
+ "--device",
65
+ "-d",
66
+ default="cuda",
67
+ )
68
+ def main(input_path, output_path, config_name, checkpoint_path, device):
69
+ model = load_model(config_name, checkpoint_path, device=device)
70
+
71
+ if input_path.suffix in AUDIO_EXTENSIONS:
72
+ logger.info(f"Processing in-place reconstruction of {input_path}")
73
+
74
+ # Load audio
75
+ audio, sr = torchaudio.load(str(input_path))
76
+ if audio.shape[0] > 1:
77
+ audio = audio.mean(0, keepdim=True)
78
+ audio = torchaudio.functional.resample(
79
+ audio, sr, model.spec_transform.sample_rate
80
+ )
81
+
82
+ audios = audio[None].to(device)
83
+ logger.info(
84
+ f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
85
+ )
86
+
87
+ # VQ Encoder
88
+ audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
89
+ indices = model.encode(audios, audio_lengths)[0][0]
90
+
91
+ logger.info(f"Generated indices of shape {indices.shape}")
92
+
93
+ # Save indices
94
+ np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
95
+ elif input_path.suffix == ".npy":
96
+ logger.info(f"Processing precomputed indices from {input_path}")
97
+ indices = np.load(input_path)
98
+ indices = torch.from_numpy(indices).to(device).long()
99
+ assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
100
+ else:
101
+ raise ValueError(f"Unknown input type: {input_path}")
102
+
103
+ # Restore
104
+ feature_lengths = torch.tensor([indices.shape[1]], device=device)
105
+ fake_audios, _ = model.decode(
106
+ indices=indices[None], feature_lengths=feature_lengths
107
+ )
108
+ audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
109
+
110
+ logger.info(
111
+ f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
112
+ )
113
+
114
+ # Save audio
115
+ fake_audio = fake_audios[0, 0].float().cpu().numpy()
116
+ sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
117
+ logger.info(f"Saved audio to {output_path}")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
tools/webui.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import html
3
+ import io
4
+ import os
5
+ import queue
6
+ import wave
7
+ from argparse import ArgumentParser
8
+ from functools import partial
9
+ from pathlib import Path
10
+
11
+ import gradio as gr
12
+ import librosa
13
+ import numpy as np
14
+ import pyrootutils
15
+ import torch
16
+ from loguru import logger
17
+ from transformers import AutoTokenizer
18
+
19
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
+
21
+
22
+ from fish_speech.i18n import i18n
23
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
24
+ from fish_speech.utils import autocast_exclude_mps, set_seed
25
+ from tools.api import decode_vq_tokens, encode_reference
26
+ from tools.file import AUDIO_EXTENSIONS, list_files
27
+ from tools.llama.generate import (
28
+ GenerateRequest,
29
+ GenerateResponse,
30
+ WrappedGenerateResponse,
31
+ launch_thread_safe_queue,
32
+ )
33
+ from tools.vqgan.inference import load_model as load_decoder_model
34
+
35
+
36
+
37
+ # Make einx happy
38
+ os.environ["EINX_FILTER_TRACEBACK"] = "false"
39
+
40
+
41
+ HEADER_MD = f"""# 泰雅爾語TTS
42
+
43
+ #泰雅爾語測試範例
44
+
45
+ {i18n("Miyan qaniy qu binkgan bbinkesan na Yesu:Yesu Kristo ga kinbahan na Tabite, Tabite ga kinbahan na Aburaham.")}
46
+
47
+ {i18n("Aburaham ga yaba na Isak; Isak ga yaba na Yakob; Yakob ga yaba na Yuta ki mmtswe nya mlikuy.")}
48
+
49
+ {i18n("Babaw nqu kyapun rasun squ qalang Babilon lga, plqyun ni Yehoyacin qu Seltiyel; Seltiyel ga yaba na Zerubabel;")}
50
+
51
+ #若要使用自己的聲音合成請按以下步驟(Streaming Generate)
52
+
53
+ # <span style="color: red;">Streaming Generate 此功能維護中</span>
54
+
55
+ {i18n("1.在Reference Audio找到Enable Reference Audio打勾")}
56
+
57
+ {i18n("2.在左下方將錄音檔案上傳,並在Reference Text輸入上傳音檔的文字")}
58
+
59
+ {i18n("3.在Input Text輸入文字")}
60
+
61
+ {i18n("4.按下Streaming Generate即可")}
62
+
63
+
64
+ """
65
+
66
+ TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
67
+ SPACE_IMPORTED = False
68
+
69
+
70
+ def build_html_error_message(error):
71
+ return f"""
72
+ <div style="color: red;
73
+ font-weight: bold;">
74
+ {html.escape(str(error))}
75
+ </div>
76
+ """
77
+
78
+
79
+ @torch.inference_mode()
80
+ def inference(
81
+ text,
82
+ enable_reference_audio,
83
+ reference_audio,
84
+ reference_text,
85
+ max_new_tokens,
86
+ chunk_length,
87
+ top_p,
88
+ repetition_penalty,
89
+ temperature,
90
+ seed="0",
91
+ streaming=False,
92
+ ):
93
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
94
+ return (
95
+ None,
96
+ None,
97
+ i18n("Text is too long, please keep it under {} characters.").format(
98
+ args.max_gradio_length
99
+ ),
100
+ )
101
+
102
+ seed = int(seed)
103
+ if seed != 0:
104
+ set_seed(seed)
105
+ logger.warning(f"set seed: {seed}")
106
+
107
+ # Parse reference audio aka prompt
108
+ prompt_tokens = encode_reference(
109
+ decoder_model=decoder_model,
110
+ reference_audio=reference_audio,
111
+ enable_reference_audio=enable_reference_audio,
112
+ )
113
+
114
+ # LLAMA Inference
115
+ request = dict(
116
+ device=decoder_model.device,
117
+ max_new_tokens=600,
118
+ text=text,
119
+ top_p=top_p,
120
+ repetition_penalty=repetition_penalty,
121
+ temperature=temperature,
122
+ compile=args.compile,
123
+ iterative_prompt=chunk_length > 0,
124
+ chunk_length=chunk_length,
125
+ max_length=2048,
126
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
127
+ prompt_text=reference_text if enable_reference_audio else None,
128
+ )
129
+
130
+ response_queue = queue.Queue()
131
+ llama_queue.put(
132
+ GenerateRequest(
133
+ request=request,
134
+ response_queue=response_queue,
135
+ )
136
+ )
137
+
138
+ if streaming:
139
+ yield wav_chunk_header(), None, None
140
+
141
+ segments = []
142
+
143
+ while True:
144
+ result: WrappedGenerateResponse = response_queue.get()
145
+ if result.status == "error":
146
+ yield None, None, build_html_error_message(result.response)
147
+ break
148
+
149
+ result: GenerateResponse = result.response
150
+ if result.action == "next":
151
+ break
152
+
153
+ with autocast_exclude_mps(
154
+ device_type=decoder_model.device.type, dtype=args.precision
155
+ ):
156
+ fake_audios = decode_vq_tokens(
157
+ decoder_model=decoder_model,
158
+ codes=result.codes,
159
+ )
160
+
161
+ fake_audios = fake_audios.float().cpu().numpy()
162
+ segments.append(fake_audios)
163
+
164
+ if streaming:
165
+ wav_header = wav_chunk_header()
166
+ audio_data = (fake_audios * 32768).astype(np.int16).tobytes()
167
+ yield wav_header + audio_data, None, None
168
+
169
+ if len(segments) == 0:
170
+ return (
171
+ None,
172
+ None,
173
+ build_html_error_message(
174
+ i18n("No audio generated, please check the input text.")
175
+ ),
176
+ )
177
+
178
+ # No matter streaming or not, we need to return the final audio
179
+ audio = np.concatenate(segments, axis=0)
180
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
181
+
182
+ if torch.cuda.is_available():
183
+ torch.cuda.empty_cache()
184
+ gc.collect()
185
+
186
+
187
+ inference_stream = partial(inference, streaming=True)
188
+
189
+ n_audios = 4
190
+
191
+ global_audio_list = []
192
+ global_error_list = []
193
+
194
+
195
+ def inference_wrapper(
196
+ text,
197
+ enable_reference_audio,
198
+ reference_audio,
199
+ reference_text,
200
+ max_new_tokens,
201
+ chunk_length,
202
+ top_p,
203
+ repetition_penalty,
204
+ temperature,
205
+ seed,
206
+ batch_infer_num,
207
+ ):
208
+ audios = []
209
+ errors = []
210
+
211
+ for _ in range(batch_infer_num):
212
+ result = inference(
213
+ text,
214
+ enable_reference_audio,
215
+ reference_audio,
216
+ reference_text,
217
+ max_new_tokens,
218
+ chunk_length,
219
+ top_p,
220
+ repetition_penalty,
221
+ temperature,
222
+ seed,
223
+ )
224
+
225
+ _, audio_data, error_message = next(result)
226
+
227
+ audios.append(
228
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
229
+ )
230
+ errors.append(
231
+ gr.HTML(value=error_message if error_message else None, visible=True),
232
+ )
233
+
234
+ for _ in range(batch_infer_num, n_audios):
235
+ audios.append(
236
+ gr.Audio(value=None, visible=False),
237
+ )
238
+ errors.append(
239
+ gr.HTML(value=None, visible=False),
240
+ )
241
+
242
+ return None, *audios, *errors
243
+
244
+
245
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
246
+ buffer = io.BytesIO()
247
+
248
+ with wave.open(buffer, "wb") as wav_file:
249
+ wav_file.setnchannels(channels)
250
+ wav_file.setsampwidth(bit_depth // 8)
251
+ wav_file.setframerate(sample_rate)
252
+
253
+ wav_header_bytes = buffer.getvalue()
254
+ buffer.close()
255
+ return wav_header_bytes
256
+
257
+
258
+ def normalize_text(user_input, use_normalization):
259
+ if use_normalization:
260
+ return ChnNormedText(raw_text=user_input).normalize()
261
+ else:
262
+ return user_input
263
+
264
+
265
+ def update_examples():
266
+ examples_dir = Path("references")
267
+ examples_dir.mkdir(parents=True, exist_ok=True)
268
+ example_audios = list_files(examples_dir, AUDIO_EXTENSIONS, recursive=True)
269
+ return gr.Dropdown(choices=example_audios + [""])
270
+
271
+
272
+ def build_app():
273
+ with gr.Blocks(theme=gr.themes.Base()) as app:
274
+ gr.Markdown(HEADER_MD)
275
+
276
+ # Use light theme by default
277
+ app.load(
278
+ None,
279
+ None,
280
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
281
+ % args.theme,
282
+ )
283
+
284
+ # Inference
285
+ with gr.Row():
286
+ with gr.Column(scale=3):
287
+ text = gr.Textbox(
288
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
289
+ )
290
+ refined_text = gr.Textbox(
291
+ label=i18n("Realtime Transform Text"),
292
+ placeholder=i18n(
293
+ "Normalization Result Preview (Currently Only Chinese)"
294
+ ),
295
+ lines=5,
296
+ interactive=False,
297
+ )
298
+
299
+ with gr.Row():
300
+ if_refine_text = gr.Checkbox(
301
+ label=i18n("Text Normalization"),
302
+ value=False,
303
+ scale=1,
304
+ )
305
+
306
+ with gr.Row():
307
+ with gr.Column():
308
+ with gr.Tab(label=i18n("Advanced Config")):
309
+ with gr.Row():
310
+ chunk_length = gr.Slider(
311
+ label=i18n("Iterative Prompt Length, 0 means off"),
312
+ minimum=50,
313
+ maximum=300,
314
+ value=200,
315
+ step=8,
316
+ )
317
+
318
+ max_new_tokens = gr.Slider(
319
+ label=i18n(
320
+ "Maximum tokens per batch, 0 means no limit"
321
+ ),
322
+ minimum=0,
323
+ maximum=2048,
324
+ value=0, # 0 means no limit
325
+ step=8,
326
+ )
327
+
328
+ with gr.Row():
329
+ top_p = gr.Slider(
330
+ label="Top-P",
331
+ minimum=0.6,
332
+ maximum=0.9,
333
+ value=0.7,
334
+ step=0.01,
335
+ )
336
+
337
+ repetition_penalty = gr.Slider(
338
+ label=i18n("Repetition Penalty"),
339
+ minimum=1,
340
+ maximum=1.5,
341
+ value=1.2,
342
+ step=0.01,
343
+ )
344
+
345
+ with gr.Row():
346
+ temperature = gr.Slider(
347
+ label="Temperature",
348
+ minimum=0.6,
349
+ maximum=0.9,
350
+ value=0.7,
351
+ step=0.01,
352
+ )
353
+ seed = gr.Textbox(
354
+ label="Seed",
355
+ info="0 means randomized inference, otherwise deterministic",
356
+ placeholder="any 32-bit-integer",
357
+ value="0",
358
+ )
359
+
360
+ with gr.Tab(label=i18n("Reference Audio")):
361
+ with gr.Row():
362
+ gr.Markdown(
363
+ i18n(
364
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
365
+ )
366
+ )
367
+ with gr.Row():
368
+ enable_reference_audio = gr.Checkbox(
369
+ label=i18n("Enable Reference Audio"),
370
+ )
371
+
372
+ with gr.Row():
373
+ example_audio_dropdown = gr.Dropdown(
374
+ label=i18n("Select Example Audio"),
375
+ choices=[""],
376
+ value="",
377
+ interactive=True,
378
+ allow_custom_value=True,
379
+ )
380
+ with gr.Row():
381
+ reference_audio = gr.Audio(
382
+ label=i18n("Reference Audio"),
383
+ type="filepath",
384
+ )
385
+ with gr.Row():
386
+ reference_text = gr.Textbox(
387
+ label=i18n("Reference Text"),
388
+ lines=1,
389
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
390
+ value="",
391
+ )
392
+
393
+ with gr.Tab(label=i18n("Batch Inference")):
394
+ with gr.Row():
395
+ batch_infer_num = gr.Slider(
396
+ label="Batch infer nums",
397
+ minimum=1,
398
+ maximum=n_audios,
399
+ step=1,
400
+ value=1,
401
+ )
402
+
403
+ with gr.Column(scale=3):
404
+ for _ in range(n_audios):
405
+ with gr.Row():
406
+ error = gr.HTML(
407
+ label=i18n("Error Message"),
408
+ visible=True if _ == 0 else False,
409
+ )
410
+ global_error_list.append(error)
411
+ with gr.Row():
412
+ audio = gr.Audio(
413
+ label=i18n("Generated Audio"),
414
+ type="numpy",
415
+ interactive=False,
416
+ visible=True if _ == 0 else False,
417
+ )
418
+ global_audio_list.append(audio)
419
+
420
+ with gr.Row():
421
+ stream_audio = gr.Audio(
422
+ label=i18n("Streaming Audio"),
423
+ streaming=True,
424
+ autoplay=True,
425
+ interactive=False,
426
+ show_download_button=True,
427
+ )
428
+ with gr.Row():
429
+ with gr.Column(scale=3):
430
+ generate = gr.Button(
431
+ value="\U0001F3A7 " + i18n("Generate"), variant="primary"
432
+ )
433
+
434
+ generate_stream = gr.Button(
435
+ value="\U0001F3A7 " + i18n("Streaming Generate"),
436
+ variant="primary",
437
+ visible=False # 隱藏按鈕
438
+ )
439
+
440
+ text.input(
441
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
442
+ )
443
+
444
+ def select_example_audio(audio_path):
445
+ audio_path = Path(audio_path)
446
+ if audio_path.is_file():
447
+ lab_file = Path(audio_path.with_suffix(".lab"))
448
+
449
+ if lab_file.exists():
450
+ lab_content = lab_file.read_text(encoding="utf-8").strip()
451
+ else:
452
+ lab_content = ""
453
+
454
+ return str(audio_path), lab_content, True
455
+ return None, "", False
456
+
457
+ # Connect the dropdown to update reference audio and text
458
+
459
+ example_audio_dropdown.change(
460
+ fn=update_examples, inputs=[], outputs=[example_audio_dropdown]
461
+ ).then(
462
+ fn=select_example_audio,
463
+ inputs=[example_audio_dropdown],
464
+ outputs=[reference_audio, reference_text, enable_reference_audio],
465
+ )
466
+
467
+ # # Submit
468
+ generate.click(
469
+ inference_wrapper,
470
+ [
471
+ refined_text,
472
+ enable_reference_audio,
473
+ reference_audio,
474
+ reference_text,
475
+ max_new_tokens,
476
+ chunk_length,
477
+ top_p,
478
+ repetition_penalty,
479
+ temperature,
480
+ seed,
481
+ batch_infer_num,
482
+ ],
483
+ [stream_audio, *global_audio_list, *global_error_list],
484
+ concurrency_limit=1,
485
+ )
486
+
487
+ generate_stream.click(
488
+ inference_stream,
489
+ [
490
+ refined_text,
491
+ enable_reference_audio,
492
+ reference_audio,
493
+ reference_text,
494
+ max_new_tokens,
495
+ chunk_length,
496
+ top_p,
497
+ repetition_penalty,
498
+ temperature,
499
+ seed,
500
+ ],
501
+ [stream_audio, global_audio_list[0], global_error_list[0]],
502
+ concurrency_limit=1,
503
+ )
504
+
505
+ return app
506
+
507
+
508
+ def parse_args():
509
+ parser = ArgumentParser()
510
+ parser.add_argument(
511
+ "--llama-checkpoint-path",
512
+ type=Path,
513
+ default="checkpoints/fish-speech-1.2",
514
+ )
515
+ parser.add_argument(
516
+ "--decoder-checkpoint-path",
517
+ type=Path,
518
+ default="checkpoints/fish-speech-1.2/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
519
+ )
520
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
521
+ parser.add_argument("--device", type=str, default="cuda")
522
+ parser.add_argument("--half", action="store_true")
523
+ parser.add_argument("--compile", action="store_true")
524
+ parser.add_argument("--max-gradio-length", type=int, default=0)
525
+ parser.add_argument("--theme", type=str, default="light")
526
+
527
+ return parser.parse_args()
528
+
529
+
530
+ if __name__ == "__main__":
531
+ args = parse_args()
532
+ args.precision = torch.half if args.half else torch.bfloat16
533
+
534
+ logger.info("Loading Llama model...")
535
+ llama_queue = launch_thread_safe_queue(
536
+ checkpoint_path=args.llama_checkpoint_path,
537
+ device=args.device,
538
+ precision=args.precision,
539
+ compile=args.compile,
540
+ )
541
+ logger.info("Llama model loaded, loading VQ-GAN model...")
542
+
543
+ decoder_model = load_decoder_model(
544
+ config_name=args.decoder_config_name,
545
+ checkpoint_path=args.decoder_checkpoint_path,
546
+ device=args.device,
547
+ )
548
+
549
+ logger.info("Decoder model loaded, warming up...")
550
+
551
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
552
+ list(
553
+ inference(
554
+ text="Hello, world!",
555
+ enable_reference_audio=False,
556
+ reference_audio=None,
557
+ reference_text="",
558
+ max_new_tokens=500,
559
+ chunk_length=200,
560
+ top_p=0.7,
561
+ repetition_penalty=1.2,
562
+ temperature=0.7,
563
+ )
564
+ )
565
+
566
+ logger.info("Warming up done, launching the web UI...")
567
+
568
+ app = build_app()
569
+ app.launch(show_api=True, server_name="0.0.0.0",share=True)
570
+
tools/whisper_asr.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Used to transcribe all audio files in one folder into another folder.
3
+ e.g.
4
+ Directory structure:
5
+ --pre_data_root
6
+ ----SP_1
7
+ ------01.wav
8
+ ------02.wav
9
+ ------......
10
+ ----SP_2
11
+ ------01.wav
12
+ ------02.wav
13
+ ------......
14
+ Use
15
+ python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
16
+ to transcribe the first speaker.
17
+
18
+ Use
19
+ python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
20
+ to transcribe the second speaker.
21
+
22
+ Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
23
+ """
24
+
25
+ import re
26
+ from pathlib import Path
27
+
28
+ import click
29
+ import soundfile as sf
30
+ from faster_whisper import WhisperModel
31
+ from loguru import logger
32
+ from pydub import AudioSegment
33
+ from tqdm import tqdm
34
+
35
+ from tools.file import AUDIO_EXTENSIONS, list_files
36
+
37
+
38
+ @click.command()
39
+ @click.option("--model-size", default="large-v3", help="Size of the Whisper model")
40
+ @click.option(
41
+ "--compute-type",
42
+ default="float16",
43
+ help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
44
+ )
45
+ @click.option("--audio-dir", required=True, help="Directory containing audio files")
46
+ @click.option(
47
+ "--save-dir", required=True, help="Directory to save processed audio files"
48
+ )
49
+ @click.option(
50
+ "--sample-rate",
51
+ default=44100,
52
+ type=int,
53
+ help="Output sample rate, default to input sample rate",
54
+ )
55
+ @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
56
+ @click.option("--language", default="auto", help="Language of the transcription")
57
+ @click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
58
+ def main(
59
+ model_size,
60
+ compute_type,
61
+ audio_dir,
62
+ save_dir,
63
+ sample_rate,
64
+ device,
65
+ language,
66
+ initial_prompt,
67
+ ):
68
+ logger.info("Loading / Downloading Faster Whisper model...")
69
+
70
+ model = WhisperModel(
71
+ model_size,
72
+ device=device,
73
+ compute_type=compute_type,
74
+ download_root="faster_whisper",
75
+ )
76
+
77
+ logger.info("Model loaded.")
78
+
79
+ save_path = Path(save_dir)
80
+ save_path.mkdir(parents=True, exist_ok=True)
81
+
82
+ audio_files = list_files(
83
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
84
+ )
85
+
86
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
87
+ file_stem = file_path.stem
88
+ file_suffix = file_path.suffix
89
+
90
+ rel_path = Path(file_path).relative_to(audio_dir)
91
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
92
+
93
+ audio = AudioSegment.from_file(file_path)
94
+
95
+ segments, info = model.transcribe(
96
+ file_path,
97
+ beam_size=5,
98
+ language=None if language == "auto" else language,
99
+ initial_prompt=initial_prompt,
100
+ )
101
+
102
+ print(
103
+ "Detected language '%s' with probability %f"
104
+ % (info.language, info.language_probability)
105
+ )
106
+ print("Total len(ms): ", len(audio))
107
+
108
+ whole_text = None
109
+ for segment in segments:
110
+ id, start, end, text = (
111
+ segment.id,
112
+ segment.start,
113
+ segment.end,
114
+ segment.text,
115
+ )
116
+ print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
117
+ if not whole_text:
118
+ whole_text = text
119
+ else:
120
+ whole_text += ", " + text
121
+
122
+ whole_text += "."
123
+
124
+ audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
125
+ audio.export(audio_save_path, format=file_suffix[1:])
126
+ print(f"Exported {audio_save_path}")
127
+
128
+ transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
129
+ with open(
130
+ transcript_save_path,
131
+ "w",
132
+ encoding="utf-8",
133
+ ) as f:
134
+ f.write(whole_text)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()
139
+ exit(0)
140
+
141
+ audio = AudioSegment.from_wav(
142
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
143
+ )
144
+
145
+ model_size = "large-v3"
146
+
147
+ model = WhisperModel(
148
+ model_size,
149
+ device="cuda",
150
+ compute_type="float16",
151
+ download_root="faster_whisper",
152
+ )
153
+
154
+ segments, info = model.transcribe(
155
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
156
+ beam_size=5,
157
+ )
158
+
159
+ print(
160
+ "Detected language '%s' with probability %f"
161
+ % (info.language, info.language_probability)
162
+ )
163
+ print("Total len(ms): ", len(audio))
164
+
165
+ for i, segment in enumerate(segments):
166
+ print(
167
+ "Segment %03d [%.2fs -> %.2fs] %s"
168
+ % (i, segment.start, segment.end, segment.text)
169
+ )
170
+ start_ms = int(segment.start * 1000)
171
+ end_ms = int(segment.end * 1000)
172
+ segment_audio = audio[start_ms:end_ms]
173
+ segment_audio.export(f"segment_{i:03d}.wav", format="wav")
174
+ print(f"Exported segment_{i:03d}.wav")
175
+
176
+ print("All segments have been exported.")