initial commit
Browse files- README.md +117 -3
 - amplify.py +238 -0
 - config.json +38 -0
 - model.safetensors +3 -0
 - rmsnorm.py +34 -0
 - rotary.py +80 -0
 - special_tokens_map.json +7 -0
 - tokenizer.json +154 -0
 - tokenizer_config.json +58 -0
 
    	
        README.md
    CHANGED
    
    | 
         @@ -1,3 +1,117 @@ 
     | 
|
| 1 | 
         
            -
            ---
         
     | 
| 2 | 
         
            -
            license: mit
         
     | 
| 3 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            license: mit
         
     | 
| 3 | 
         
            +
            datasets:
         
     | 
| 4 | 
         
            +
              - drug-discovery/UR100P
         
     | 
| 5 | 
         
            +
            language:
         
     | 
| 6 | 
         
            +
              - en
         
     | 
| 7 | 
         
            +
            tags:
         
     | 
| 8 | 
         
            +
              - biology
         
     | 
| 9 | 
         
            +
            ---
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            ## AMPLIFY
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            AMPLIFY is an efficient, state-of-the-art protein language model pre-trained using masked language modeling on UniRef100, OAS, and SCOP ([UR100P](https://huggingface.co/datasets/drug-discovery/UR100P)). AMPLIFY can generate residue and protein embeddings, suggest mutations, differentiate disordered proteins from non-protein sequences, and much more. AMPLIFY is available in two sizes, 120M and 350M parameters, with the `_base` models not extended beyond 512 residues (Stage 1). The model architecture and pre-training procedure are detailed below. For more details, please refer to the [accompanying paper](https://www.biorxiv.org/content/10.1101/2024.09.23.614603v1).
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            - [`AMPLIFY_350M`](https://huggingface.co/drug-discovery/AMPLIFY_350M)
         
     | 
| 16 | 
         
            +
            - [`AMPLIFY_350M_base`](https://huggingface.co/drug-discovery/AMPLIFY_350M_base)
         
     | 
| 17 | 
         
            +
            - [`AMPLIFY_120M`](https://huggingface.co/drug-discovery/AMPLIFY_120M)
         
     | 
| 18 | 
         
            +
            - [`AMPLIFY_120M_base`](https://huggingface.co/drug-discovery/AMPLIFY_120M_base)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            ### Model Descritpion
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            |                                | AMPLIFY 120M | AMPLIFY 350M |
         
     | 
| 23 | 
         
            +
            | :----------------------------- | -----------: | -----------: |
         
     | 
| 24 | 
         
            +
            | `hidden-size`                  |          640 |          960 |
         
     | 
| 25 | 
         
            +
            | `num-hidden-layers`            |           24 |           32 |
         
     | 
| 26 | 
         
            +
            | `num-attention-heads`          |           10 |           15 |
         
     | 
| 27 | 
         
            +
            | `intermediate-size`            |         2560 |         3840 |
         
     | 
| 28 | 
         
            +
            | `max-position-embeddings`      |         2048 |         2048 |
         
     | 
| 29 | 
         
            +
            | `vocab-size`                   |           27 |           27 |
         
     | 
| 30 | 
         
            +
            | `rope-theta`                   |        10000 |        10000 |
         
     | 
| 31 | 
         
            +
            | `dropout-prob`                 |            0 |            0 |
         
     | 
| 32 | 
         
            +
            | `embedding-init-range`         |         0.02 |         0.02 |
         
     | 
| 33 | 
         
            +
            | `norm-eps`                     |      1.0e-05 |      1.0e-05 |
         
     | 
| 34 | 
         
            +
            | `hidden-act`                   |       swiglu |       swiglu |
         
     | 
| 35 | 
         
            +
            | `pre-activation-layer-norm`    |         true |         true |
         
     | 
| 36 | 
         
            +
            | `layer-norm-after-embedding`   |        false |        false |
         
     | 
| 37 | 
         
            +
            | `layer-norm-before-last-layer` |         true |         true |
         
     | 
| 38 | 
         
            +
            | `rms-norm`                     |         true |         true |
         
     | 
| 39 | 
         
            +
            | `ffn-bias`                     |        false |        false |
         
     | 
| 40 | 
         
            +
            | `attn-bias`                    |        false |        false |
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            ### Training Descritpion
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            |                     |     Stage 1 |                      Stage 2 |
         
     | 
| 45 | 
         
            +
            | :------------------ | ----------: | ---------------------------: |
         
     | 
| 46 | 
         
            +
            | `dataset`           |      UR100P |                       UR100P |
         
     | 
| 47 | 
         
            +
            | `max-steps`         |     1000000 | 25000 (120M) or 50000 (350M) |
         
     | 
| 48 | 
         
            +
            | `max-length`        |         512 |                         2048 |
         
     | 
| 49 | 
         
            +
            | `optimizer`         |       adamw |                        adamw |
         
     | 
| 50 | 
         
            +
            | `lr`                |       0.001 |                        0.001 |
         
     | 
| 51 | 
         
            +
            | `betas`             | (0.9, 0.95) |                  (0.9, 0.95) |
         
     | 
| 52 | 
         
            +
            | `eps`               |     1.0e-08 |                      1.0e-08 |
         
     | 
| 53 | 
         
            +
            | `weight-decay`      |        0.01 |                         0.01 |
         
     | 
| 54 | 
         
            +
            | `scheduler`         | cosinedecay |                         none |
         
     | 
| 55 | 
         
            +
            | `warmup-steps`      |       1,000 |                         none |
         
     | 
| 56 | 
         
            +
            | `final-step`        |     900,000 |                         none |
         
     | 
| 57 | 
         
            +
            | `warmup-steps`      |       1,000 |                         none |
         
     | 
| 58 | 
         
            +
            | `gradient-clipping` |         1.0 |                          1.0 |
         
     | 
| 59 | 
         
            +
            | `tf32`              |        true |                         true |
         
     | 
| 60 | 
         
            +
            | `mixed-precision`   |        bf16 |                         bf16 |
         
     | 
| 61 | 
         
            +
            | `padding`           |  max-length |                   max-length |
         
     | 
| 62 | 
         
            +
            | `random-truncate`   |        true |                         true |
         
     | 
| 63 | 
         
            +
            | `mask-probability`  |        0.15 |                         0.15 |
         
     | 
| 64 | 
         
            +
            | `total-batch-size`  |        4096 |                         4096 |
         
     | 
| 65 | 
         
            +
            | `deepspeed`         |        true |                         true |
         
     | 
| 66 | 
         
            +
            | `zero-stage`        |           3 |                            3 |
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            ## Get Started
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            ```python
         
     | 
| 71 | 
         
            +
            from transformers import AutoModel
         
     | 
| 72 | 
         
            +
            from transformers import AutoTokenizer
         
     | 
| 73 | 
         
            +
            from datasets import load_dataset
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            # Load AMPLIFY and tokenizer
         
     | 
| 76 | 
         
            +
            model = AutoModel.from_pretrained("drug-discovery/AMPLIFY_350M", trust_remote_code=True)
         
     | 
| 77 | 
         
            +
            tokenizer = AutoTokenizer.from_pretrained("drug-discovery/AMPLIFY_350M", trust_remote_code=True)
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            # Move the model to GPU (required due to Flash Attention)
         
     | 
| 80 | 
         
            +
            model = model.to("cuda")
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            # Load the UniProt validation set
         
     | 
| 83 | 
         
            +
            dataset = load_dataset("drug-discovery/UR100P", data_dir="UniProt", split="test")
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            for sample in dataset:
         
     | 
| 86 | 
         
            +
                # Protein
         
     | 
| 87 | 
         
            +
                print("Sample: ", sample["name"], sample["sequence"])
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                # Tokenize the protein
         
     | 
| 90 | 
         
            +
                input = tokenizer.encode(sample["sequence"], return_tensors="pt")
         
     | 
| 91 | 
         
            +
                print("Input: ", input)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                # Move to the GPU and make a prediction
         
     | 
| 94 | 
         
            +
                input = input.to("cuda")
         
     | 
| 95 | 
         
            +
                output = model(input)
         
     | 
| 96 | 
         
            +
                print("Output: ", output)
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                break
         
     | 
| 99 | 
         
            +
            ```
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            ## Citations
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
            If you find the models useful in your research, we ask that you cite the paper:
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            ```bibtex
         
     | 
| 106 | 
         
            +
            @article{Fournier2024.09.23.614603,
         
     | 
| 107 | 
         
            +
            	title        = {Protein Language Models: Is Scaling Necessary?},
         
     | 
| 108 | 
         
            +
            	author       = {Fournier, Quentin and Vernon, Robert M. and van der Sloot, Almer and Schulz, Benjamin and Chandar, Sarath and Langmead, Christopher James},
         
     | 
| 109 | 
         
            +
            	year         = {2024},
         
     | 
| 110 | 
         
            +
            	journal      = {bioRxiv},
         
     | 
| 111 | 
         
            +
            	publisher    = {Cold Spring Harbor Laboratory},
         
     | 
| 112 | 
         
            +
            	doi          = {10.1101/2024.09.23.614603},
         
     | 
| 113 | 
         
            +
            	url          = {https://www.biorxiv.org/content/early/2024/09/23/2024.09.23.614603},
         
     | 
| 114 | 
         
            +
            	elocation-id = {2024.09.23.614603},
         
     | 
| 115 | 
         
            +
            	eprint       = {https://www.biorxiv.org/content/early/2024/09/23/2024.09.23.614603.full.pdf}
         
     | 
| 116 | 
         
            +
            }
         
     | 
| 117 | 
         
            +
            ```
         
     | 
    	
        amplify.py
    ADDED
    
    | 
         @@ -0,0 +1,238 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # From https://stackoverflow.com/a/23689767
         
     | 
| 2 | 
         
            +
            # From https://github.com/pytorch/pytorch/issues/97899
         
     | 
| 3 | 
         
            +
            # From https://github.com/facebookresearch/llama/blob/main/llama/model.py
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from torch import nn
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from xformers.ops import SwiGLU, memory_efficient_attention
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from .rmsnorm import RMSNorm
         
     | 
| 11 | 
         
            +
            from .rotary import precompute_freqs_cis, apply_rotary_emb
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from transformers import PreTrainedModel, PretrainedConfig
         
     | 
| 14 | 
         
            +
            from transformers.modeling_outputs import MaskedLMOutput
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            class DotDict(dict):
         
     | 
| 17 | 
         
            +
                """Dictionary that supports the dot notation to access attributes (similarly to HuggingFace)."""
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                __getattr__ = dict.get
         
     | 
| 20 | 
         
            +
                __setattr__ = dict.__setitem__
         
     | 
| 21 | 
         
            +
                __delattr__ = dict.__delitem__
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            class AMPLIFYConfig(PretrainedConfig):
         
     | 
| 24 | 
         
            +
                model_type = "AMPLIFY"
         
     | 
| 25 | 
         
            +
                # All config parameters must have a default value.
         
     | 
| 26 | 
         
            +
                def __init__(
         
     | 
| 27 | 
         
            +
                    self,
         
     | 
| 28 | 
         
            +
                    hidden_size: int = 960,
         
     | 
| 29 | 
         
            +
                    num_hidden_layers: int = 32,
         
     | 
| 30 | 
         
            +
                    num_attention_heads: int = 15,
         
     | 
| 31 | 
         
            +
                    intermediate_size: int = 3840,
         
     | 
| 32 | 
         
            +
                    dropout_prob: float = 0,
         
     | 
| 33 | 
         
            +
                    embedding_init_range: float = 0.02,
         
     | 
| 34 | 
         
            +
                    decoder_init_range: float = 0.02,
         
     | 
| 35 | 
         
            +
                    rms_norm: bool = True,
         
     | 
| 36 | 
         
            +
                    norm_eps: float = 1e-05,
         
     | 
| 37 | 
         
            +
                    hidden_act: str = "SwiGLU",
         
     | 
| 38 | 
         
            +
                    layer_norm_after_embedding: bool = False,
         
     | 
| 39 | 
         
            +
                    layer_norm_before_last_layer: bool = True,
         
     | 
| 40 | 
         
            +
                    vocab_size: int = 27,
         
     | 
| 41 | 
         
            +
                    ffn_bias: bool = False,
         
     | 
| 42 | 
         
            +
                    att_bias: bool = False,
         
     | 
| 43 | 
         
            +
                    pad_token_id: int = 0,
         
     | 
| 44 | 
         
            +
                    max_length: int = 2048,
         
     | 
| 45 | 
         
            +
                    **kwargs,
         
     | 
| 46 | 
         
            +
                ):
         
     | 
| 47 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 48 | 
         
            +
                    
         
     | 
| 49 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 50 | 
         
            +
                    self.num_hidden_layers = num_hidden_layers
         
     | 
| 51 | 
         
            +
                    self.num_attention_heads = num_attention_heads
         
     | 
| 52 | 
         
            +
                    self.intermediate_size = intermediate_size
         
     | 
| 53 | 
         
            +
                    self.dropout_prob = dropout_prob
         
     | 
| 54 | 
         
            +
                    self.embedding_init_range = embedding_init_range
         
     | 
| 55 | 
         
            +
                    self.decoder_init_range = decoder_init_range
         
     | 
| 56 | 
         
            +
                    self.rms_norm = rms_norm
         
     | 
| 57 | 
         
            +
                    self.norm_eps = norm_eps
         
     | 
| 58 | 
         
            +
                    self.hidden_act = hidden_act
         
     | 
| 59 | 
         
            +
                    self.layer_norm_after_embedding = layer_norm_after_embedding
         
     | 
| 60 | 
         
            +
                    self.layer_norm_before_last_layer = layer_norm_before_last_layer
         
     | 
| 61 | 
         
            +
                    self.vocab_size = vocab_size
         
     | 
| 62 | 
         
            +
                    self.ffn_bias = ffn_bias
         
     | 
| 63 | 
         
            +
                    self.att_bias = att_bias
         
     | 
| 64 | 
         
            +
                    self.pad_token_id = pad_token_id
         
     | 
| 65 | 
         
            +
                    self.max_length = max_length
         
     | 
| 66 | 
         
            +
                    
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            class EncoderBlock(nn.Module):
         
     | 
| 69 | 
         
            +
                """Transformer encoder block."""
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                def __init__(self, config: AMPLIFYConfig):
         
     | 
| 72 | 
         
            +
                    """Initialize a EncoderBlock.
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    Args:
         
     | 
| 75 | 
         
            +
                        hidden_size (int): _description_
         
     | 
| 76 | 
         
            +
                        num_attention_heads (int): _description_
         
     | 
| 77 | 
         
            +
                        intermediate_size (int, optional): _description_. Defaults to 2048.
         
     | 
| 78 | 
         
            +
                        dropout_prob (float, optional): _description_. Defaults to 0.1.
         
     | 
| 79 | 
         
            +
                        activation (str, optional): _description_. Defaults to "relu".
         
     | 
| 80 | 
         
            +
                        rms_norm (bool, optional): _description_. Defaults to True.
         
     | 
| 81 | 
         
            +
                        norm_eps (float, optional): _description_. Defaults to 1e-5.
         
     | 
| 82 | 
         
            +
                        pad_token_id (int, optional): _description_. Defaults to 0.
         
     | 
| 83 | 
         
            +
                        max_length (int, optional): _description_. Defaults to 2048.
         
     | 
| 84 | 
         
            +
                        ffn_bias (bool, optional): _description_. Defaults to False.
         
     | 
| 85 | 
         
            +
                        att_bias (bool, optional): _description_. Defaults to False.
         
     | 
| 86 | 
         
            +
                    """
         
     | 
| 87 | 
         
            +
                    super().__init__()
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    self.config = config
         
     | 
| 90 | 
         
            +
                    self.d_head = config.hidden_size // config.num_attention_heads
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    # Attention
         
     | 
| 93 | 
         
            +
                    self.q = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=config.att_bias)
         
     | 
| 94 | 
         
            +
                    self.k = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=config.att_bias)
         
     | 
| 95 | 
         
            +
                    self.v = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=config.att_bias)
         
     | 
| 96 | 
         
            +
                    self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=config.att_bias)
         
     | 
| 97 | 
         
            +
                    self.resid_dropout = nn.Dropout(config.dropout_prob)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    # Feedforward network
         
     | 
| 100 | 
         
            +
                    match config.hidden_act.lower():
         
     | 
| 101 | 
         
            +
                        case "swiglu":
         
     | 
| 102 | 
         
            +
                            # To keep the number of parameters and the amount of computation constant, we reduce the number of
         
     | 
| 103 | 
         
            +
                            # hidden units by a factor of 2/3 (https://arxiv.org/pdf/2002.05202.pdf) and make it a multiple of 8 to
         
     | 
| 104 | 
         
            +
                            # avoid RuntimeError due to misaligned operand
         
     | 
| 105 | 
         
            +
                            multiple_of = 8
         
     | 
| 106 | 
         
            +
                            intermediate_size = int(2 * config.intermediate_size / 3)
         
     | 
| 107 | 
         
            +
                            intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
         
     | 
| 108 | 
         
            +
                            self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=config.ffn_bias)
         
     | 
| 109 | 
         
            +
                        case "relu":
         
     | 
| 110 | 
         
            +
                            self.ffn = nn.Sequential(
         
     | 
| 111 | 
         
            +
                                nn.Linear(config.hidden_size, config.intermediate_size, bias=config.ffn_bias),
         
     | 
| 112 | 
         
            +
                                nn.ReLU(),
         
     | 
| 113 | 
         
            +
                                nn.Linear(config.intermediate_size, config.hidden_size, bias=config.ffn_bias),
         
     | 
| 114 | 
         
            +
                            )
         
     | 
| 115 | 
         
            +
                        case "gelu":
         
     | 
| 116 | 
         
            +
                            self.ffn = nn.Sequential(
         
     | 
| 117 | 
         
            +
                                nn.Linear(config.hidden_size, config.intermediate_size, bias=config.ffn_bias),
         
     | 
| 118 | 
         
            +
                                nn.GELU(),
         
     | 
| 119 | 
         
            +
                                nn.Linear(config.intermediate_size, config.hidden_size, bias=config.ffn_bias),
         
     | 
| 120 | 
         
            +
                            )
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    self.attention_norm = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
         
     | 
| 123 | 
         
            +
                    self.ffn_norm = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    self.ffn_dropout = nn.Dropout(config.dropout_prob)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def forward(self, x: torch.Tensor, pad_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
         
     | 
| 128 | 
         
            +
                    attn, contact = self._att_block(self.attention_norm(x), pad_mask, freqs_cis, output_attentions)
         
     | 
| 129 | 
         
            +
                    x = x + attn
         
     | 
| 130 | 
         
            +
                    x = x + self._ff_block(self.ffn_norm(x))
         
     | 
| 131 | 
         
            +
                    return x, contact
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def _att_block(self, x: torch.Tensor, pad_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
         
     | 
| 134 | 
         
            +
                    batch_size, seq_len, _ = x.shape
         
     | 
| 135 | 
         
            +
                    xq, xk, xv = self.q(x), self.k(x), self.v(x)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    # Reshape for rotary embeddings
         
     | 
| 138 | 
         
            +
                    xq = xq.view(batch_size, seq_len, self.config.num_attention_heads, self.d_head)
         
     | 
| 139 | 
         
            +
                    xk = xk.view(batch_size, seq_len, self.config.num_attention_heads, self.d_head)
         
     | 
| 140 | 
         
            +
                    xv = xv.view(batch_size, seq_len, self.config.num_attention_heads, self.d_head)
         
     | 
| 141 | 
         
            +
                    xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    attn = memory_efficient_attention(
         
     | 
| 144 | 
         
            +
                        query=xq,
         
     | 
| 145 | 
         
            +
                        key=xk,
         
     | 
| 146 | 
         
            +
                        value=xv,
         
     | 
| 147 | 
         
            +
                        attn_bias=pad_mask,
         
     | 
| 148 | 
         
            +
                        p=self.config.dropout_prob if self.training else 0,
         
     | 
| 149 | 
         
            +
                    )
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    _attn = None
         
     | 
| 152 | 
         
            +
                    if output_attentions:
         
     | 
| 153 | 
         
            +
                        _attn = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
         
     | 
| 154 | 
         
            +
                        if pad_mask is not None:
         
     | 
| 155 | 
         
            +
                            _attn = _attn + pad_mask
         
     | 
| 156 | 
         
            +
                        _attn = _attn.softmax(-1)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    return self.resid_dropout(self.wo(attn.view(batch_size, seq_len, self.config.num_attention_heads * self.d_head))), _attn
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                def _ff_block(self, x: torch.Tensor):
         
     | 
| 161 | 
         
            +
                    return self.ffn_dropout(self.ffn(x))
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
            class AMPLIFYPreTrainedModel(PreTrainedModel):
         
     | 
| 165 | 
         
            +
                config_class = AMPLIFYConfig
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                def _init_weights(self, module):
         
     | 
| 168 | 
         
            +
                    if isinstance(module, nn.Linear):
         
     | 
| 169 | 
         
            +
                        module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range)
         
     | 
| 170 | 
         
            +
                        if module.bias is not None:
         
     | 
| 171 | 
         
            +
                            module.bias.data.zero_()
         
     | 
| 172 | 
         
            +
                    elif isinstance(module, nn.Embedding):
         
     | 
| 173 | 
         
            +
                        module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            class AMPLIFY(AMPLIFYPreTrainedModel):
         
     | 
| 177 | 
         
            +
                """The main model class.
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                   Args:
         
     | 
| 180 | 
         
            +
                      config (amplify.model.amplify.AMPLIFYConfig): model configuration, usually defined from the Hydra configuration.
         
     | 
| 181 | 
         
            +
                """
         
     | 
| 182 | 
         
            +
                def __init__(self, config: AMPLIFYConfig, **kwargs):
         
     | 
| 183 | 
         
            +
                    super().__init__(config)
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    self.config = config
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    if config.layer_norm_after_embedding:
         
     | 
| 190 | 
         
            +
                        self.layer_norm_1 = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    self.transformer_encoder = nn.ModuleList()
         
     | 
| 193 | 
         
            +
                    for _ in range(config.num_hidden_layers):
         
     | 
| 194 | 
         
            +
                        self.transformer_encoder.append(EncoderBlock(config))
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    if config.layer_norm_before_last_layer:
         
     | 
| 197 | 
         
            +
                        self.layer_norm_2 = RMSNorm(config.hidden_size, config.norm_eps) if config.rms_norm else nn.LayerNorm(config.hidden_size, config.norm_eps)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                    self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    self.freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
         
     | 
| 202 | 
         
            +
                    
         
     | 
| 203 | 
         
            +
                    # Initialize weights and apply final processing
         
     | 
| 204 | 
         
            +
                    self.post_init()
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                def forward(self, src, pad_mask=None, output_hidden_states=False, output_attentions=False):
         
     | 
| 207 | 
         
            +
                    # Initialize
         
     | 
| 208 | 
         
            +
                    hidden_states, attentions = [], []
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
         
     | 
| 211 | 
         
            +
                    if pad_mask is not None and not torch.all(pad_mask == 0):
         
     | 
| 212 | 
         
            +
                        pad_mask = pad_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, pad_mask.size(-1), 1)
         
     | 
| 213 | 
         
            +
                    else:
         
     | 
| 214 | 
         
            +
                        pad_mask = None
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    # RoPE
         
     | 
| 217 | 
         
            +
                    self.freqs_cis = self.freqs_cis.to(src.device, non_blocking=True)
         
     | 
| 218 | 
         
            +
                    freqs_cis = self.freqs_cis[: src.shape[1]]
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                    # Embedding
         
     | 
| 221 | 
         
            +
                    x = self.encoder(src)
         
     | 
| 222 | 
         
            +
                    if self.config.layer_norm_after_embedding:
         
     | 
| 223 | 
         
            +
                        x = self.layer_norm_1(x)
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    # Transformer encoder
         
     | 
| 226 | 
         
            +
                    for layer in self.transformer_encoder:
         
     | 
| 227 | 
         
            +
                        x, attn = layer(x, pad_mask, freqs_cis, output_attentions)
         
     | 
| 228 | 
         
            +
                        if output_hidden_states:
         
     | 
| 229 | 
         
            +
                            hidden_states.append(x)
         
     | 
| 230 | 
         
            +
                        if output_attentions:
         
     | 
| 231 | 
         
            +
                            attentions.append(attn)
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    # Classification head with layer norm
         
     | 
| 234 | 
         
            +
                    logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    # Return logits or the output of the last hidden layer
         
     | 
| 237 | 
         
            +
                    return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
         
     | 
| 238 | 
         
            +
             
     | 
    	
        config.json
    ADDED
    
    | 
         @@ -0,0 +1,38 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "_name_": "AMPLIFY",
         
     | 
| 3 | 
         
            +
              "architectures": [
         
     | 
| 4 | 
         
            +
                "AMPLIFY"
         
     | 
| 5 | 
         
            +
              ],
         
     | 
| 6 | 
         
            +
              "att_bias": false,
         
     | 
| 7 | 
         
            +
              "auto_map": {
         
     | 
| 8 | 
         
            +
                "AutoConfig": "amplify.AMPLIFYConfig",
         
     | 
| 9 | 
         
            +
                "AutoModel": "amplify.AMPLIFY"
         
     | 
| 10 | 
         
            +
              },
         
     | 
| 11 | 
         
            +
              "bias": false,
         
     | 
| 12 | 
         
            +
              "bos_token_id": 3,
         
     | 
| 13 | 
         
            +
              "decoder_init_range": 0.02,
         
     | 
| 14 | 
         
            +
              "dropout_prob": 0,
         
     | 
| 15 | 
         
            +
              "embedding_init_range": 0.02,
         
     | 
| 16 | 
         
            +
              "eos_token_id": 4,
         
     | 
| 17 | 
         
            +
              "ffn_bias": false,
         
     | 
| 18 | 
         
            +
              "hidden_act": "SwiGLU",
         
     | 
| 19 | 
         
            +
              "hidden_size": 640,
         
     | 
| 20 | 
         
            +
              "intermediate_size": 2560,
         
     | 
| 21 | 
         
            +
              "layer_norm_after_embedding": false,
         
     | 
| 22 | 
         
            +
              "layer_norm_before_last_layer": true,
         
     | 
| 23 | 
         
            +
              "mask_token_id": 2,
         
     | 
| 24 | 
         
            +
              "max_length": 2048,
         
     | 
| 25 | 
         
            +
              "model_type": "AMPLIFY",
         
     | 
| 26 | 
         
            +
              "norm_eps": 1e-05,
         
     | 
| 27 | 
         
            +
              "num_attention_heads": 10,
         
     | 
| 28 | 
         
            +
              "num_hidden_layers": 24,
         
     | 
| 29 | 
         
            +
              "other_special_token_ids": null,
         
     | 
| 30 | 
         
            +
              "pad_token_id": 0,
         
     | 
| 31 | 
         
            +
              "pre_activation_layer_norm": true,
         
     | 
| 32 | 
         
            +
              "rms_norm": true,
         
     | 
| 33 | 
         
            +
              "torch_dtype": "float32",
         
     | 
| 34 | 
         
            +
              "transformers_version": "4.38.2",
         
     | 
| 35 | 
         
            +
              "unk_token_id": 1,
         
     | 
| 36 | 
         
            +
              "vocab_path": "conf/tokenizer/amplify_vocab.txt",
         
     | 
| 37 | 
         
            +
              "vocab_size": 27
         
     | 
| 38 | 
         
            +
            }
         
     | 
    	
        model.safetensors
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:5cdd05fcfa647ed4861c13fc5bb6f94c49acf0c0510dbc5ea75a10aaec558170
         
     | 
