uto1125 commited on
Commit
41e786b
·
verified ·
1 Parent(s): a40e20b

Update tools/llama/generate.py

Browse files
Files changed (1) hide show
  1. tools/llama/generate.py +713 -713
tools/llama/generate.py CHANGED
@@ -1,713 +1,713 @@
1
- import os
2
- import queue
3
- import threading
4
- import time
5
- from contextlib import nullcontext
6
- from dataclasses import dataclass
7
- from pathlib import Path
8
- from typing import Literal, Optional, Tuple, Union
9
-
10
- import click
11
- import hydra
12
- import numpy as np
13
- import torch
14
- import torch._dynamo.config
15
- import torch._inductor.config
16
- from loguru import logger
17
- from tqdm import tqdm
18
-
19
- from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
- from fish_speech.text import clean_text, split_text
21
-
22
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
- torch._inductor.config.coordinate_descent_tuning = True
24
- torch._inductor.config.triton.unique_kernel_names = True
25
-
26
- if hasattr(torch._inductor.config, "fx_graph_cache"):
27
- # Experimental feature to reduce compilation times, will be on by default in future
28
- torch._inductor.config.fx_graph_cache = True
29
-
30
-
31
- from fish_speech.models.text2semantic.llama import (
32
- BaseTransformer,
33
- DualARTransformer,
34
- NaiveTransformer,
35
- )
36
-
37
-
38
- def multinomial_sample_one_no_sync(
39
- probs_sort,
40
- ): # Does multinomial sampling without a cuda synchronization
41
- q = torch.empty_like(probs_sort).exponential_(1)
42
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
43
-
44
-
45
- def logits_to_probs(
46
- logits,
47
- previous_tokens: Optional[torch.Tensor] = None,
48
- temperature: torch.Tensor = 1.0,
49
- top_p: torch.Tensor = 1.0,
50
- repetition_penalty: torch.Tensor = 1.0,
51
- ) -> torch.Tensor:
52
- # Apply repetition penalty
53
- if previous_tokens is not None:
54
- previous_tokens = previous_tokens.long()
55
- score = torch.gather(logits, dim=0, index=previous_tokens)
56
- score = torch.where(
57
- score < 0, score * repetition_penalty, score / repetition_penalty
58
- )
59
- logits.scatter_(dim=0, index=previous_tokens, src=score)
60
-
61
- # Apply top-p sampling
62
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
63
- cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
64
- sorted_indices_to_remove = cum_probs > top_p
65
- sorted_indices_to_remove[0] = False # keep at least one option
66
- indices_to_remove = sorted_indices_to_remove.scatter(
67
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
68
- )
69
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
70
-
71
- logits = logits / max(temperature, 1e-5)
72
-
73
- probs = torch.nn.functional.softmax(logits, dim=-1)
74
- return probs
75
-
76
-
77
- def sample(
78
- logits,
79
- previous_tokens: Optional[torch.Tensor] = None,
80
- **sampling_kwargs,
81
- ) -> Tuple[torch.Tensor, torch.Tensor]:
82
- probs = logits_to_probs(
83
- logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
84
- )
85
- idx_next = multinomial_sample_one_no_sync(probs)
86
- return idx_next, probs
87
-
88
-
89
- def decode_one_token_ar(
90
- model: DualARTransformer,
91
- x: torch.Tensor,
92
- input_pos: torch.Tensor,
93
- previous_tokens: torch.Tensor = None,
94
- **sampling_kwargs,
95
- ) -> torch.Tensor:
96
- x = model.forward_generate(x, input_pos)
97
-
98
- sampling_kwargs_main = sampling_kwargs.copy()
99
- sampling_kwargs_main["temperature"] = 0.1
100
- sampling_kwargs_main["top_p"] = 0.1
101
- sampling_kwargs_main["repetition_penalty"] = 1.0
102
-
103
- codebooks = [
104
- sample(
105
- x.logits,
106
- previous_tokens=None, # Disable repetition penalty for the token codebook
107
- **sampling_kwargs_main,
108
- )[0]
109
- ]
110
-
111
- x = x.hidden_states
112
-
113
- # Cleanup the cache
114
- for layer in model.fast_layers:
115
- layer.attention.kv_cache.k_cache.fill_(0)
116
- layer.attention.kv_cache.v_cache.fill_(0)
117
-
118
- for codebook_idx in range(model.config.num_codebooks):
119
- input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
120
- logits = model.forward_generate_fast(x, input_pos)
121
- a = sample(
122
- logits,
123
- previous_tokens=(
124
- previous_tokens[codebook_idx + 1]
125
- if previous_tokens is not None
126
- else None
127
- ),
128
- **sampling_kwargs,
129
- )[0]
130
- x = model.fast_embeddings(a)
131
- codebooks.append(a)
132
-
133
- return torch.stack(codebooks, dim=0)
134
-
135
-
136
- def decode_one_token_naive(
137
- model: NaiveTransformer,
138
- x: torch.Tensor,
139
- input_pos: torch.Tensor,
140
- previous_tokens: torch.Tensor = None,
141
- **sampling_kwargs,
142
- ) -> torch.Tensor:
143
- x = model.forward_generate(x, input_pos)
144
-
145
- sampling_kwargs_main = sampling_kwargs.copy()
146
- sampling_kwargs_main["temperature"] = 0.1
147
- sampling_kwargs_main["top_p"] = 0.1
148
- sampling_kwargs_main["repetition_penalty"] = 1.0
149
-
150
- codebooks = [
151
- sample(
152
- x.logits,
153
- previous_tokens=None, # Disable repetition penalty for the token codebook
154
- **sampling_kwargs_main,
155
- )[0]
156
- ]
157
-
158
- for i in range(model.config.num_codebooks):
159
- codebooks.append(
160
- sample(
161
- x.codebook_logits[:, :, i],
162
- previous_tokens=(
163
- previous_tokens[i + 1] if previous_tokens is not None else None
164
- ),
165
- **sampling_kwargs,
166
- )[0]
167
- )
168
-
169
- return torch.stack(codebooks, dim=0)
170
-
171
-
172
- def decode_n_tokens(
173
- model: NaiveTransformer,
174
- cur_token: torch.Tensor,
175
- input_pos: torch.Tensor,
176
- num_new_tokens: int,
177
- im_end_id: int = 4,
178
- decode_one_token=decode_one_token_naive,
179
- **sampling_kwargs,
180
- ):
181
- previous_tokens = torch.zeros(
182
- (model.config.num_codebooks + 1, model.config.max_seq_len),
183
- dtype=torch.int,
184
- device=cur_token.device,
185
- )
186
-
187
- for i in tqdm(range(num_new_tokens)):
188
- # We need to get windowed repeat penalty
189
- win_size = 16
190
- if i < win_size:
191
- window = previous_tokens[:, :win_size]
192
- else:
193
- window = previous_tokens[:, i - win_size : i]
194
-
195
- with (
196
- torch.backends.cuda.sdp_kernel(
197
- enable_flash=False, enable_mem_efficient=False, enable_math=True
198
- )
199
- if torch.cuda.is_available()
200
- else nullcontext()
201
- ): # Actually better for Inductor to codegen attention here
202
- next_token = decode_one_token(
203
- model=model,
204
- x=cur_token,
205
- input_pos=input_pos,
206
- previous_tokens=window,
207
- **sampling_kwargs,
208
- )
209
-
210
- input_pos += 1
211
- cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
212
- previous_tokens[:, i : i + 1] = next_token.view(
213
- model.config.num_codebooks + 1, -1
214
- )
215
-
216
- if cur_token[0, 0, -1] == im_end_id:
217
- break
218
-
219
- return previous_tokens[:, : i + 1]
220
-
221
-
222
- @torch.no_grad()
223
- @torch.inference_mode()
224
- def generate(
225
- *,
226
- model: NaiveTransformer,
227
- prompt: torch.Tensor,
228
- max_new_tokens: int,
229
- im_end_id: int = 4,
230
- decode_one_token=decode_one_token_naive,
231
- **sampling_kwargs,
232
- ) -> torch.Tensor:
233
- """
234
- Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
235
- """
236
-
237
- # create an empty tensor of the expected final shape and fill in the current tokens
238
- T = prompt.size(1)
239
-
240
- if max_new_tokens:
241
- if T + max_new_tokens > model.config.max_seq_len:
242
- max_new_tokens = model.config.max_seq_len - T
243
- logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
244
-
245
- T_new = T + max_new_tokens
246
- else:
247
- T_new = model.config.max_seq_len
248
- max_new_tokens = T_new - T
249
-
250
- device, dtype = prompt.device, prompt.dtype
251
- with torch.device(device):
252
- model.setup_caches(
253
- max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
254
- )
255
-
256
- codebook_dim = 1 + model.config.num_codebooks
257
- # create an empty tensor of the expected final shape and fill in the current tokens
258
- empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
259
- empty[:, :T] = prompt
260
- seq = empty
261
- input_pos = torch.arange(0, T, device=device)
262
-
263
- # Use non-accelerated version for now, to avoid compilation overhead
264
- prefill_decode = (
265
- decode_one_token_naive
266
- if isinstance(model, NaiveTransformer)
267
- else decode_one_token_ar
268
- )
269
-
270
- next_token = prefill_decode(
271
- model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
272
- )
273
- seq[:, T : T + 1] = next_token
274
-
275
- input_pos = torch.tensor([T], device=device, dtype=torch.int)
276
- x = decode_n_tokens(
277
- model,
278
- next_token.view(1, codebook_dim, -1),
279
- input_pos,
280
- max_new_tokens - 1,
281
- im_end_id=im_end_id,
282
- decode_one_token=decode_one_token,
283
- **sampling_kwargs,
284
- )
285
- # x = torch.cat(generated_tokens, dim=1)
286
- seq = seq[:, : T + 1 + x.size(1)]
287
- seq[:, T + 1 :] = x
288
-
289
- return seq
290
-
291
-
292
- def encode_tokens(
293
- tokenizer,
294
- string,
295
- device="cuda",
296
- prompt_tokens=None,
297
- num_codebooks=4,
298
- ):
299
- string = clean_text(string)
300
- string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
301
-
302
- new_tokens = tokenizer.encode(
303
- string,
304
- add_special_tokens=False,
305
- max_length=10**6,
306
- truncation=False,
307
- )
308
- tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
309
-
310
- # Codebooks
311
- zeros = (
312
- torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
313
- * CODEBOOK_PAD_TOKEN_ID
314
- )
315
- prompt = torch.cat((tokens, zeros), dim=0)
316
-
317
- if prompt_tokens is None:
318
- return prompt
319
-
320
- # Get prompt tokens
321
- if prompt_tokens.ndim == 3:
322
- assert (
323
- prompt_tokens.shape[0] == 1
324
- ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
325
- prompt_tokens = prompt_tokens[0]
326
-
327
- assert prompt_tokens.ndim == 2
328
- data = prompt_tokens + 1
329
-
330
- if prompt_tokens.shape[0] > num_codebooks:
331
- logger.warning(
332
- f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
333
- )
334
- data = data[:num_codebooks]
335
-
336
- # Add pad token for each codebook
337
- data = torch.cat(
338
- (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
339
- dim=1,
340
- )
341
-
342
- # Since 1.0, we use <|semantic|>
343
- s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
344
- end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
345
- main_token_ids = (
346
- torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
347
- )
348
- main_token_ids[0, -1] = end_token_id
349
-
350
- data = torch.cat((main_token_ids, data), dim=0)
351
- prompt = torch.cat((prompt, data), dim=1)
352
-
353
- return prompt
354
-
355
-
356
- def load_model(checkpoint_path, device, precision, compile=False):
357
- model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
358
- checkpoint_path, load_weights=True
359
- )
360
-
361
- model = model.to(device=device, dtype=precision)
362
- logger.info(f"Restored model from checkpoint")
363
-
364
- if isinstance(model, DualARTransformer):
365
- decode_one_token = decode_one_token_ar
366
- logger.info("Using DualARTransformer")
367
- else:
368
- decode_one_token = decode_one_token_naive
369
- logger.info("Using NaiveTransformer")
370
-
371
- if compile:
372
- logger.info("Compiling function...")
373
- decode_one_token = torch.compile(
374
- decode_one_token,
375
- fullgraph=True,
376
- backend="inductor" if torch.cuda.is_available() else "aot_eager",
377
- mode="reduce-overhead" if torch.cuda.is_available() else None,
378
- )
379
-
380
- return model.eval(), decode_one_token
381
-
382
-
383
- @dataclass
384
- class GenerateResponse:
385
- action: Literal["sample", "next"]
386
- codes: Optional[torch.Tensor] = None
387
- text: Optional[str] = None
388
-
389
-
390
- def generate_long(
391
- *,
392
- model,
393
- device: str | torch.device,
394
- decode_one_token: callable,
395
- text: str,
396
- num_samples: int = 1,
397
- max_new_tokens: int = 0,
398
- top_p: int = 0.7,
399
- repetition_penalty: float = 1.5,
400
- temperature: float = 0.7,
401
- compile: bool = False,
402
- iterative_prompt: bool = True,
403
- max_length: int = 2048,
404
- chunk_length: int = 150,
405
- prompt_text: Optional[str | list[str]] = None,
406
- prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
407
- ):
408
- assert 0 < top_p <= 1, "top_p must be in (0, 1]"
409
- assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
410
- assert 0 < temperature < 2, "temperature must be in (0, 2)"
411
-
412
- use_prompt = prompt_text is not None and prompt_tokens is not None
413
- if use_prompt and isinstance(prompt_text, str):
414
- prompt_text = [prompt_text]
415
- prompt_tokens = [prompt_tokens]
416
-
417
- assert use_prompt is False or len(prompt_text) == len(
418
- prompt_tokens
419
- ), "Prompt text and tokens must have the same length"
420
-
421
- model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
422
- tokenizer = model.tokenizer
423
- im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
424
-
425
- encoded = []
426
- texts = split_text(text, chunk_length) if iterative_prompt else [text]
427
- encoded_prompts = []
428
-
429
- if use_prompt:
430
- for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
431
- encoded_prompts.append(
432
- encode_tokens(
433
- tokenizer,
434
- string=t,
435
- device=device,
436
- prompt_tokens=c,
437
- num_codebooks=model.config.num_codebooks,
438
- )
439
- )
440
-
441
- for idx, text in enumerate(texts):
442
- encoded.append(
443
- encode_tokens(
444
- tokenizer,
445
- string=text,
446
- device=device,
447
- num_codebooks=model.config.num_codebooks,
448
- )
449
- )
450
- logger.info(f"Encoded text: {text}")
451
-
452
- # Move temperature, top_p, repetition_penalty to device
453
- # This is important so that changing params doesn't trigger recompile
454
- temperature = torch.tensor(temperature, device=device, dtype=torch.float)
455
- top_p = torch.tensor(top_p, device=device, dtype=torch.float)
456
- repetition_penalty = torch.tensor(
457
- repetition_penalty, device=device, dtype=torch.float
458
- )
459
-
460
- for sample_idx in range(num_samples):
461
- if torch.cuda.is_available():
462
- torch.cuda.synchronize()
463
-
464
- global_encoded = []
465
- seg_idx = 0
466
-
467
- while seg_idx < len(encoded):
468
- logger.info(
469
- f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
470
- )
471
-
472
- seg = encoded[seg_idx]
473
- global_encoded.append(seg)
474
-
475
- lengths = reversed([seg.size(1) for seg in global_encoded])
476
-
477
- # Pick last 2000 tokens
478
- count = 0
479
- for i, length in enumerate(lengths):
480
- count += length
481
- if count + length > max_length - 1024 - sum(
482
- t.shape[1] for t in encoded_prompts
483
- ):
484
- break
485
-
486
- if i != 0 and i % 2 == 0:
487
- i -= 1
488
-
489
- # Rotate the list, always make sure first segment is included to avoid drift
490
- if i < len(global_encoded) - 2:
491
- partial_encoded = global_encoded[:2] + global_encoded[-i:]
492
- else:
493
- partial_encoded = global_encoded
494
-
495
- if use_prompt:
496
- partial_encoded = encoded_prompts + partial_encoded
497
-
498
- cat_encoded = torch.cat(partial_encoded, dim=1)
499
- prompt_length = cat_encoded.size(1)
500
-
501
- t0 = time.perf_counter()
502
- y = generate(
503
- model=model,
504
- prompt=cat_encoded,
505
- max_new_tokens=max_new_tokens,
506
- im_end_id=im_end_id,
507
- decode_one_token=decode_one_token,
508
- temperature=temperature,
509
- top_p=top_p,
510
- repetition_penalty=repetition_penalty,
511
- )
512
-
513
- if sample_idx == 0 and seg_idx == 0 and compile:
514
- logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
515
-
516
- if torch.cuda.is_available():
517
- torch.cuda.synchronize()
518
-
519
- t = time.perf_counter() - t0
520
-
521
- tokens_generated = y.size(1) - prompt_length
522
- tokens_sec = tokens_generated / t
523
- logger.info(
524
- f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
525
- )
526
- logger.info(
527
- f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
528
- )
529
-
530
- if torch.cuda.is_available():
531
- logger.info(
532
- f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
533
- )
534
-
535
- # Put the generated tokens
536
- # since there is <im_end> and <eos> tokens, we remove last 2 tokens
537
- codes = y[1:, prompt_length:-1].clone()
538
- codes = codes - 1
539
- assert (codes >= 0).all(), f"Negative code found"
540
-
541
- decoded = y[:, prompt_length:-1].clone()
542
- # But for global encoding, we should keep the <im_end> token
543
-
544
- global_encoded.append(decoded)
545
- assert (codes >= 0).all(), f"Negative code found: {codes}"
546
- yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
547
- seg_idx += 1
548
-
549
- # This indicates the end of the current sample
550
- yield GenerateResponse(action="next")
551
-
552
-
553
- @dataclass
554
- class WrappedGenerateResponse:
555
- status: Literal["success", "error"]
556
- response: Optional[GenerateResponse | Exception] = None
557
-
558
-
559
- @dataclass
560
- class GenerateRequest:
561
- request: dict
562
- response_queue: queue.Queue
563
-
564
-
565
- def launch_thread_safe_queue(
566
- checkpoint_path,
567
- device,
568
- precision,
569
- compile: bool = False,
570
- ):
571
- input_queue = queue.Queue()
572
- init_event = threading.Event()
573
-
574
- def worker():
575
- model, decode_one_token = load_model(
576
- checkpoint_path, device, precision, compile=compile
577
- )
578
- init_event.set()
579
-
580
- while True:
581
- item: GenerateRequest | None = input_queue.get()
582
- if item is None:
583
- break
584
-
585
- kwargs = item.request
586
- response_queue = item.response_queue
587
-
588
- try:
589
- for chunk in generate_long(
590
- model=model, decode_one_token=decode_one_token, **kwargs
591
- ):
592
- response_queue.put(
593
- WrappedGenerateResponse(status="success", response=chunk)
594
- )
595
- except Exception as e:
596
- response_queue.put(WrappedGenerateResponse(status="error", response=e))
597
-
598
- threading.Thread(target=worker, daemon=True).start()
599
- init_event.wait()
600
-
601
- return input_queue
602
-
603
-
604
- @click.command()
605
- @click.option(
606
- "--text",
607
- type=str,
608
- default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
609
- )
610
- @click.option("--prompt-text", type=str, default=None, multiple=True)
611
- @click.option(
612
- "--prompt-tokens",
613
- type=click.Path(path_type=Path, exists=True),
614
- default=None,
615
- multiple=True,
616
- )
617
- @click.option("--num-samples", type=int, default=1)
618
- @click.option("--max-new-tokens", type=int, default=0)
619
- @click.option("--top-p", type=float, default=0.7)
620
- @click.option("--repetition-penalty", type=float, default=1.2)
621
- @click.option("--temperature", type=float, default=0.7)
622
- @click.option(
623
- "--checkpoint-path",
624
- type=click.Path(path_type=Path, exists=True),
625
- default="checkpoints/fish-speech-1.4",
626
- )
627
- @click.option("--device", type=str, default="cuda")
628
- @click.option("--compile/--no-compile", default=False)
629
- @click.option("--seed", type=int, default=42)
630
- @click.option("--half/--no-half", default=False)
631
- @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
632
- @click.option("--chunk-length", type=int, default=100)
633
- def main(
634
- text: str,
635
- prompt_text: Optional[list[str]],
636
- prompt_tokens: Optional[list[Path]],
637
- num_samples: int,
638
- max_new_tokens: int,
639
- top_p: int,
640
- repetition_penalty: float,
641
- temperature: float,
642
- checkpoint_path: Path,
643
- device: str,
644
- compile: bool,
645
- seed: int,
646
- half: bool,
647
- iterative_prompt: bool,
648
- chunk_length: int,
649
- ) -> None:
650
-
651
- precision = torch.half if half else torch.bfloat16
652
-
653
- if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
654
- raise ValueError(
655
- f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
656
- )
657
-
658
- logger.info("Loading model ...")
659
- t0 = time.time()
660
- model, decode_one_token = load_model(
661
- checkpoint_path, device, precision, compile=compile
662
- )
663
-
664
- if torch.cuda.is_available():
665
- torch.cuda.synchronize()
666
-
667
- logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
668
-
669
- if prompt_tokens is not None:
670
- prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
671
-
672
- torch.manual_seed(seed)
673
-
674
- if torch.cuda.is_available():
675
- torch.cuda.manual_seed(seed)
676
-
677
- generator = generate_long(
678
- model=model,
679
- device=device,
680
- decode_one_token=decode_one_token,
681
- text=text,
682
- num_samples=num_samples,
683
- max_new_tokens=max_new_tokens,
684
- top_p=top_p,
685
- repetition_penalty=repetition_penalty,
686
- temperature=temperature,
687
- compile=compile,
688
- iterative_prompt=iterative_prompt,
689
- chunk_length=chunk_length,
690
- prompt_text=prompt_text,
691
- prompt_tokens=prompt_tokens,
692
- )
693
-
694
- idx = 0
695
- codes = []
696
-
697
- for response in generator:
698
- if response.action == "sample":
699
- codes.append(response.codes)
700
- logger.info(f"Sampled text: {response.text}")
701
- elif response.action == "next":
702
- if codes:
703
- np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
704
- logger.info(f"Saved codes to codes_{idx}.npy")
705
- logger.info(f"Next sample")
706
- codes = []
707
- idx += 1
708
- else:
709
- logger.error(f"Error: {response}")
710
-
711
-
712
- if __name__ == "__main__":
713
- main()
 
1
+ import os
2
+ import queue
3
+ import threading
4
+ import time
5
+ from contextlib import nullcontext
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Literal, Optional, Tuple, Union
9
+
10
+ import click
11
+ import hydra
12
+ import numpy as np
13
+ import torch
14
+ import torch._dynamo.config
15
+ import torch._inductor.config
16
+ from loguru import logger
17
+ from tqdm import tqdm
18
+
19
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
+ from fish_speech.text import clean_text, split_text
21
+ torch.cuda.is_available = lambda: False
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
+ torch._inductor.config.coordinate_descent_tuning = True
24
+ torch._inductor.config.triton.unique_kernel_names = True
25
+
26
+ if hasattr(torch._inductor.config, "fx_graph_cache"):
27
+ # Experimental feature to reduce compilation times, will be on by default in future
28
+ torch._inductor.config.fx_graph_cache = True
29
+
30
+
31
+ from fish_speech.models.text2semantic.llama import (
32
+ BaseTransformer,
33
+ DualARTransformer,
34
+ NaiveTransformer,
35
+ )
36
+
37
+
38
+ def multinomial_sample_one_no_sync(
39
+ probs_sort,
40
+ ): # Does multinomial sampling without a cuda synchronization
41
+ q = torch.empty_like(probs_sort).exponential_(1)
42
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
43
+
44
+
45
+ def logits_to_probs(
46
+ logits,
47
+ previous_tokens: Optional[torch.Tensor] = None,
48
+ temperature: torch.Tensor = 1.0,
49
+ top_p: torch.Tensor = 1.0,
50
+ repetition_penalty: torch.Tensor = 1.0,
51
+ ) -> torch.Tensor:
52
+ # Apply repetition penalty
53
+ if previous_tokens is not None:
54
+ previous_tokens = previous_tokens.long()
55
+ score = torch.gather(logits, dim=0, index=previous_tokens)
56
+ score = torch.where(
57
+ score < 0, score * repetition_penalty, score / repetition_penalty
58
+ )
59
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
60
+
61
+ # Apply top-p sampling
62
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
63
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
64
+ sorted_indices_to_remove = cum_probs > top_p
65
+ sorted_indices_to_remove[0] = False # keep at least one option
66
+ indices_to_remove = sorted_indices_to_remove.scatter(
67
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
68
+ )
69
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
70
+
71
+ logits = logits / max(temperature, 1e-5)
72
+
73
+ probs = torch.nn.functional.softmax(logits, dim=-1)
74
+ return probs
75
+
76
+
77
+ def sample(
78
+ logits,
79
+ previous_tokens: Optional[torch.Tensor] = None,
80
+ **sampling_kwargs,
81
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
82
+ probs = logits_to_probs(
83
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
84
+ )
85
+ idx_next = multinomial_sample_one_no_sync(probs)
86
+ return idx_next, probs
87
+
88
+
89
+ def decode_one_token_ar(
90
+ model: DualARTransformer,
91
+ x: torch.Tensor,
92
+ input_pos: torch.Tensor,
93
+ previous_tokens: torch.Tensor = None,
94
+ **sampling_kwargs,
95
+ ) -> torch.Tensor:
96
+ x = model.forward_generate(x, input_pos)
97
+
98
+ sampling_kwargs_main = sampling_kwargs.copy()
99
+ sampling_kwargs_main["temperature"] = 0.1
100
+ sampling_kwargs_main["top_p"] = 0.1
101
+ sampling_kwargs_main["repetition_penalty"] = 1.0
102
+
103
+ codebooks = [
104
+ sample(
105
+ x.logits,
106
+ previous_tokens=None, # Disable repetition penalty for the token codebook
107
+ **sampling_kwargs_main,
108
+ )[0]
109
+ ]
110
+
111
+ x = x.hidden_states
112
+
113
+ # Cleanup the cache
114
+ for layer in model.fast_layers:
115
+ layer.attention.kv_cache.k_cache.fill_(0)
116
+ layer.attention.kv_cache.v_cache.fill_(0)
117
+
118
+ for codebook_idx in range(model.config.num_codebooks):
119
+ input_pos = torch.tensor([codebook_idx], device=x.device, dtype=torch.long)
120
+ logits = model.forward_generate_fast(x, input_pos)
121
+ a = sample(
122
+ logits,
123
+ previous_tokens=(
124
+ previous_tokens[codebook_idx + 1]
125
+ if previous_tokens is not None
126
+ else None
127
+ ),
128
+ **sampling_kwargs,
129
+ )[0]
130
+ x = model.fast_embeddings(a)
131
+ codebooks.append(a)
132
+
133
+ return torch.stack(codebooks, dim=0)
134
+
135
+
136
+ def decode_one_token_naive(
137
+ model: NaiveTransformer,
138
+ x: torch.Tensor,
139
+ input_pos: torch.Tensor,
140
+ previous_tokens: torch.Tensor = None,
141
+ **sampling_kwargs,
142
+ ) -> torch.Tensor:
143
+ x = model.forward_generate(x, input_pos)
144
+
145
+ sampling_kwargs_main = sampling_kwargs.copy()
146
+ sampling_kwargs_main["temperature"] = 0.1
147
+ sampling_kwargs_main["top_p"] = 0.1
148
+ sampling_kwargs_main["repetition_penalty"] = 1.0
149
+
150
+ codebooks = [
151
+ sample(
152
+ x.logits,
153
+ previous_tokens=None, # Disable repetition penalty for the token codebook
154
+ **sampling_kwargs_main,
155
+ )[0]
156
+ ]
157
+
158
+ for i in range(model.config.num_codebooks):
159
+ codebooks.append(
160
+ sample(
161
+ x.codebook_logits[:, :, i],
162
+ previous_tokens=(
163
+ previous_tokens[i + 1] if previous_tokens is not None else None
164
+ ),
165
+ **sampling_kwargs,
166
+ )[0]
167
+ )
168
+
169
+ return torch.stack(codebooks, dim=0)
170
+
171
+
172
+ def decode_n_tokens(
173
+ model: NaiveTransformer,
174
+ cur_token: torch.Tensor,
175
+ input_pos: torch.Tensor,
176
+ num_new_tokens: int,
177
+ im_end_id: int = 4,
178
+ decode_one_token=decode_one_token_naive,
179
+ **sampling_kwargs,
180
+ ):
181
+ previous_tokens = torch.zeros(
182
+ (model.config.num_codebooks + 1, model.config.max_seq_len),
183
+ dtype=torch.int,
184
+ device=cur_token.device,
185
+ )
186
+
187
+ for i in tqdm(range(num_new_tokens)):
188
+ # We need to get windowed repeat penalty
189
+ win_size = 16
190
+ if i < win_size:
191
+ window = previous_tokens[:, :win_size]
192
+ else:
193
+ window = previous_tokens[:, i - win_size : i]
194
+
195
+ with (
196
+ torch.backends.cuda.sdp_kernel(
197
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
198
+ )
199
+ if torch.cuda.is_available()
200
+ else nullcontext()
201
+ ): # Actually better for Inductor to codegen attention here
202
+ next_token = decode_one_token(
203
+ model=model,
204
+ x=cur_token,
205
+ input_pos=input_pos,
206
+ previous_tokens=window,
207
+ **sampling_kwargs,
208
+ )
209
+
210
+ input_pos += 1
211
+ cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
212
+ previous_tokens[:, i : i + 1] = next_token.view(
213
+ model.config.num_codebooks + 1, -1
214
+ )
215
+
216
+ if cur_token[0, 0, -1] == im_end_id:
217
+ break
218
+
219
+ return previous_tokens[:, : i + 1]
220
+
221
+
222
+ @torch.no_grad()
223
+ @torch.inference_mode()
224
+ def generate(
225
+ *,
226
+ model: NaiveTransformer,
227
+ prompt: torch.Tensor,
228
+ max_new_tokens: int,
229
+ im_end_id: int = 4,
230
+ decode_one_token=decode_one_token_naive,
231
+ **sampling_kwargs,
232
+ ) -> torch.Tensor:
233
+ """
234
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
235
+ """
236
+
237
+ # create an empty tensor of the expected final shape and fill in the current tokens
238
+ T = prompt.size(1)
239
+
240
+ if max_new_tokens:
241
+ if T + max_new_tokens > model.config.max_seq_len:
242
+ max_new_tokens = model.config.max_seq_len - T
243
+ logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
244
+
245
+ T_new = T + max_new_tokens
246
+ else:
247
+ T_new = model.config.max_seq_len
248
+ max_new_tokens = T_new - T
249
+
250
+ device, dtype = prompt.device, prompt.dtype
251
+ with torch.device(device):
252
+ model.setup_caches(
253
+ max_batch_size=1, max_seq_len=T_new, dtype=next(model.parameters()).dtype
254
+ )
255
+
256
+ codebook_dim = 1 + model.config.num_codebooks
257
+ # create an empty tensor of the expected final shape and fill in the current tokens
258
+ empty = torch.empty((codebook_dim, T_new), dtype=dtype, device=device)
259
+ empty[:, :T] = prompt
260
+ seq = empty
261
+ input_pos = torch.arange(0, T, device=device)
262
+
263
+ # Use non-accelerated version for now, to avoid compilation overhead
264
+ prefill_decode = (
265
+ decode_one_token_naive
266
+ if isinstance(model, NaiveTransformer)
267
+ else decode_one_token_ar
268
+ )
269
+
270
+ next_token = prefill_decode(
271
+ model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
272
+ )
273
+ seq[:, T : T + 1] = next_token
274
+
275
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
276
+ x = decode_n_tokens(
277
+ model,
278
+ next_token.view(1, codebook_dim, -1),
279
+ input_pos,
280
+ max_new_tokens - 1,
281
+ im_end_id=im_end_id,
282
+ decode_one_token=decode_one_token,
283
+ **sampling_kwargs,
284
+ )
285
+ # x = torch.cat(generated_tokens, dim=1)
286
+ seq = seq[:, : T + 1 + x.size(1)]
287
+ seq[:, T + 1 :] = x
288
+
289
+ return seq
290
+
291
+
292
+ def encode_tokens(
293
+ tokenizer,
294
+ string,
295
+ device="cuda",
296
+ prompt_tokens=None,
297
+ num_codebooks=4,
298
+ ):
299
+ string = clean_text(string)
300
+ string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
301
+
302
+ new_tokens = tokenizer.encode(
303
+ string,
304
+ add_special_tokens=False,
305
+ max_length=10**6,
306
+ truncation=False,
307
+ )
308
+ tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
309
+
310
+ # Codebooks
311
+ zeros = (
312
+ torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
313
+ * CODEBOOK_PAD_TOKEN_ID
314
+ )
315
+ prompt = torch.cat((tokens, zeros), dim=0)
316
+
317
+ if prompt_tokens is None:
318
+ return prompt
319
+
320
+ # Get prompt tokens
321
+ if prompt_tokens.ndim == 3:
322
+ assert (
323
+ prompt_tokens.shape[0] == 1
324
+ ), f"3 dim prompt tokens should have shape (1, num_codebooks, seq_len)"
325
+ prompt_tokens = prompt_tokens[0]
326
+
327
+ assert prompt_tokens.ndim == 2
328
+ data = prompt_tokens + 1
329
+
330
+ if prompt_tokens.shape[0] > num_codebooks:
331
+ logger.warning(
332
+ f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
333
+ )
334
+ data = data[:num_codebooks]
335
+
336
+ # Add pad token for each codebook
337
+ data = torch.cat(
338
+ (data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
339
+ dim=1,
340
+ )
341
+
342
+ # Since 1.0, we use <|semantic|>
343
+ s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
344
+ end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
345
+ main_token_ids = (
346
+ torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
347
+ )
348
+ main_token_ids[0, -1] = end_token_id
349
+
350
+ data = torch.cat((main_token_ids, data), dim=0)
351
+ prompt = torch.cat((prompt, data), dim=1)
352
+
353
+ return prompt
354
+
355
+
356
+ def load_model(checkpoint_path, device, precision, compile=False):
357
+ model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
358
+ checkpoint_path, load_weights=True
359
+ )
360
+
361
+ model = model.to(device=device, dtype=precision)
362
+ logger.info(f"Restored model from checkpoint")
363
+
364
+ if isinstance(model, DualARTransformer):
365
+ decode_one_token = decode_one_token_ar
366
+ logger.info("Using DualARTransformer")
367
+ else:
368
+ decode_one_token = decode_one_token_naive
369
+ logger.info("Using NaiveTransformer")
370
+
371
+ if compile:
372
+ logger.info("Compiling function...")
373
+ decode_one_token = torch.compile(
374
+ decode_one_token,
375
+ fullgraph=True,
376
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
377
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
378
+ )
379
+
380
+ return model.eval(), decode_one_token
381
+
382
+
383
+ @dataclass
384
+ class GenerateResponse:
385
+ action: Literal["sample", "next"]
386
+ codes: Optional[torch.Tensor] = None
387
+ text: Optional[str] = None
388
+
389
+
390
+ def generate_long(
391
+ *,
392
+ model,
393
+ device: str | torch.device,
394
+ decode_one_token: callable,
395
+ text: str,
396
+ num_samples: int = 1,
397
+ max_new_tokens: int = 0,
398
+ top_p: int = 0.7,
399
+ repetition_penalty: float = 1.5,
400
+ temperature: float = 0.7,
401
+ compile: bool = False,
402
+ iterative_prompt: bool = True,
403
+ max_length: int = 2048,
404
+ chunk_length: int = 150,
405
+ prompt_text: Optional[str | list[str]] = None,
406
+ prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
407
+ ):
408
+ assert 0 < top_p <= 1, "top_p must be in (0, 1]"
409
+ assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
410
+ assert 0 < temperature < 2, "temperature must be in (0, 2)"
411
+
412
+ use_prompt = prompt_text is not None and prompt_tokens is not None
413
+ if use_prompt and isinstance(prompt_text, str):
414
+ prompt_text = [prompt_text]
415
+ prompt_tokens = [prompt_tokens]
416
+
417
+ assert use_prompt is False or len(prompt_text) == len(
418
+ prompt_tokens
419
+ ), "Prompt text and tokens must have the same length"
420
+
421
+ model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
422
+ tokenizer = model.tokenizer
423
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
424
+
425
+ encoded = []
426
+ texts = split_text(text, chunk_length) if iterative_prompt else [text]
427
+ encoded_prompts = []
428
+
429
+ if use_prompt:
430
+ for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
431
+ encoded_prompts.append(
432
+ encode_tokens(
433
+ tokenizer,
434
+ string=t,
435
+ device=device,
436
+ prompt_tokens=c,
437
+ num_codebooks=model.config.num_codebooks,
438
+ )
439
+ )
440
+
441
+ for idx, text in enumerate(texts):
442
+ encoded.append(
443
+ encode_tokens(
444
+ tokenizer,
445
+ string=text,
446
+ device=device,
447
+ num_codebooks=model.config.num_codebooks,
448
+ )
449
+ )
450
+ logger.info(f"Encoded text: {text}")
451
+
452
+ # Move temperature, top_p, repetition_penalty to device
453
+ # This is important so that changing params doesn't trigger recompile
454
+ temperature = torch.tensor(temperature, device=device, dtype=torch.float)
455
+ top_p = torch.tensor(top_p, device=device, dtype=torch.float)
456
+ repetition_penalty = torch.tensor(
457
+ repetition_penalty, device=device, dtype=torch.float
458
+ )
459
+
460
+ for sample_idx in range(num_samples):
461
+ if torch.cuda.is_available():
462
+ torch.cuda.synchronize()
463
+
464
+ global_encoded = []
465
+ seg_idx = 0
466
+
467
+ while seg_idx < len(encoded):
468
+ logger.info(
469
+ f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
470
+ )
471
+
472
+ seg = encoded[seg_idx]
473
+ global_encoded.append(seg)
474
+
475
+ lengths = reversed([seg.size(1) for seg in global_encoded])
476
+
477
+ # Pick last 2000 tokens
478
+ count = 0
479
+ for i, length in enumerate(lengths):
480
+ count += length
481
+ if count + length > max_length - 1024 - sum(
482
+ t.shape[1] for t in encoded_prompts
483
+ ):
484
+ break
485
+
486
+ if i != 0 and i % 2 == 0:
487
+ i -= 1
488
+
489
+ # Rotate the list, always make sure first segment is included to avoid drift
490
+ if i < len(global_encoded) - 2:
491
+ partial_encoded = global_encoded[:2] + global_encoded[-i:]
492
+ else:
493
+ partial_encoded = global_encoded
494
+
495
+ if use_prompt:
496
+ partial_encoded = encoded_prompts + partial_encoded
497
+
498
+ cat_encoded = torch.cat(partial_encoded, dim=1)
499
+ prompt_length = cat_encoded.size(1)
500
+
501
+ t0 = time.perf_counter()
502
+ y = generate(
503
+ model=model,
504
+ prompt=cat_encoded,
505
+ max_new_tokens=max_new_tokens,
506
+ im_end_id=im_end_id,
507
+ decode_one_token=decode_one_token,
508
+ temperature=temperature,
509
+ top_p=top_p,
510
+ repetition_penalty=repetition_penalty,
511
+ )
512
+
513
+ if sample_idx == 0 and seg_idx == 0 and compile:
514
+ logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
515
+
516
+ if torch.cuda.is_available():
517
+ torch.cuda.synchronize()
518
+
519
+ t = time.perf_counter() - t0
520
+
521
+ tokens_generated = y.size(1) - prompt_length
522
+ tokens_sec = tokens_generated / t
523
+ logger.info(
524
+ f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
525
+ )
526
+ logger.info(
527
+ f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
528
+ )
529
+
530
+ if torch.cuda.is_available():
531
+ logger.info(
532
+ f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
533
+ )
534
+
535
+ # Put the generated tokens
536
+ # since there is <im_end> and <eos> tokens, we remove last 2 tokens
537
+ codes = y[1:, prompt_length:-1].clone()
538
+ codes = codes - 1
539
+ assert (codes >= 0).all(), f"Negative code found"
540
+
541
+ decoded = y[:, prompt_length:-1].clone()
542
+ # But for global encoding, we should keep the <im_end> token
543
+
544
+ global_encoded.append(decoded)
545
+ assert (codes >= 0).all(), f"Negative code found: {codes}"
546
+ yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
547
+ seg_idx += 1
548
+
549
+ # This indicates the end of the current sample
550
+ yield GenerateResponse(action="next")
551
+
552
+
553
+ @dataclass
554
+ class WrappedGenerateResponse:
555
+ status: Literal["success", "error"]
556
+ response: Optional[GenerateResponse | Exception] = None
557
+
558
+
559
+ @dataclass
560
+ class GenerateRequest:
561
+ request: dict
562
+ response_queue: queue.Queue
563
+
564
+
565
+ def launch_thread_safe_queue(
566
+ checkpoint_path,
567
+ device,
568
+ precision,
569
+ compile: bool = False,
570
+ ):
571
+ input_queue = queue.Queue()
572
+ init_event = threading.Event()
573
+
574
+ def worker():
575
+ model, decode_one_token = load_model(
576
+ checkpoint_path, device, precision, compile=compile
577
+ )
578
+ init_event.set()
579
+
580
+ while True:
581
+ item: GenerateRequest | None = input_queue.get()
582
+ if item is None:
583
+ break
584
+
585
+ kwargs = item.request
586
+ response_queue = item.response_queue
587
+
588
+ try:
589
+ for chunk in generate_long(
590
+ model=model, decode_one_token=decode_one_token, **kwargs
591
+ ):
592
+ response_queue.put(
593
+ WrappedGenerateResponse(status="success", response=chunk)
594
+ )
595
+ except Exception as e:
596
+ response_queue.put(WrappedGenerateResponse(status="error", response=e))
597
+
598
+ threading.Thread(target=worker, daemon=True).start()
599
+ init_event.wait()
600
+
601
+ return input_queue
602
+
603
+
604
+ @click.command()
605
+ @click.option(
606
+ "--text",
607
+ type=str,
608
+ default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
609
+ )
610
+ @click.option("--prompt-text", type=str, default=None, multiple=True)
611
+ @click.option(
612
+ "--prompt-tokens",
613
+ type=click.Path(path_type=Path, exists=True),
614
+ default=None,
615
+ multiple=True,
616
+ )
617
+ @click.option("--num-samples", type=int, default=1)
618
+ @click.option("--max-new-tokens", type=int, default=0)
619
+ @click.option("--top-p", type=float, default=0.7)
620
+ @click.option("--repetition-penalty", type=float, default=1.2)
621
+ @click.option("--temperature", type=float, default=0.7)
622
+ @click.option(
623
+ "--checkpoint-path",
624
+ type=click.Path(path_type=Path, exists=True),
625
+ default="checkpoints/fish-speech-1.4",
626
+ )
627
+ @click.option("--device", type=str, default="cuda")
628
+ @click.option("--compile/--no-compile", default=False)
629
+ @click.option("--seed", type=int, default=42)
630
+ @click.option("--half/--no-half", default=False)
631
+ @click.option("--iterative-prompt/--no-iterative-prompt", default=True)
632
+ @click.option("--chunk-length", type=int, default=100)
633
+ def main(
634
+ text: str,
635
+ prompt_text: Optional[list[str]],
636
+ prompt_tokens: Optional[list[Path]],
637
+ num_samples: int,
638
+ max_new_tokens: int,
639
+ top_p: int,
640
+ repetition_penalty: float,
641
+ temperature: float,
642
+ checkpoint_path: Path,
643
+ device: str,
644
+ compile: bool,
645
+ seed: int,
646
+ half: bool,
647
+ iterative_prompt: bool,
648
+ chunk_length: int,
649
+ ) -> None:
650
+
651
+ precision = torch.half if half else torch.bfloat16
652
+
653
+ if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
654
+ raise ValueError(
655
+ f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
656
+ )
657
+
658
+ logger.info("Loading model ...")
659
+ t0 = time.time()
660
+ model, decode_one_token = load_model(
661
+ checkpoint_path, device, precision, compile=compile
662
+ )
663
+
664
+ if torch.cuda.is_available():
665
+ torch.cuda.synchronize()
666
+
667
+ logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
668
+
669
+ if prompt_tokens is not None:
670
+ prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
671
+
672
+ torch.manual_seed(seed)
673
+
674
+ if torch.cuda.is_available():
675
+ torch.cuda.manual_seed(seed)
676
+
677
+ generator = generate_long(
678
+ model=model,
679
+ device=device,
680
+ decode_one_token=decode_one_token,
681
+ text=text,
682
+ num_samples=num_samples,
683
+ max_new_tokens=max_new_tokens,
684
+ top_p=top_p,
685
+ repetition_penalty=repetition_penalty,
686
+ temperature=temperature,
687
+ compile=compile,
688
+ iterative_prompt=iterative_prompt,
689
+ chunk_length=chunk_length,
690
+ prompt_text=prompt_text,
691
+ prompt_tokens=prompt_tokens,
692
+ )
693
+
694
+ idx = 0
695
+ codes = []
696
+
697
+ for response in generator:
698
+ if response.action == "sample":
699
+ codes.append(response.codes)
700
+ logger.info(f"Sampled text: {response.text}")
701
+ elif response.action == "next":
702
+ if codes:
703
+ np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
704
+ logger.info(f"Saved codes to codes_{idx}.npy")
705
+ logger.info(f"Next sample")
706
+ codes = []
707
+ idx += 1
708
+ else:
709
+ logger.error(f"Error: {response}")
710
+
711
+
712
+ if __name__ == "__main__":
713
+ main()