Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			L4
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			L4
	Enable compile on A10G
Browse files- app.py +2 -2
- tools/llama/generate.py +37 -23
    	
        app.py
    CHANGED
    
    | @@ -251,7 +251,7 @@ def build_app(): | |
| 251 | 
             
                            # speaker,
         | 
| 252 | 
             
                        ],
         | 
| 253 | 
             
                        [audio, error],
         | 
| 254 | 
            -
                         | 
| 255 | 
             
                    )
         | 
| 256 |  | 
| 257 | 
             
                return app
         | 
| @@ -287,7 +287,7 @@ if __name__ == "__main__": | |
| 287 | 
             
                args = parse_args()
         | 
| 288 |  | 
| 289 | 
             
                args.precision = torch.half if args.half else torch.bfloat16
         | 
| 290 | 
            -
                 | 
| 291 |  | 
| 292 | 
             
                logger.info("Loading Llama model...")
         | 
| 293 | 
             
                llama_model, decode_one_token = load_llama_model(
         | 
|  | |
| 251 | 
             
                            # speaker,
         | 
| 252 | 
             
                        ],
         | 
| 253 | 
             
                        [audio, error],
         | 
| 254 | 
            +
                        concurrency_limit=1,
         | 
| 255 | 
             
                    )
         | 
| 256 |  | 
| 257 | 
             
                return app
         | 
|  | |
| 287 | 
             
                args = parse_args()
         | 
| 288 |  | 
| 289 | 
             
                args.precision = torch.half if args.half else torch.bfloat16
         | 
| 290 | 
            +
                args.compile = True
         | 
| 291 |  | 
| 292 | 
             
                logger.info("Loading Llama model...")
         | 