| 3 | 
         
            +
            size 473126988
         
     | 
    	
        rmsnorm.py
    ADDED
    
    | 
         @@ -0,0 +1,34 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from torch import nn
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class RMSNorm(nn.Module):
         
     | 
| 6 | 
         
            +
                def __init__(self, dim: int, eps: float = 1e-6):
         
     | 
| 7 | 
         
            +
                    """
         
     | 
| 8 | 
         
            +
                    Initialize the RMSNorm normalization layer.
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                    Args:
         
     | 
| 11 | 
         
            +
                        dim (int): The dimension of the input tensor.
         
     | 
| 12 | 
         
            +
                        eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                    Attributes:
         
     | 
| 15 | 
         
            +
                        eps (float): A small value added to the denominator for numerical stability.
         
     | 
| 16 | 
         
            +
                        weight (nn.Parameter): Learnable scaling parameter.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    """
         
     | 
| 19 | 
         
            +
                    super().__init__()
         
     | 
| 20 | 
         
            +
                    self.eps = eps
         
     | 
| 21 | 
         
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def forward(self, x):
         
     | 
| 24 | 
         
            +
                    """
         
     | 
| 25 | 
         
            +
                    Forward pass through the RMSNorm layer.
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    Args:
         
     | 
| 28 | 
         
            +
                        x (torch.Tensor): The input tensor.
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    Returns:
         
     | 
| 31 | 
         
            +
                        torch.Tensor: The output tensor after applying RMSNorm.
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    """
         
     | 
