Stardust-minus commited on
Commit
440bab4
·
verified ·
1 Parent(s): 2616e46

Upload folder using huggingface_hub

Browse files
examples/Arabic.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:00f7671345e2c8ac3ce573459b0dc7a363b2dbccfd73dcf4247d0f60afeebdf7
3
- size 501258
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a3c902c13fcf408c95353d91ab65f839d27584d8929c7345317956d1e9ea5bd
3
+ size 131
examples/English.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bc489e06dc43b82e4162db8b161bd3f627802ac9b32ee644f7ab8cf94728d0e3
3
- size 367804
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed744820849c8f16e03cb68e45b7d7d4b8697476a162d50ffe2cd6612a621aa6
3
+ size 131
examples/French.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:74087d181b2410c79b83068be60990a369dddf064d8ec79dbfc6cbb5dd46f84f
3
- size 406794
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dee830ddff631df6e0db0911a20099ddf6438a80d1da597536470ba36e2d645c
3
+ size 131
examples/German.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3b77818cd825a71239cd85bfc256cb84f31b091793fb033388b766cf7535359e
3
- size 572682
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc076529638f0a4bb8d19b509b7781372c26abadcc74a7dcbc5b72b6b1e680fd
3
+ size 131
examples/Japanese.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3034a38260884be854cb4a3f6cb648db85ebdeeb8cab74cfae2a578dc7aaedc2
3
- size 132
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba2a2c07770cb6ab36a5aa6ee953c9914773368e223359e4710897d425a25402
3
+ size 128
examples/Korean.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5767663f0c26f4dc94f45227f385c2be568aac065272466915d65eaa64fdda0f
3
- size 132
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09c122b25a3ad99247179be77deeaa6ead7d93b40092347801948fea34797e48
3
+ size 128
examples/Nice English Ref.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4b707de0cfc5d2eee59dcc3fea495603fe28d95ca64d8202bcdb31537d588782
3
- size 132
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b895ec0d49173630cf9253c70579888cde65129fbaeda167e3b4f91593715eca
3
+ size 128
examples/Spanish.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:268308ad9aaa303456664f305749574d7a7f8cf4c784caf42306e3e612a16ecd
3
- size 379146
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c22d63058f58f46c6a65b6ced8faa969f403b065e822a274342b520e8e20b65f
3
+ size 131
fish_speech/models/dac/modded_dac.py CHANGED
@@ -976,49 +976,3 @@ class DAC(BaseModel, CodecMixin):
976
  z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
977
  x = self.decode(z)
978
  return x[..., :length], vq_results
979
-
980
-
981
- if __name__ == "__main__":
982
-
983
- def filter_state_dict_shapes(params, model):
984
- model_state_dict = model.state_dict()
985
- filtered_state_dict = {
986
- k: v
987
- for k, v in params.items()
988
- if k in model_state_dict and v.shape == model_state_dict[k].shape
989
- }
990
- skipped_keys = set(params.keys()) - set(filtered_state_dict.keys())
991
- if skipped_keys:
992
- print(
993
- f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
994
- )
995
- return filtered_state_dict, skipped_keys
996
-
997
- model = hydra.utils.instantiate(
998
- OmegaConf.load("fish_speech/configs/modded_dac_vq.yaml")
999
- )
1000
- sd = torch.load("checkpoints/openaudio-s1-mini/firefly-gan-large.pth")
1001
- filtered_sd, skipped_keys = filter_state_dict_shapes(sd, model)
1002
- print(f"Skipped keys: {skipped_keys}")
1003
- model.load_state_dict(filtered_sd, strict=False)
1004
- model.eval()
1005
-
1006
- src_audio_path = "./test.wav"
1007
- wave_np, _ = librosa.load(src_audio_path, sr=44100, mono=False)
1008
- if len(wave_np.shape) == 1:
1009
- wave_np = wave_np[None, :]
1010
- wave_tensor = torch.from_numpy(wave_np).unsqueeze(1)
1011
-
1012
- with torch.no_grad():
1013
- # encode 返回 (indices, indices_lens)
1014
- indices, indices_lens = model.encode(wave_tensor)
1015
- print(f"Indices shape: {indices.shape}")
1016
- print(f"Indices lengths: {indices_lens}")
1017
-
1018
- # decode 需要 indices 和 feature_lengths 两个参数
1019
- fake_audio, audio_lengths = model.decode(indices, indices_lens)
1020
- print(f"Decoded audio shape: {fake_audio.shape}")
1021
- print(f"Audio lengths: {audio_lengths}")
1022
-
1023
- # 保存重建的音频
1024
- sf.write("fake.wav", fake_audio.squeeze(1).cpu().numpy().T, 44100)
 