| 293 | 
             
                llama_model, decode_one_token = load_llama_model(
         | 
    	
        tools/llama/generate.py
    CHANGED
    
    | @@ -14,7 +14,7 @@ from loguru import logger | |
| 14 | 
             
            from tqdm import tqdm
         | 
| 15 | 
             
            from transformers import AutoTokenizer
         | 
| 16 |  | 
| 17 | 
            -
            from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
         | 
| 18 | 
             
            from fish_speech.text.clean import clean_text
         | 
| 19 |  | 
| 20 | 
             
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         | 
| @@ -291,11 +291,11 @@ def encode_tokens( | |
| 291 | 
             
            ):
         | 
| 292 | 
             
                string = clean_text(string)
         | 
| 293 |  | 
| 294 | 
            -
                if speaker is  | 
| 295 | 
            -
                     | 
| 296 |  | 
| 297 | 
             
                string = (
         | 
| 298 | 
            -
                    f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|> | 
| 299 | 
             
                )
         | 
| 300 | 
             
                if bos:
         | 
| 301 | 
             
                    string = f"<|begin_of_sequence|>{string}"
         | 
| @@ -309,7 +309,10 @@ def encode_tokens( | |
| 309 | 
             
                tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
         | 
| 310 |  | 
| 311 | 
             
                # Codebooks
         | 
| 312 | 
            -
                zeros =  | 
|  | |
|  | |
|  | |
| 313 | 
             
                prompt = torch.cat((tokens, zeros), dim=0)
         | 
| 314 |  | 
| 315 | 
             
                if prompt_tokens is None:
         | 
| @@ -331,13 +334,23 @@ def encode_tokens( | |
| 331 | 
             
                    )
         | 
| 332 | 
             
                    data = data[:num_codebooks]
         | 
| 333 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 334 | 
             
                # Since 1.0, we use <|semantic|>
         | 
| 335 | 
             
                s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
         | 
| 336 | 
            -
                 | 
| 337 | 
            -
             | 
| 338 | 
            -
                    dtype=torch.int,
         | 
| 339 | 
            -
                    device=device,
         | 
| 340 | 
             
                )
         | 
|  | |
| 341 |  | 
| 342 | 
             
                data = torch.cat((main_token_ids, data), dim=0)
         | 
| 343 | 
             
                prompt = torch.cat((prompt, data), dim=1)
         | 
| @@ -450,6 +463,20 @@ def generate_long( | |
| 450 | 
             
                use_prompt = prompt_text is not None and prompt_tokens is not None
         | 
| 451 | 
             
                encoded = []
         | 
| 452 | 
             
                texts = split_text(text, chunk_length) if iterative_prompt else [text]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 453 | 
             
                for idx, text in enumerate(texts):
         | 
| 454 | 
             
                    encoded.append(
         | 
| 455 | 
             
                        encode_tokens(
         | 
| @@ -457,25 +484,12 @@ def generate_long( | |
| 457 | 
             
                            string=text,
         | 
| 458 | 
             
                            bos=idx == 0 and not use_prompt,
         | 
| 459 | 
             
                            device=device,
         | 
| 460 | 
            -
                            speaker= | 
| 461 | 
             
                            num_codebooks=model.config.num_codebooks,
         | 
| 462 | 
             
                        )
         | 
| 463 | 
             
                    )
         | 
| 464 | 
             
                    logger.info(f"Encoded text: {text}")
         | 
| 465 |  | 
| 466 | 
            -
                if use_prompt:
         | 
| 467 | 
            -
                    encoded_prompt = encode_tokens(
         | 
| 468 | 
            -
                        tokenizer,
         | 
| 469 | 
            -
                        prompt_text,
         | 
| 470 | 
            -
                        prompt_tokens=prompt_tokens,
         | 
| 471 | 
            -
                        bos=True,
         | 
| 472 | 
            -
                        device=device,
         | 
| 473 | 
            -
                        speaker=speaker,
         | 
| 474 | 
            -
                        num_codebooks=model.config.num_codebooks,
         | 
| 475 | 
            -
                    )
         | 
| 476 | 
            -
             | 
| 477 | 
            -
                    encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
         | 
| 478 | 
            -
             | 
| 479 | 
             
                for sample_idx in range(num_samples):
         | 
| 480 | 
             
                    torch.cuda.synchronize()
         | 
| 481 | 
             
                    global_encoded = []
         | 
|  | |
| 14 | 
             
            from tqdm import tqdm
         | 
| 15 | 
             
            from transformers import AutoTokenizer
         | 
| 16 |  | 
| 17 | 
            +
            from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
         | 
| 18 | 
             
            from fish_speech.text.clean import clean_text
         | 
| 19 |  | 
| 20 | 
             
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         | 
|  | |
| 291 | 
             
            ):
         | 
| 292 | 
             
                string = clean_text(string)
         | 
| 293 |  | 
| 294 | 
            +
                if speaker is None:
         | 
| 295 | 
            +
                    speaker = "assistant"
         | 
| 296 |  | 
| 297 | 
             
                string = (
         | 
| 298 | 
            +
                    f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
         | 
| 299 | 
             
                )
         | 
| 300 | 
             
                if bos:
         | 
| 301 | 
             
                    string = f"<|begin_of_sequence|>{string}"
         | 
|  | |
| 309 | 
             
                tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
         | 
| 310 |  | 
| 311 | 
             
                # Codebooks
         | 
| 312 | 
            +
                zeros = (
         | 
| 313 | 
            +
                    torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
         | 
| 314 | 
            +
                    * CODEBOOK_PAD_TOKEN_ID
         | 
| 315 | 
            +
                )
         | 
| 316 | 
             
                prompt = torch.cat((tokens, zeros), dim=0)
         | 
| 317 |  | 
| 318 | 
             
                if prompt_tokens is None:
         | 
|  | |
| 334 | 
             
                    )
         | 
| 335 | 
             
                    data = data[:num_codebooks]
         | 
| 336 |  | 
| 337 | 
            +
                # Add eos token for each codebook
         | 
| 338 | 
            +
                data = torch.cat(
         | 
| 339 | 
            +
                    (
         | 
| 340 | 
            +
                        data,
         | 
| 341 | 
            +
                        torch.ones((data.size(0), 1), dtype=torch.int, device=device)
         | 
| 342 | 
            +
                        * CODEBOOK_EOS_TOKEN_ID,
         | 
| 343 | 
            +
                    ),
         | 
| 344 | 
            +
                    dim=1,
         | 
| 345 | 
            +
                )
         | 
| 346 | 
            +
             | 
| 347 | 
             
                # Since 1.0, we use <|semantic|>
         | 
| 348 | 
             
                s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
         | 
| 349 | 
            +
                end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
         | 
| 350 | 
            +
                main_token_ids = (
         | 
| 351 | 
            +
                    torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
         | 
|  | |
| 352 | 
             
                )
         | 
| 353 | 
            +
                main_token_ids[0, -1] = end_token_id
         | 
| 354 |  | 
| 355 | 
             
                data = torch.cat((main_token_ids, data), dim=0)
         | 
| 356 | 
             
                prompt = torch.cat((prompt, data), dim=1)
         | 
|  | |
| 463 | 
             
                use_prompt = prompt_text is not None and prompt_tokens is not None
         | 
| 464 | 
             
                encoded = []
         | 
| 465 | 
             
                texts = split_text(text, chunk_length) if iterative_prompt else [text]
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                if use_prompt:
         | 
| 468 | 
            +
                    encoded.append(
         | 
| 469 | 
            +
                        encode_tokens(
         | 
| 470 | 
            +
                            tokenizer,
         | 
| 471 | 
            +
                            prompt_text,
         | 
| 472 | 
            +
                            prompt_tokens=prompt_tokens,
         | 
| 473 | 
            +
                            bos=True,
         | 
| 474 | 
            +
                            device=device,
         | 
| 475 | 
            +
                            speaker=speaker,
         | 
| 476 | 
            +
                            num_codebooks=model.config.num_codebooks,
         | 
| 477 | 
            +
                        )
         | 
| 478 | 
            +
                    )
         | 
| 479 | 
            +
             | 
| 480 | 
             
                for idx, text in enumerate(texts):
         | 
| 481 | 
             
                    encoded.append(
         | 
| 482 | 
             
                        encode_tokens(
         | 
|  | |
| 484 | 
             
                            string=text,
         | 
| 485 | 
             
                            bos=idx == 0 and not use_prompt,
         | 
| 486 | 
             
                            device=device,
         | 
| 487 | 
            +
                            speaker=speaker,
         | 
| 488 | 
             
                            num_codebooks=model.config.num_codebooks,
         | 
| 489 | 
             
                        )
         | 
| 490 | 
             
                    )
         | 
| 491 | 
             
                    logger.info(f"Encoded text: {text}")
         | 
| 492 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 493 | 
             
                for sample_idx in range(num_samples):
         | 
| 494 | 
             
                    torch.cuda.synchronize()
         | 
| 495 | 
             
                    global_encoded = []
         | 