| 34 | 
         
            +
                    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
         
     | 
    	
        rotary.py
    ADDED
    
    | 
         @@ -0,0 +1,80 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            from typing import Tuple
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
         
     | 
| 6 | 
         
            +
                """
         
     | 
| 7 | 
         
            +
                Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
         
     | 
| 10 | 
         
            +
                and the end index 'end'. The 'theta' parameter scales the frequencies.
         
     | 
| 11 | 
         
            +
                The returned tensor contains complex values in complex64 data type.
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                Args:
         
     | 
| 14 | 
         
            +
                    dim (int): Dimension of the frequency tensor.
         
     | 
| 15 | 
         
            +
                    end (int): End index for precomputing frequencies.
         
     | 
| 16 | 
         
            +
                    theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                Returns:
         
     | 
| 19 | 
         
            +
                    torch.Tensor: Precomputed frequency tensor with complex exponentials.
         
     | 
| 20 | 
         
            +
                """
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
         
     | 
| 23 | 
         
            +
                t = torch.arange(end, device=freqs.device)  # type: ignore
         
     | 
| 24 | 
         
            +
                freqs = torch.outer(t, freqs).float()  # type: ignore
         
     | 
| 25 | 
         
            +
                return torch.polar(torch.ones_like(freqs), freqs)  # complex64
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
         
     | 