976
  z = vq_results[0] if isinstance(vq_results, tuple) else vq_results.z
977
  x = self.decode(z)
978
  return x[..., :length], vq_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/text2semantic/inference.py CHANGED
@@ -10,7 +10,6 @@ from typing import Literal, Optional, Tuple, Union
10
  import click
11
  import numpy as np
12
  import torch
13
- import torch._dynamo.config
14
  import torch._inductor.config
15
  from loguru import logger
16
  from tqdm import tqdm
@@ -21,9 +20,8 @@ from fish_speech.content_sequence import (
21
  TextPart,
22
  VQPart,
23
  )
24
- from fish_speech.models.text2semantic.llama import BaseModelArgs
25
- from fish_speech.text import clean_text, split_text
26
- from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
27
 
28
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
29
  torch._inductor.config.coordinate_descent_tuning = True
@@ -37,7 +35,6 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
37
  from torch.nn.attention import SDPBackend, sdpa_kernel
38
 
39
  from fish_speech.models.text2semantic.llama import (
40
- BaseTransformer,
41
  DualARTransformer,
42
  NaiveTransformer,
43
  )
@@ -98,16 +95,27 @@ def decode_one_token_ar(
98
  model: DualARTransformer,
99
  x: torch.Tensor,
100
  input_pos: torch.Tensor,
101
- semantic_ids: list,
102
  previous_tokens: torch.Tensor = None,
103
  **sampling_kwargs,
104
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  x = model.forward_generate(x, input_pos)
106
 
107
  sampling_kwargs_main = sampling_kwargs.copy()
108
- # sampling_kwargs_main["temperature"] = 0.1
109
- # sampling_kwargs_main["top_p"] = 0.1
110
- # sampling_kwargs_main["repetition_penalty"] = 1.0
111
 
112
  codebooks = [
113
  sample(
@@ -152,12 +160,7 @@ def decode_one_token_ar(
152
  codebooks.append(a)
153
 
154
  codebooks = torch.stack(codebooks, dim=0)
155
- # semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
156
- # codebooks[1:, :] = torch.masked_fill(
157
- # codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
158
- # )
159
 
160
- # print(codebooks)
161
  return codebooks
162
 
163
 
@@ -166,10 +169,24 @@ def decode_n_tokens(
166
  cur_token: torch.Tensor,
167
  input_pos: torch.Tensor,
168
  num_new_tokens: int,
169
- semantic_ids: list,
170
  decode_one_token=decode_one_token_ar,
171
  **sampling_kwargs,
172
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  previous_tokens = torch.zeros(
174
  (model.config.num_codebooks + 1, model.config.max_seq_len),
175
  dtype=torch.int,
@@ -184,21 +201,14 @@ def decode_n_tokens(
184
  else:
185
  window = previous_tokens[:, i - win_size : i]
186
 
187
- with (
188
- torch.backends.cuda.sdp_kernel(
189
- enable_flash=False, enable_mem_efficient=False, enable_math=True
190
- )
191
- if torch.cuda.is_available()
192
- else nullcontext()
193
- ): # Actually better for Inductor to codegen attention here
194
  next_token = decode_one_token(
195
  model=model,
196
  x=cur_token,
197
  input_pos=input_pos,
198
  previous_tokens=window,
199
- semantic_ids=semantic_ids,
200
  **sampling_kwargs,
201
- )
202
 
203
  input_pos += 1
204
  cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
@@ -223,15 +233,21 @@ def generate(
223
  **sampling_kwargs,
224
  ) -> torch.Tensor:
225
  """
226
- Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
 
 
 
 
 
 
 
 
 
 
 
227
  """
228
 
229
- # create an empty tensor of the expected final shape and fill in the current tokens
230
  T = prompt.size(1)
231
- # semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
232
- semantic_ids = [
233
- model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
234
- ]
235
 
236
  if max_new_tokens:
237
  if T + max_new_tokens > model.config.max_seq_len:
@@ -246,7 +262,6 @@ def generate(
246
  device, dtype = prompt.device, prompt.dtype
247
 
248
  codebook_dim = 1 + model.config.num_codebooks
249
- # create an empty tensor of the expected final shape and fill in the current tokens
250
  empty = torch.empty(
251
  (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
252
  )
@@ -257,33 +272,30 @@ def generate(
257
  # Use non-accelerated version for now, to avoid compilation overhead
258
  prefill_decode = decode_one_token_ar
259
 
260
- next_token = prefill_decode(
261
  model,
262
  prompt.view(1, codebook_dim, -1),
263
  input_pos,
264
- semantic_ids=semantic_ids,
265
  **sampling_kwargs,
266
  )
267
- seq[:, T : T + 1] = next_token
268
 
269
  input_pos = torch.tensor([T], device=device, dtype=torch.int)
270
  x = decode_n_tokens(
271
  model,
272
- next_token.view(1, codebook_dim, -1),
273
  input_pos,
274
  max_new_tokens - 1,
275
  decode_one_token=decode_one_token,
276
- semantic_ids=semantic_ids,
277
  **sampling_kwargs,
278
  )
279
- # x = torch.cat(generated_tokens, dim=1)
280
  seq = seq[:, : T + 1 + x.size(1)]
281
  seq[:, T + 1 :] = x
282
 
283
  return seq
284
 
285
 
286
- def load_model(checkpoint_path, device, precision, compile=False):
287
  model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
288
 
289
  model = model.to(device=device, dtype=precision)
@@ -405,26 +417,6 @@ def generate_long(
405
  seg = encoded[seg_idx]
406
  global_encoded.append(seg)
407
 
408
- # Do not use previous segments to generate current segment for now
409
- # lengths = reversed([seg.size(1) for seg in global_encoded])
410
-
411
- # # Pick last 2000 tokens
412
- # count = 0
413
- # for i, length in enumerate(lengths):
414
- # count += length
415
- # if count + length > max_length - 2048 - encoded_prompts.size(1):
416
- # break
417
-
418
- # if i != 0 and i % 2 == 0:
419
- # i -= 1
420
-
421
- # # Rotate the list, always make sure first segment is included to avoid drift
422
- # if i < len(global_encoded) - 2:
423
- # partial_encoded = global_encoded[:2] + global_encoded[-i:]
424
- # else:
425
- # partial_encoded = global_encoded
426
-
427
- # cat_encoded = torch.cat([encoded_prompts, *partial_encoded], dim=1)
428
  if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
429
  cat_encoded = torch.cat(
430
  [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
@@ -507,7 +499,7 @@ def launch_thread_safe_queue(
507
  init_event = threading.Event()
508
 
509
  def worker():
510
- model, decode_one_token = load_model(
511
  checkpoint_path, device, precision, compile=compile
512
  )
513
  with torch.device(device):
@@ -542,60 +534,6 @@ def launch_thread_safe_queue(
542
  return input_queue
543
 
544
 
545
- def launch_thread_safe_queue_agent(
546
- checkpoint_path,
547
- device,
548
- precision,
549
- compile: bool = False,
550
- ):
551
- input_queue = queue.Queue()
552
- init_event = threading.Event()
553
-
554
- tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
555
- config = BaseModelArgs.from_pretrained(checkpoint_path)
556
-
557
- def worker():
558
- model, decode_one_token = load_model(
559
- checkpoint_path, device, precision, compile=compile, is_agent=True
560
- )
561
-
562
- with torch.device(device):
563
- model.setup_caches(
564
- max_batch_size=1,
565
- max_seq_len=model.config.max_seq_len,
566
- dtype=next(model.parameters()).dtype,
567
- )
568
- init_event.set()
569
-
570
- while True:
571
- item: GenerateRequest | None = input_queue.get()
572
- if item is None:
573
- break
574
-
575
- kwargs = item.request
576
- response_queue = item.response_queue
577
-
578
- try:
579
- for token in generate_agent(
580
- model=model,
581
- decode_one_token=decode_one_token,
582
- **kwargs,
583
- ):
584
- response_queue.put(token)
585
-
586
- response_queue.put("stop")
587
- except Exception as e:
588
- import traceback
589
-
590
- logger.exception(f"Error in worker: {traceback.format_exc()}")
591
- response_queue.put("error")
592
-
593
- threading.Thread(target=worker, daemon=True).start()
594
- init_event.wait()
595
-
596
- return input_queue, tokenizer, config
597
-
598
-
599
  @click.command()
600
  @click.option(
601
  "--text",
@@ -654,7 +592,7 @@ def main(
654
 
655
  logger.info("Loading model ...")
656
  t0 = time.time()
657
- model, decode_one_token = load_model(
658
  checkpoint_path, device, precision, compile=compile
659
  )
660
  with torch.device(device):
 
10
  import click
11
  import numpy as np
12
  import torch
 
13
  import torch._inductor.config
14
  from loguru import logger
15
  from tqdm import tqdm
 
20
  TextPart,
21
  VQPart,
22
  )
23
+ from fish_speech.text import split_text
24
+ from fish_speech.tokenizer import IM_END_TOKEN
 
25
 
26
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
27
  torch._inductor.config.coordinate_descent_tuning = True
 
35
  from torch.nn.attention import SDPBackend, sdpa_kernel
36
 
37
  from fish_speech.models.text2semantic.llama import (
 
38
  DualARTransformer,
39
  NaiveTransformer,
40
  )
 
95
  model: DualARTransformer,
96
  x: torch.Tensor,
97
  input_pos: torch.Tensor,
 
98
  previous_tokens: torch.Tensor = None,
99
  **sampling_kwargs,
100
  ) -> torch.Tensor:
101
+ """
102
+ Generate one token using dual autoregressive transformer for text-to-speech.
103
+
104
+ First generates semantic tokens, then generates acoustic codebook tokens sequentially.
105
+
106
+ Args:
107
+ x: Input token tensor (1, num_codebooks+1, seq_len)
108
+ input_pos: Position indices for input tokens (seq_len,)
109
+ temperature/top_p/repetition_penalty: Sampling parameters (1, 1)
110
+ previous_tokens: Previous tokens for repetition penalty (1, num_codebooks+1, history_seq_len)
111
+ audio_masks/audio_parts: Audio conditioning tensors (num_codebooks, seq_len)
112
+
113
+ Returns:
114
+ Generated tokens tensor (num_codebooks+1, 1) - one token per codebook
115
+ """
116
  x = model.forward_generate(x, input_pos)
117
 
118
  sampling_kwargs_main = sampling_kwargs.copy()
 
 
 
119
 
120
  codebooks = [
121
  sample(
 
160
  codebooks.append(a)
161
 
162
  codebooks = torch.stack(codebooks, dim=0)
 
 
 
 
163
 
 
164
  return codebooks
165
 
166
 
 
169
  cur_token: torch.Tensor,
170
  input_pos: torch.Tensor,
171
  num_new_tokens: int,
 
172
  decode_one_token=decode_one_token_ar,
173
  **sampling_kwargs,
174
  ):
175
+ """
176
+ Generate n tokens iteratively using the model.
177
+
178
+ Args:
179
+ model: The transformer model
180
+ cur_token: Current token tensor of shape (1, num_codebooks+1, seq_len)
181
+ input_pos: Current input position tensor
182
+ num_new_tokens: Number of new tokens to generate
183
+ semantic_ids: List of semantic token IDs
184
+ decode_one_token: Function to decode one token
185
+ **sampling_kwargs: Additional sampling parameters
186
+
187
+ Returns:
188
+ Generated tokens tensor of shape (num_codebooks+1, generated_len)
189
+ """
190
  previous_tokens = torch.zeros(
191
  (model.config.num_codebooks + 1, model.config.max_seq_len),
192
  dtype=torch.int,
 
201
  else:
202
  window = previous_tokens[:, i - win_size : i]
203
 
204
+ with sdpa_kernel(SDPBackend.MATH):
 
 
 
 
 
 
205
  next_token = decode_one_token(
206
  model=model,
207
  x=cur_token,
208
  input_pos=input_pos,
209
  previous_tokens=window,
 
210
  **sampling_kwargs,
211
+ ).clone()
212
 
213
  input_pos += 1
214
  cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
 
233
  **sampling_kwargs,
234
  ) -> torch.Tensor:
235
  """
236
+ Generate tokens from text prompt using the transformer model.
237
+
238
+ Args:
239
+ model: The transformer model for generation
240
+ prompt: Input token tensor of shape (num_codebooks+1, seq_len)
241
+ max_new_tokens: Maximum number of new tokens to generate
242
+ decode_one_token: Function to decode one token at a time
243
+ **sampling_kwargs: Additional sampling parameters (temperature, top_p, repetition_penalty)
244
+
245
+ Returns:
246
+ Generated sequence tensor of shape (num_codebooks+1, total_seq_len)
247
+ where total_seq_len = original_seq_len + generated_tokens_len
248
  """
249
 
 
250
  T = prompt.size(1)
 
 
 
 
251
 
252
  if max_new_tokens:
253
  if T + max_new_tokens > model.config.max_seq_len:
 
262
  device, dtype = prompt.device, prompt.dtype
263
 
264
  codebook_dim = 1 + model.config.num_codebooks
 
265
  empty = torch.empty(
266
  (codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
267
  )
 
272
  # Use non-accelerated version for now, to avoid compilation overhead
273
  prefill_decode = decode_one_token_ar
274
 
275
+ first_token = prefill_decode(
276
  model,
277
  prompt.view(1, codebook_dim, -1),
278
  input_pos,
 
279
  **sampling_kwargs,
280
  )
281
+ seq[:, T : T + 1] = first_token
282
 
283
  input_pos = torch.tensor([T], device=device, dtype=torch.int)
284
  x = decode_n_tokens(
285
  model,
286
+ first_token.view(1, codebook_dim, -1),
287
  input_pos,
288
  max_new_tokens - 1,
289
  decode_one_token=decode_one_token,
 
290
  **sampling_kwargs,
291
  )
 
292
  seq = seq[:, : T + 1 + x.size(1)]
293
  seq[:, T + 1 :] = x
294
 
295
  return seq
296
 
297
 
298
+ def init_model(checkpoint_path, device, precision, compile=False):
299
  model = DualARTransformer.from_pretrained(checkpoint_path, load_weights=True)
300
 
301
  model = model.to(device=device, dtype=precision)
 
417
  seg = encoded[seg_idx]
418
  global_encoded.append(seg)
419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  if len(base_content_sequence.parts) <= 1 and len(global_encoded) >= 2:
421
  cat_encoded = torch.cat(
422
  [encoded_prompts, global_encoded[0], global_encoded[1], seg], dim=1
 
499
  init_event = threading.Event()
500
 
501
  def worker():
502
+ model, decode_one_token = init_model(
503
  checkpoint_path, device, precision, compile=compile
504
  )
505
  with torch.device(device):
 
534
  return input_queue
535
 
536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  @click.command()
538
  @click.option(
539
  "--text",
 
592
 
593
  logger.info("Loading model ...")
594
  t0 = time.time()
595
+ model, decode_one_token = init_model(
596
  checkpoint_path, device, precision, compile=compile
597
  )
598
  with torch.device(device):
tools/download_models.py CHANGED
@@ -22,7 +22,7 @@ def check_and_download_files(repo_id, file_list, local_dir):
22
 
23
 
24
  # 1st
25
- repo_id_1 = "fishaudio/fish-speech-1.5"
26
  local_dir_1 = "./checkpoints/openaudio-s1-mini"
27
  files_1 = [
28
  ".gitattributes",
@@ -31,7 +31,7 @@ files_1 = [
31
  "special_tokens.json",
32
  "tokenizer.tiktoken",
33
  "config.json",
34
- "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
35
  ]
36
 
37
  # 3rd
 
22
 
23
 
24
  # 1st
25
+ repo_id_1 = "fishaudio/openaudio-s1-mini"
26
  local_dir_1 = "./checkpoints/openaudio-s1-mini"
27
  files_1 = [
28
  ".gitattributes",
 
31
  "special_tokens.json",
32
  "tokenizer.tiktoken",
33
  "config.json",
34
+ "codec.pth",
35
  ]
36
 
37
  # 3rd