| 29 | 
         
            +
                """
         
     | 
| 30 | 
         
            +
                Reshape frequency tensor for broadcasting it with another tensor.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
         
     | 
| 33 | 
         
            +
                for the purpose of broadcasting the frequency tensor during element-wise operations.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                Args:
         
     | 
| 36 | 
         
            +
                    freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
         
     | 
| 37 | 
         
            +
                    x (torch.Tensor): Target tensor for broadcasting compatibility.
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                Returns:
         
     | 
| 40 | 
         
            +
                    torch.Tensor: Reshaped frequency tensor.
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                Raises:
         
     | 
| 43 | 
         
            +
                    AssertionError: If the frequency tensor doesn't match the expected shape.
         
     | 
| 44 | 
         
            +
                    AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
         
     | 
| 45 | 
         
            +
                """
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                ndim = x.ndim
         
     | 
| 48 | 
         
            +
                assert 0 <= 1 < ndim
         
     | 
| 49 | 
         
            +
                assert freqs_cis.shape == (x.shape[1], x.shape[-1])
         
     | 
| 50 | 
         
            +
                shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         
     | 
| 51 | 
         
            +
                return freqs_cis.view(*shape)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            def apply_rotary_emb(
         
     | 
| 55 | 
         
            +
                xq: torch.Tensor,
         
     | 
| 56 | 
         
            +
                xk: torch.Tensor,
         
     | 
| 57 | 
         
            +
                freqs_cis: torch.Tensor,
         
     | 
| 58 | 
         
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 59 | 
         
            +
                """
         
     | 
| 60 | 
         
            +
                Apply rotary embeddings to input tensors using the given frequency tensor.
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
         
     | 
| 63 | 
         
            +
                frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
         
     | 
| 64 | 
         
            +
                is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
         
     | 
| 65 | 
         
            +
                returned as real tensors.
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                Args:
         
     | 
| 68 | 
         
            +
                    xq (torch.Tensor): Query tensor to apply rotary embeddings.
         
     | 
| 69 | 
         
            +
                    xk (torch.Tensor): Key tensor to apply rotary embeddings.
         
     | 
| 70 | 
         
            +
                    freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                Returns:
         
     | 
| 73 | 
         
            +
                    Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
         
     | 
| 74 | 
         
            +
                """
         
     | 
| 75 | 
         
            +
                xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
         
     | 
| 76 | 
         
            +
                xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
         
     | 
| 77 | 
         
            +
                freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
         
     | 
| 78 | 
         
            +
                xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
         
     | 
| 79 | 
         
            +
                xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
         
     | 
| 80 | 
         
            +
                return xq_out.type_as(xq), xk_out.type_as(xk)
         
     | 
    	
        special_tokens_map.json
    ADDED
    
    | 
         @@ -0,0 +1,7 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "bos_token": "<bos>",
         
     | 
| 3 | 
         
            +
              "eos_token": "<eos>",
         
     | 
| 4 | 
         
            +
              "mask_token": "<mask>",
         
     | 
| 5 | 
         
            +
              "pad_token": "<pad>",
         
     | 
| 6 | 
         
            +
              "unk_token": "<unk>"
         
     | 
| 7 | 
         
            +
            }
         
     | 
    	
        tokenizer.json
    ADDED
    
    | 
         @@ -0,0 +1,154 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "version": "1.0",
         
     | 
| 3 | 
         
            +
              "truncation": null,
         
     | 
| 4 | 
         
            +
              "padding": null,
         
     | 
| 5 | 
         
            +
              "added_tokens": [
         
     | 
| 6 | 
         
            +
                {
         
     | 
| 7 | 
         
            +
                  "id": 0,
         
     | 
| 8 | 
         
            +
                  "content": "<pad>",
         
     | 
| 9 | 
         
            +
                  "single_word": false,
         
     | 
| 10 | 
         
            +
                  "lstrip": false,
         
     | 
| 11 | 
         
            +
                  "rstrip": false,
         
     | 
| 12 | 
         
            +
                  "normalized": false,
         
     | 
| 13 | 
         
            +
                  "special": true
         
     | 
| 14 | 
         
            +
                },
         
     | 
| 15 | 
         
            +
                {
         
     | 
| 16 | 
         
            +
                  "id": 1,
         
     | 
| 17 | 
         
            +
                  "content": "<unk>",
         
     | 
| 18 | 
         
            +
                  "single_word": false,
         
     | 
| 19 | 
         
            +
                  "lstrip": false,
         
     | 
| 20 | 
         
            +
                  "rstrip": false,
         
     | 
| 21 | 
         
            +
                  "normalized": false,
         
     | 
| 22 | 
         
            +
                  "special": true
         
     | 
| 23 | 
         
            +
                },
         
     | 
| 24 | 
         
            +
                {
         
     | 
| 25 | 
         
            +
                  "id": 2,
         
     | 
| 26 | 
         
            +
                  "content": "<mask>",
         
     | 
| 27 | 
         
            +
                  "single_word": false,
         
     | 
| 28 | 
         
            +
                  "lstrip": false,
         
     | 
| 29 | 
         
            +
                  "rstrip": false,
         
     | 
| 30 | 
         
            +
                  "normalized": false,
         
     | 
| 31 | 
         
            +
                  "special": true
         
     | 
| 32 | 
         
            +
                },
         
     | 
| 33 | 
         
            +
                {
         
     | 
| 34 | 
         
            +
                  "id": 3,
         
     | 
| 35 | 
         
            +
                  "content": "<bos>",
         
     | 
| 36 | 
         
            +
                  "single_word": false,
         
     | 
| 37 | 
         
            +
                  "lstrip": false,
         
     | 
| 38 | 
         
            +
                  "rstrip": false,
         
     | 
| 39 | 
         
            +
                  "normalized": false,
         
     | 
| 40 | 
         
            +
                  "special": true
         
     | 
| 41 | 
         
            +
                },
         
     | 
| 42 | 
         
            +
                {
         
     | 
| 43 | 
         
            +
                  "id": 4,
         
     | 
| 44 | 
         
            +
                  "content": "<eos>",
         
     | 
| 45 | 
         
            +
                  "single_word": false,
         
     | 
| 46 | 
         
            +
                  "lstrip": false,
         
     | 
| 47 | 
         
            +
                  "rstrip": false,
         
     | 
| 48 | 
         
            +
                  "normalized": false,
         
     | 
| 49 | 
         
            +
                  "special": true
         
     | 
| 50 | 
         
            +
                }
         
     | 
| 51 | 
         
            +
              ],
         
     | 
| 52 | 
         
            +
              "normalizer": null,
         
     | 
| 53 | 
         
            +
              "pre_tokenizer": {
         
     | 
| 54 | 
         
            +
                "type": "Split",
         
     | 
| 55 | 
         
            +
                "pattern": {
         
     | 
| 56 | 
         
            +
                  "String": ""
         
     | 
| 57 | 
         
            +
                },
         
     | 
| 58 | 
         
            +
                "behavior": "Removed",
         
     | 
| 59 | 
         
            +
                "invert": false
         
     | 
| 60 | 
         
            +
              },
         
     | 
| 61 | 
         
            +
              "post_processor": {
         
     | 
| 62 | 
         
            +
                "type": "TemplateProcessing",
         
     | 
| 63 | 
         
            +
                "single": [
         
     | 
| 64 | 
         
            +
                  {
         
     | 
| 65 | 
         
            +
                    "SpecialToken": {
         
     | 
| 66 | 
         
            +
                      "id": "<bos>",
         
     | 
| 67 | 
         
            +
                      "type_id": 0
         
     | 
| 68 | 
         
            +
                    }
         
     | 
| 69 | 
         
            +
                  },
         
     | 
| 70 | 
         
            +
                  {
         
     | 
| 71 | 
         
            +
                    "Sequence": {
         
     | 
| 72 | 
         
            +
                      "id": "A",
         
     | 
| 73 | 
         
            +
                      "type_id": 0
         
     | 
| 74 | 
         
            +
                    }
         
     | 
| 75 | 
         
            +
                  },
         
     | 
| 76 | 
         
            +
                  {
         
     | 
| 77 | 
         
            +
                    "SpecialToken": {
         
     | 
| 78 | 
         
            +
                      "id": "<eos>",
         
     | 
| 79 | 
         
            +
                      "type_id": 0
         
     | 
| 80 | 
         
            +
                    }
         
     | 
| 81 | 
         
            +
                  }
         
     | 
| 82 | 
         
            +
                ],
         
     | 
| 83 | 
         
            +
                "pair": [
         
     | 
| 84 | 
         
            +
                  {
         
     | 
| 85 | 
         
            +
                    "Sequence": {
         
     | 
| 86 | 
         
            +
                      "id": "A",
         
     | 
| 87 | 
         
            +
                      "type_id": 0
         
     | 
| 88 | 
         
            +
                    }
         
     | 
| 89 | 
         
            +
                  },
         
     | 
| 90 | 
         
            +
                  {
         
     | 
| 91 | 
         
            +
                    "Sequence": {
         
     | 
| 92 | 
         
            +
                      "id": "B",
         
     | 
| 93 | 
         
            +
                      "type_id": 1
         
     | 
| 94 | 
         
            +
                    }
         
     | 
| 95 | 
         
            +
                  }
         
     | 
| 96 | 
         
            +
                ],
         
     | 
| 97 | 
         
            +
                "special_tokens": {
         
     | 
| 98 | 
         
            +
                  "<bos>": {
         
     | 
| 99 | 
         
            +
                    "id": "<bos>",
         
     | 
| 100 | 
         
            +
                    "ids": [
         
     | 
| 101 | 
         
            +
                      3
         
     | 
| 102 | 
         
            +
                    ],
         
     | 
| 103 | 
         
            +
                    "tokens": [
         
     | 
| 104 | 
         
            +
                      "<bos>"
         
     | 
| 105 | 
         
            +
                    ]
         
     | 
| 106 | 
         
            +
                  },
         
     | 
| 107 | 
         
            +
                  "<eos>": {
         
     | 
| 108 | 
         
            +
                    "id": "<eos>",
         
     | 
| 109 | 
         
            +
                    "ids": [
         
     | 
| 110 | 
         
            +
                      4
         
     | 
| 111 | 
         
            +
                    ],
         
     | 
| 112 | 
         
            +
                    "tokens": [
         
     | 
| 113 | 
         
            +
                      "<eos>"
         
     | 
| 114 | 
         
            +
                    ]
         
     | 
| 115 | 
         
            +
                  }
         
     | 
| 116 | 
         
            +
                }
         
     | 
| 117 | 
         
            +
              },
         
     | 
| 118 | 
         
            +
              "decoder": null,
         
     | 
| 119 | 
         
            +
              "model": {
         
     | 
| 120 | 
         
            +
                "type": "WordPiece",
         
     | 
| 121 | 
         
            +
                "unk_token": "<unk>",
         
     | 
| 122 | 
         
            +
                "continuing_subword_prefix": "##",
         
     | 
| 123 | 
         
            +
                "max_input_chars_per_word": 100,
         
     | 
| 124 | 
         
            +
                "vocab": {
         
     | 
| 125 | 
         
            +
                  "<pad>": 0,
         
     | 
| 126 | 
         
            +
                  "<unk>": 1,
         
     | 
| 127 | 
         
            +
                  "<mask>": 2,
         
     | 
| 128 | 
         
            +
                  "<bos>": 3,
         
     | 
| 129 | 
         
            +
                  "<eos>": 4,
         
     | 
| 130 | 
         
            +
                  "|": 5,
         
     | 
| 131 | 
         
            +
                  "L": 6,
         
     | 
| 132 | 
         
            +
                  "A": 7,
         
     | 
| 133 | 
         
            +
                  "G": 8,
         
     | 
| 134 | 
         
            +
                  "V": 9,
         
     | 
| 135 | 
         
            +
                  "S": 10,
         
     | 
| 136 | 
         
            +
                  "E": 11,
         
     | 
| 137 | 
         
            +
                  "R": 12,
         
     | 
| 138 | 
         
            +
                  "T": 13,
         
     | 
| 139 | 
         
            +
                  "I": 14,
         
     | 
| 140 | 
         
            +
                  "D": 15,
         
     | 
| 141 | 
         
            +
                  "P": 16,
         
     | 
| 142 | 
         
            +
                  "K": 17,
         
     | 
| 143 | 
         
            +
                  "Q": 18,
         
     | 
| 144 | 
         
            +
                  "N": 19,
         
     | 
| 145 | 
         
            +
                  "F": 20,
         
     | 
| 146 | 
         
            +
                  "Y": 21,
         
     | 
| 147 | 
         
            +
                  "M": 22,
         
     | 
| 148 | 
         
            +
                  "H": 23,
         
     | 
| 149 | 
         
            +
                  "W": 24,
         
     | 
| 150 | 
         
            +
                  "C": 25,
         
     | 
| 151 | 
         
            +
                  "B": 26
         
     | 
| 152 | 
         
            +
                }
         
     | 
| 153 | 
         
            +
              }
         
     | 
| 154 | 
         
            +
            }
         
     | 
    	
        tokenizer_config.json
    ADDED
    
    | 
         @@ -0,0 +1,58 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "added_tokens_decoder": {
         
     | 
| 3 | 
         
            +
                "0": {
         
     | 
| 4 | 
         
            +
                  "content": "<pad>",
         
     | 
| 5 | 
         
            +
                  "lstrip": false,
         
     | 
| 6 | 
         
            +
                  "normalized": false,
         
     | 
| 7 | 
         
            +
                  "rstrip": false,
         
     | 
| 8 | 
         
            +
                  "single_word": false,
         
     | 
| 9 | 
         
            +
                  "special": true
         
     | 
| 10 | 
         
            +
                },
         
     | 
| 11 | 
         
            +
                "1": {
         
     | 
| 12 | 
         
            +
                  "content": "<unk>",
         
     | 
| 13 | 
         
            +
                  "lstrip": false,
         
     | 
| 14 | 
         
            +
                  "normalized": false,
         
     | 
| 15 | 
         
            +
                  "rstrip": false,
         
     | 
| 16 | 
         
            +
                  "single_word": false,
         
     | 
| 17 | 
         
            +
                  "special": true
         
     | 
| 18 | 
         
            +
                },
         
     | 
| 19 | 
         
            +
                "2": {
         
     | 
| 20 | 
         
            +
                  "content": "<mask>",
         
     | 
| 21 | 
         
            +
                  "lstrip": false,
         
     | 
| 22 | 
         
            +
                  "normalized": false,
         
     | 
| 23 | 
         
            +
                  "rstrip": false,
         
     | 
| 24 | 
         
            +
                  "single_word": false,
         
     | 
| 25 | 
         
            +
                  "special": true
         
     | 
| 26 | 
         
            +
                },
         
     | 
| 27 | 
         
            +
                "3": {
         
     | 
| 28 | 
         
            +
                  "content": "<bos>",
         
     | 
| 29 | 
         
            +
                  "lstrip": false,
         
     | 
| 30 | 
         
            +
                  "normalized": false,
         
     | 
| 31 | 
         
            +
                  "rstrip": false,
         
     | 
| 32 | 
         
            +
                  "single_word": false,
         
     | 
| 33 | 
         
            +
                  "special": true
         
     | 
| 34 | 
         
            +
                },
         
     | 
| 35 | 
         
            +
                "4": {
         
     | 
| 36 | 
         
            +
                  "content": "<eos>",
         
     | 
| 37 | 
         
            +
                  "lstrip": false,
         
     | 
| 38 | 
         
            +
                  "normalized": false,
         
     | 
| 39 | 
         
            +
                  "rstrip": false,
         
     | 
| 40 | 
         
            +
                  "single_word": false,
         
     | 
| 41 | 
         
            +
                  "special": true
         
     | 
| 42 | 
         
            +
                }
         
     | 
| 43 | 
         
            +
              },
         
     | 
| 44 | 
         
            +
              "bos_token": "<bos>",
         
     | 
| 45 | 
         
            +
              "clean_up_tokenization_spaces": true,
         
     | 
| 46 | 
         
            +
              "eos_token": "<eos>",
         
     | 
| 47 | 
         
            +
              "mask_token": "<mask>",
         
     | 
| 48 | 
         
            +
              "model_input_names": [
         
     | 
| 49 | 
         
            +
                "input_ids",
         
     | 
| 50 | 
         
            +
                "attention_mask"
         
     | 
| 51 | 
         
            +
              ],
         
     | 
| 52 | 
         
            +
              "model_max_length": 2048,
         
     | 
| 53 | 
         
            +
              "pad_token": "<pad>",
         
     | 
| 54 | 
         
            +
              "padding_side": "right",
         
     | 
| 55 | 
         
            +
              "tokenizer_class": "PreTrainedTokenizerFast",
         
     | 
| 56 | 
         
            +
              "truncation_side": "right",
         
     | 
| 57 | 
         
            +
              "unk_token": "<unk>"
         
     | 
| 58 | 
         
            +
            }
         
     |