alpindale commited on
Commit
c689221
·
verified ·
1 Parent(s): 5aa9ea3

Upload folder using huggingface_hub

Browse files
7B_1T_1/consolidated.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f0fd7688fea5a3edb4613fb40597715b7db1abbc30afee49a77c088e74fbe3b
3
+ size 26953703248
7B_1T_1/params.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dim": 4096,
3
+ "ffn_dim_multiplier": 1.0,
4
+ "multiple_of": 256,
5
+ "n_future_tokens": 1,
6
+ "n_heads": 32,
7
+ "n_kv_heads": 32,
8
+ "n_layers": 32,
9
+ "norm_eps": 1e-05,
10
+ "rope_theta": 10000.0,
11
+ "vocab_size": 32000
12
+ }
7B_1T_4/consolidated.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dd2a79e8283ec51955a914b0a416650a62ea6d20efc75bc6bdc8513fb8a55b5
3
+ size 26953703376
7B_1T_4/params.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dim": 4096,
3
+ "ffn_dim_multiplier": 1.0,
4
+ "multiple_of": 256,
5
+ "n_future_tokens": 4,
6
+ "n_heads": 32,
7
+ "n_kv_heads": 32,
8
+ "n_layers": 32,
9
+ "norm_eps": 1e-05,
10
+ "rope_theta": 10000.0,
11
+ "vocab_size": 32000
12
+ }
7B_200B_1/consolidated.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:239c8662dae921cc26d4060b62eee086fcf5c6d6611253ead120ed260cfce42a
3
+ size 26953703039
7B_200B_1/params.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dim": 4096,
3
+ "ffn_dim_multiplier": 1.0,
4
+ "multiple_of": 256,
5
+ "n_future_tokens": 1,
6
+ "n_heads": 32,
7
+ "n_kv_heads": 32,
8
+ "n_layers": 32,
9
+ "norm_eps": 1e-05,
10
+ "rope_theta": 10000.0,
11
+ "vocab_size": 32000
12
+ }
7B_200B_4/consolidated.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4f3209280186ad279b9f1421fc8a4bdd0f76a6ed18b14b277c9007c7cec8a0e
3
+ size 26953703376
7B_200B_4/params.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dim": 4096,
3
+ "ffn_dim_multiplier": 1.0,
4
+ "multiple_of": 256,
5
+ "n_future_tokens": 4,
6
+ "n_heads": 32,
7
+ "n_kv_heads": 32,
8
+ "n_layers": 32,
9
+ "norm_eps": 1e-05,
10
+ "rope_theta": 10000.0,
11
+ "vocab_size": 32000
12
+ }
LICENSE ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Multi-token Prediction Research License 18th June 2024
2
+
3
+ This Multi-token Prediction Research License (“Agreement”) contains the terms and conditions that govern your access and use of the Materials (as defined below). You may not use the Materials if you do not accept this Agreement. By clicking “I Accept” to accept, or accessing, using, or distributing any portion or element of the Materials you hereby agree to be bound by the terms of this Agreement. If you are agreeing to be bound by the Agreement on behalf of your employer or other entity, you represent and warrant to Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland) (“Meta”) that you have full legal authority to bind your employer or such entity to this Agreement. If you do not have requisite authority, you may not accept the Agreement or access the Materials on behalf of your employer or other entity.
4
+
5
+ This Agreement is effective upon the earlier of the date that you first access the Materials or accept this Agreement (“Effective Date”), and is entered into by and between Meta, and you, or if you are entering into this Agreement on behalf of your employer or other entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules, or regulations to provide legal consent and, your employer or other entity and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf (“Licensee” or “You”).
6
+
7
+ 1. Definitions.
8
+
9
+ a. “Documentation” means the specifications, manuals and documentation accompanying this release distributed by Meta at https://huggingface.co/facebook/multi-token-prediction.
10
+
11
+ b. “Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
12
+
13
+ c. “Materials” means, collectively, Documentation and the models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta at https://huggingface.co/facebook/multi-token-prediction and made available under this Agreement.
14
+
15
+ d. “Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
16
+
17
+ e. “Acceptable Use Policy” means the LLaMA Acceptable Use Policy applicable to Materials that is incorporated into this Agreement.
18
+
19
+ 2. License Rights and Redistribution. Subject to Your compliance with the terms and conditions of this Agreement, Meta hereby grants you the following:
20
+
21
+ a. Grant of Rights. You are hereby granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Materials solely for Noncommercial Research Uses.
22
+
23
+ b. Redistribution and Use.
24
+
25
+ i. Distribution of Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
26
+
27
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Materials, you must acknowledge the use of Materials in your publication.
28
+
29
+ iii. You must retain in all copies of the Materials that you distribute and include the following attribution notice within a “Notice” text file distributed as a part of such copies: “Materials are licensed under the Multi-token Prediction Research License, Copyright © Meta Platforms, Inc. All Rights Reserved.”
30
+
31
+ iv. Your use of the Materials must comply with applicable laws and regulations (including Trade Control Laws)) and adhere to the LLaMA Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
32
+
33
+ v. You agree to validate and confirm LLaMA outputs for compliance with the LLaMA Acceptable Use Policy, including before relying on LLaMA outputs in any way as part of research activities or incorporating these outputs in research, studies, and papers.
34
+
35
+ vi. You agree to report any violation of this Multi-token Prediction Research License or the Acceptable Use Policy as outlined in the LLaMA Acceptable Use Policy.
36
+
37
+ 3. Restrictions. You will not, and will not permit, assist or cause any third party to:
38
+
39
+ a. use the Materials or any outputs or results of the Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
40
+
41
+ b. disguise your or their location through IP proxying or other methods;
42
+
43
+ c. use or download Materials if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) will use Materials for any purpose prohibited by Trade Control Laws; or
44
+
45
+ d. directly or indirectly export, re-export, provide, or otherwise transfer Materials: (a) to any individual, entity, or country prohibited by Trade Control Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Trade Control Laws, including nuclear, chemical or biological weapons, or missile technology applications.
46
+
47
+ 4. User Support. Your Noncommercial Research Use of the Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
48
+
49
+ 5. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE MATERIALS AND ANY OUTPUT AND RESULTS.
50
+
51
+ 6. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
52
+
53
+ 7. Intellectual Property.
54
+
55
+ a. No trademark licenses are granted under this Agreement, and in connection with the Materials, neither Meta nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Materials.
56
+
57
+ b. Subject to Meta’s ownership of Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
58
+
59
+ c. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Materials or outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses and rights granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Materials.
60
+
61
+ 8. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Materials. Sections 3, 4, 5, 6, 7, 8 and 9 shall survive the termination of this Agreement.
62
+
63
+ 9. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
64
+
65
+ 10. Modifications and Amendments. Meta may modify this Agreement from time to time by posting a revised version at https://huggingface.co/facebook/multi-token-prediction/LICENSE; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no other modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ extra_gated_prompt: >-
4
+ ### MULTI-TOKEN PREDICTION RESEARCH LICENSE AGREEMENT 18th June 2024
5
+
6
+ This Multi-token Prediction Research License (“Agreement”) contains the terms and conditions that govern your access and use of the Materials (as defined below). You may not use the Materials if you do not accept this Agreement. By clicking "submit" below to accept, or accessing, using, or distributing any portion or element of the Materials you hereby agree to be bound by the terms of this Agreement. If you are agreeing to be bound by the Agreement on behalf of your employer or other entity, you represent and warrant to Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland) (“Meta”) that you have full legal authority to bind your employer or such entity to this Agreement. If you do not have requisite authority, you may not accept the Agreement or access the Materials on behalf of your employer or other entity.
7
+
8
+ This Agreement is effective upon the earlier of the date that you first access the Materials or accept this Agreement (“Effective Date”), and is entered into by and between Meta, and you, or if you are entering into this Agreement on behalf of your employer or other entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules, or regulations to provide legal consent and, your employer or other entity and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf (“Licensee” or “You”).
9
+
10
+ 1. Definitions.
11
+
12
+ a. “Documentation” means the specifications, manuals and documentation accompanying this release distributed by Meta at https://huggingface.co/facebook/multi-token-prediction.
13
+
14
+ b. “Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
15
+
16
+ c. “Materials” means, collectively, Documentation and the models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta at https://huggingface.co/facebook/multi-token-prediction and made available under this Agreement.
17
+
18
+ d. “Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
19
+
20
+ e. “Acceptable Use Policy” means the [LLaMA Acceptable Use Policy](https://ai.meta.com/llama/use-policy/) applicable to Materials that is incorporated into this Agreement.
21
+
22
+ 2. License Rights and Redistribution. Subject to Your compliance with the terms and conditions of this Agreement, Meta hereby grants you the following:
23
+
24
+ a. Grant of Rights. You are hereby granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Materials solely for Noncommercial Research Uses.
25
+
26
+ b. Redistribution and Use.
27
+
28
+ i. Distribution of Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
29
+
30
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Materials, you must acknowledge the use of Materials in your publication.
31
+
32
+ iii. You must retain in all copies of the Materials that you distribute and include the following attribution notice within a “Notice” text file distributed as a part of such copies: “Materials are licensed under the Multi-token Prediction Research License, Copyright © Meta Platforms, Inc. All Rights Reserved.”
33
+
34
+ iv. Your use of the Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the LLaMA Acceptable Use Policy, which is hereby incorporated by reference into this Agreement.
35
+
36
+ v. You agree to validate and confirm LLaMA outputs for compliance with the LLaMA Acceptable Use Policy, including before relying on LLaMA outputs in any way as part of research activities or incorporating these outputs in research, studies, and papers.
37
+
38
+ vi. You agree to report any violation of this Multi-token Prediction Research License or the Acceptable Use Policy, as outlined in the LLaMA Acceptable Use Policy.
39
+
40
+ 3. Restrictions. You will not, and will not permit, assist or cause any third party to:
41
+
42
+ a. use the Materials or any outputs or results of the Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
43
+
44
+ b. disguise your or their location through IP proxying or other methods;
45
+
46
+ c. use or download Materials if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) will use Materials for any purpose prohibited by Trade Control Laws; or
47
+
48
+ d. directly or indirectly export, re-export, provide, or otherwise transfer Materials: (a) to any individual, entity, or country prohibited by Trade Control Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Trade Control Laws, including nuclear, chemical or biological weapons, or missile technology applications.
49
+
50
+ 4. User Support. Your Noncommercial Research Use of the Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
51
+
52
+ 5. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE MATERIALS AND ANY OUTPUT AND RESULTS.
53
+
54
+ 6. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
55
+
56
+ 7. Intellectual Property.
57
+
58
+ a. No trademark licenses are granted under this Agreement, and in connection with the Materials, neither Meta nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Materials.
59
+
60
+ b. Subject to Meta’s ownership of Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
61
+
62
+ c. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Materials or outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses and rights granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Materials.
63
+
64
+ 8. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Materials. Sections 3, 4, 5, 6, 7, 8 and 9 shall survive the termination of this Agreement.
65
+
66
+ 9. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
67
+
68
+ 10. Modifications and Amendments. Meta may modify this Agreement from time to time by posting a revised version at https://huggingface.co/facebook/multi-token-prediction/LICENSE; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no other modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
69
+
70
+ extra_gated_fields:
71
+ First Name: text
72
+ Last Name: text
73
+ Date of birth: date_picker
74
+ Country: country
75
+ Affiliation: text
76
+ geo: ip_location
77
+ By clicking Submit below I accept the terms of the license and acknowledge that the information I provide will be collected stored processed and shared in accordance with the Meta Privacy Policy: checkbox
78
+ extra_gated_description: The information you provide will be collected, stored, processed and shared in accordance with the [Meta Privacy Policy](https://www.facebook.com/privacy/policy/).
79
+ extra_gated_button_content: Submit
80
+ ---
81
+
82
+ # **Multi-token prediction models and baselines**
83
+
84
+ Models accompanying the research paper "Better & Faster Large Language Models via Multi-token Prediction" (https://arxiv.org/abs/2404.19737).
85
+
86
+ Included are the following four 7B parameter models trained on code:
87
+ - baseline model (`n=1`) trained on 200B tokens of code: [7B_200B_1/](7B_200B_1/)
88
+ - multi-token prediction model (`n=4`) trained on 200B tokens of code: [7B_200B_4/](7B_200B_4/)
89
+ - baseline model (`n=1`) trained on 1T tokens of code: [7B_1T_1/](7B_1T_1/)
90
+ - multi-token prediction model (`n=4`) trained on 1T tokens of code: [7B_1T_4/](7B_1T_4/)
91
+
92
+ Tokenizer: standard Llama 2 SentencePiece tokenizer in [tokenizer.model](tokenizer.model).
93
+
94
+ ## *Quickstart*
95
+
96
+ Install `torch`, `fairscale`, `fire` and `sentencepiece` and run
97
+ ```
98
+ torchrun --nproc_per_node 1 example_completion.py --ckpt_dir 7B_200B_4/ --tokenizer_path tokenizer.model --max_seq_len 128 --max_batch_size 2
99
+ ```
100
+ replacing `7B_200B_4` by the respective checkpoint directory.
101
+
102
+ ## *Format*
103
+
104
+ The Pytorch `state_dicts` are compatible with Llama format: the layers of the shared trunk and the next-token prediction head layer are numbered contiguously. Additional prediction heads for tokens further in the future are names `extra_heads` and can be ignored for standard autoregressive inference.
105
+
106
+ The implementation of `forward()` in [llama/model.py](llama/model.py) provides an additional argument `return_all_heads`. If set, the additional prediction heads are called and the logits are returned in shape `(batch_size, seq_len, n_future_tokens, vocab_size)`. Otherwise, the logit's shape is `(batch_size, seq_len, 1, vocab_size)`.
107
+
108
+ ## *Citation*
109
+
110
+ Gloeckle, F., Idrissi, B. Y., Rozière, B., Lopez-Paz, D., & Synnaeve, G. (2024). Better & faster large language models via multi-token prediction. arXiv preprint arXiv:2404.19737.
111
+
112
+ Bibtex entry:
113
+ ```
114
+ @article{gloeckle2024better,
115
+ title={Better \& faster large language models via multi-token prediction},
116
+ author={Gloeckle, Fabian and Idrissi, Badr Youbi and Rozi{\`e}re, Baptiste and Lopez-Paz, David and Synnaeve, Gabriel},
117
+ journal={arXiv preprint arXiv:2404.19737},
118
+ year={2024}
119
+ }
120
+ ```
121
+
122
+ ## Feedback and comments
123
+ Please report risks as indicated in the Acceptable Use Policy and address bugs and any other comments to the corresponding authors as indicated in the research paper.
checklist.chk ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ c91d557db711a63af14a1c6e50fda2b3 7B_1T_1/consolidated.pth
2
+ 2646a33bf77c20f9219351b40f183f5e 7B_1T_1/params.json
3
+ 4d42458cc14f8cdd1b13a59106ddff75 7B_1T_4/consolidated.pth
4
+ ad853f2e42dd19f7163c04d82093ac28 7B_1T_4/params.json
5
+ 8d117da90ce11aaf03b4fe4d8cbc9ff2 7B_200B_1/consolidated.pth
6
+ 2646a33bf77c20f9219351b40f183f5e 7B_200B_1/params.json
7
+ b7c2764583dd99aee35e6109100f53c5 7B_200B_4/consolidated.pth
8
+ ad853f2e42dd19f7163c04d82093ac28 7B_200B_4/params.json
9
+ eeec4125e9c7560836b4873b6f8e3025 tokenizer.model
example_completion.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ from typing import Optional
5
+
6
+ import fire
7
+
8
+ from llama import Llama
9
+
10
+
11
+ def main(
12
+ ckpt_dir: str,
13
+ tokenizer_path: str,
14
+ temperature: float = 0.2,
15
+ top_p: float = 0.9,
16
+ max_seq_len: int = 256,
17
+ max_batch_size: int = 4,
18
+ max_gen_len: Optional[int] = None,
19
+ ):
20
+ generator = Llama.build(
21
+ ckpt_dir=ckpt_dir,
22
+ tokenizer_path=tokenizer_path,
23
+ max_seq_len=max_seq_len,
24
+ max_batch_size=max_batch_size,
25
+ )
26
+
27
+ prompts = [
28
+ # For these prompts, the expected answer is the natural continuation of the prompt
29
+ """\
30
+ def fizzbuzz(n: int):""",
31
+ """\
32
+ import argparse
33
+
34
+ def main(string: str):
35
+ print(string)
36
+ print(string[::-1])
37
+
38
+ if __name__ == "__main__":"""
39
+ ]
40
+ results = generator.text_completion(
41
+ prompts,
42
+ max_gen_len=max_gen_len,
43
+ temperature=temperature,
44
+ top_p=top_p,
45
+ )
46
+ for prompt, result in zip(prompts, results):
47
+ print(prompt)
48
+ print(f"> {result['generation']}")
49
+ print("\n==================================\n")
50
+
51
+
52
+ if __name__ == "__main__":
53
+ fire.Fire(main)
llama/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ from .generation import Llama, Dialog
5
+ from .model import ModelArgs, Transformer
6
+ from .tokenizer import Tokenizer
llama/generation.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import json
5
+ import os
6
+ import sys
7
+ import time
8
+ from pathlib import Path
9
+ from typing import List, Literal, Optional, Tuple, TypedDict
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from fairscale.nn.model_parallel.initialize import (
14
+ get_model_parallel_rank,
15
+ initialize_model_parallel,
16
+ model_parallel_is_initialized,
17
+ )
18
+
19
+ from llama.model import ModelArgs, Transformer
20
+ from llama.tokenizer import Tokenizer
21
+
22
+ Role = Literal["system", "user", "assistant"]
23
+
24
+
25
+ class Message(TypedDict):
26
+ role: Role
27
+ content: str
28
+
29
+
30
+ class CompletionPrediction(TypedDict, total=False):
31
+ generation: str
32
+ tokens: List[str] # not required
33
+ logprobs: List[float] # not required
34
+
35
+
36
+ class ChatPrediction(TypedDict, total=False):
37
+ generation: Message
38
+ tokens: List[str] # not required
39
+ logprobs: List[float] # not required
40
+
41
+
42
+ Dialog = List[Message]
43
+
44
+ B_INST, E_INST = "[INST]", "[/INST]"
45
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
46
+
47
+ SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
48
+ UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
49
+
50
+
51
+ class Llama:
52
+ @staticmethod
53
+ def build(
54
+ ckpt_dir: str,
55
+ tokenizer_path: str,
56
+ max_seq_len: int,
57
+ max_batch_size: int,
58
+ model_parallel_size: Optional[int] = None,
59
+ seed: int = 1,
60
+ ) -> "Llama":
61
+ """
62
+ Build a Llama instance by initializing and loading a pre-trained model.
63
+
64
+ Args:
65
+ ckpt_dir (str): Path to the directory containing checkpoint files.
66
+ tokenizer_path (str): Path to the tokenizer file.
67
+ max_seq_len (int): Maximum sequence length for input text.
68
+ max_batch_size (int): Maximum batch size for inference.
69
+ model_parallel_size (Optional[int], optional): Number of model parallel processes.
70
+ If not provided, it's determined from the environment. Defaults to None.
71
+
72
+ Returns:
73
+ Llama: An instance of the Llama class with the loaded model and tokenizer.
74
+
75
+ Raises:
76
+ AssertionError: If there are no checkpoint files in the specified directory,
77
+ or if the model parallel size does not match the number of checkpoint files.
78
+
79
+ Note:
80
+ This method initializes the distributed process group, sets the device to CUDA,
81
+ and loads the pre-trained model and tokenizer.
82
+
83
+ """
84
+ if not torch.distributed.is_initialized():
85
+ torch.distributed.init_process_group("nccl")
86
+ if not model_parallel_is_initialized():
87
+ if model_parallel_size is None:
88
+ model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
89
+ initialize_model_parallel(model_parallel_size)
90
+
91
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
92
+ torch.cuda.set_device(local_rank)
93
+
94
+ # seed must be the same in all processes
95
+ torch.manual_seed(seed)
96
+
97
+ if local_rank > 0:
98
+ sys.stdout = open(os.devnull, "w")
99
+
100
+ start_time = time.time()
101
+ checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
102
+ assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
103
+ assert model_parallel_size == len(
104
+ checkpoints
105
+ ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
106
+ ckpt_path = checkpoints[get_model_parallel_rank()]
107
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
108
+ with open(Path(ckpt_dir) / "params.json", "r") as f:
109
+ params = json.loads(f.read())
110
+
111
+ model_args: ModelArgs = ModelArgs(
112
+ max_seq_len=max_seq_len,
113
+ max_batch_size=max_batch_size,
114
+ **params,
115
+ )
116
+ tokenizer = Tokenizer(model_path=tokenizer_path)
117
+ model_args.vocab_size = tokenizer.n_words
118
+ torch.set_default_tensor_type(torch.cuda.HalfTensor)
119
+ model = Transformer(model_args)
120
+ model.load_state_dict(checkpoint, strict=False)
121
+ print(f"Loaded in {time.time() - start_time:.2f} seconds")
122
+
123
+ return Llama(model, tokenizer)
124
+
125
+ def __init__(self, model: Transformer, tokenizer: Tokenizer):
126
+ self.model = model
127
+ self.tokenizer = tokenizer
128
+
129
+ @torch.inference_mode()
130
+ def generate(
131
+ self,
132
+ prompt_tokens: List[List[int]],
133
+ max_gen_len: int,
134
+ temperature: float = 0.6,
135
+ top_p: float = 0.9,
136
+ logprobs: bool = False,
137
+ echo: bool = False,
138
+ ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
139
+ """
140
+ Generate text sequences based on provided prompts using the language generation model.
141
+
142
+ Args:
143
+ prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
144
+ max_gen_len (int): Maximum length of the generated text sequence.
145
+ temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
146
+ top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
147
+ logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
148
+ echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
149
+
150
+ Returns:
151
+ Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
152
+
153
+ Note:
154
+ This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
155
+ If logprobs is True, token log probabilities are computed for each generated token.
156
+
157
+ """
158
+ params = self.model.params
159
+ bsz = len(prompt_tokens)
160
+ assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
161
+
162
+ min_prompt_len = min(len(t) for t in prompt_tokens)
163
+ max_prompt_len = max(len(t) for t in prompt_tokens)
164
+ assert max_prompt_len <= params.max_seq_len
165
+ total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
166
+
167
+ pad_id = self.tokenizer.pad_id
168
+ tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
169
+ for k, t in enumerate(prompt_tokens):
170
+ tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
171
+ if logprobs:
172
+ token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
173
+
174
+ prev_pos = 0
175
+ eos_reached = torch.tensor([False] * bsz, device="cuda")
176
+ input_text_mask = tokens != pad_id
177
+ if min_prompt_len == total_len:
178
+ logits = self.model.forward(tokens, prev_pos).squeeze(2)
179
+ token_logprobs = -F.cross_entropy(
180
+ input=logits.transpose(1, -1),
181
+ target=tokens.flatten,
182
+ reduction="none",
183
+ ignore_index=pad_id,
184
+ )
185
+
186
+ for cur_pos in range(min_prompt_len, total_len):
187
+ logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos).squeeze(2)
188
+ if temperature > 0:
189
+ probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
190
+ next_token = sample_top_p(probs, top_p)
191
+ else:
192
+ next_token = torch.argmax(logits[:, -1], dim=-1)
193
+
194
+ next_token = next_token.reshape(-1)
195
+ # only replace token if prompt has already been generated
196
+ next_token = torch.where(
197
+ input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
198
+ )
199
+ tokens[:, cur_pos] = next_token
200
+ if logprobs:
201
+ token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
202
+ input=logits.transpose(1, -1),
203
+ target=tokens[:, prev_pos + 1 : cur_pos + 1],
204
+ reduction="none",
205
+ ignore_index=pad_id,
206
+ )
207
+ eos_reached |= (~input_text_mask[:, cur_pos]) & (
208
+ next_token == self.tokenizer.eos_id
209
+ )
210
+ prev_pos = cur_pos
211
+ if all(eos_reached):
212
+ break
213
+
214
+ if logprobs:
215
+ token_logprobs = token_logprobs.tolist()
216
+ out_tokens, out_logprobs = [], []
217
+ for i, toks in enumerate(tokens.tolist()):
218
+ # cut to max gen len
219
+ start = 0 if echo else len(prompt_tokens[i])
220
+ toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
221
+ probs = None
222
+ if logprobs:
223
+ probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
224
+ # cut to eos tok if any
225
+ if self.tokenizer.eos_id in toks:
226
+ eos_idx = toks.index(self.tokenizer.eos_id)
227
+ toks = toks[:eos_idx]
228
+ probs = probs[:eos_idx] if logprobs else None
229
+ out_tokens.append(toks)
230
+ out_logprobs.append(probs)
231
+ return (out_tokens, out_logprobs if logprobs else None)
232
+
233
+ def text_completion(
234
+ self,
235
+ prompts: List[str],
236
+ temperature: float = 0.6,
237
+ top_p: float = 0.9,
238
+ max_gen_len: Optional[int] = None,
239
+ logprobs: bool = False,
240
+ echo: bool = False,
241
+ ) -> List[CompletionPrediction]:
242
+ """
243
+ Perform text completion for a list of prompts using the language generation model.
244
+
245
+ Args:
246
+ prompts (List[str]): List of text prompts for completion.
247
+ temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
248
+ top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
249
+ max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
250
+ If not provided, it's set to the model's maximum sequence length minus 1.
251
+ logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
252
+ echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
253
+
254
+ Returns:
255
+ List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.
256
+
257
+ Note:
258
+ This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
259
+ If logprobs is True, token log probabilities are computed for each generated token.
260
+
261
+ """
262
+ if max_gen_len is None:
263
+ max_gen_len = self.model.params.max_seq_len - 1
264
+ prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
265
+ generation_tokens, generation_logprobs = self.generate(
266
+ prompt_tokens=prompt_tokens,
267
+ max_gen_len=max_gen_len,
268
+ temperature=temperature,
269
+ top_p=top_p,
270
+ logprobs=logprobs,
271
+ echo=echo,
272
+ )
273
+ if logprobs:
274
+ return [
275
+ {
276
+ "generation": self.tokenizer.decode(t),
277
+ "tokens": [self.tokenizer.decode(x) for x in t],
278
+ "logprobs": logprobs_i,
279
+ }
280
+ for t, logprobs_i in zip(generation_tokens, generation_logprobs)
281
+ ]
282
+ return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
283
+
284
+ def chat_completion(
285
+ self,
286
+ dialogs: List[Dialog],
287
+ temperature: float = 0.6,
288
+ top_p: float = 0.9,
289
+ max_gen_len: Optional[int] = None,
290
+ logprobs: bool = False,
291
+ ) -> List[ChatPrediction]:
292
+ """
293
+ Generate assistant responses for a list of conversational dialogs using the language generation model.
294
+
295
+ Args:
296
+ dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
297
+ temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
298
+ top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
299
+ max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
300
+ If not provided, it's set to the model's maximum sequence length minus 1.
301
+ logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
302
+
303
+ Returns:
304
+ List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
305
+
306
+ Raises:
307
+ AssertionError: If the last message in a dialog is not from the user.
308
+ AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
309
+
310
+ Note:
311
+ This method generates assistant responses for the provided conversational dialogs.
312
+ It employs nucleus sampling to introduce controlled randomness in text generation.
313
+ If logprobs is True, token log probabilities are computed for each generated token.
314
+
315
+ """
316
+ if max_gen_len is None:
317
+ max_gen_len = self.model.params.max_seq_len - 1
318
+ prompt_tokens = []
319
+ unsafe_requests = []
320
+ for dialog in dialogs:
321
+ unsafe_requests.append(
322
+ any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog])
323
+ )
324
+ if dialog[0]["role"] == "system":
325
+ dialog = [
326
+ {
327
+ "role": dialog[1]["role"],
328
+ "content": B_SYS
329
+ + dialog[0]["content"]
330
+ + E_SYS
331
+ + dialog[1]["content"],
332
+ }
333
+ ] + dialog[2:]
334
+ assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
335
+ [msg["role"] == "assistant" for msg in dialog[1::2]]
336
+ ), (
337
+ "model only supports 'system', 'user' and 'assistant' roles, "
338
+ "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
339
+ )
340
+ dialog_tokens: List[int] = sum(
341
+ [
342
+ self.tokenizer.encode(
343
+ f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
344
+ bos=True,
345
+ eos=True,
346
+ )
347
+ for prompt, answer in zip(
348
+ dialog[::2],
349
+ dialog[1::2],
350
+ )
351
+ ],
352
+ [],
353
+ )
354
+ assert (
355
+ dialog[-1]["role"] == "user"
356
+ ), f"Last message must be from user, got {dialog[-1]['role']}"
357
+ dialog_tokens += self.tokenizer.encode(
358
+ f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
359
+ bos=True,
360
+ eos=False,
361
+ )
362
+ prompt_tokens.append(dialog_tokens)
363
+
364
+ generation_tokens, generation_logprobs = self.generate(
365
+ prompt_tokens=prompt_tokens,
366
+ max_gen_len=max_gen_len,
367
+ temperature=temperature,
368
+ top_p=top_p,
369
+ logprobs=logprobs,
370
+ )
371
+ if logprobs:
372
+ return [
373
+ {
374
+ "generation": {
375
+ "role": "assistant",
376
+ "content": self.tokenizer.decode(t)
377
+ if not unsafe
378
+ else UNSAFE_ERROR,
379
+ },
380
+ "tokens": [self.tokenizer.decode(x) for x in t],
381
+ "logprobs": logprobs_i,
382
+ }
383
+ for t, logprobs_i, unsafe in zip(
384
+ generation_tokens, generation_logprobs, unsafe_requests
385
+ )
386
+ ]
387
+ return [
388
+ {
389
+ "generation": {
390
+ "role": "assistant",
391
+ "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR,
392
+ }
393
+ }
394
+ for t, unsafe in zip(generation_tokens, unsafe_requests)
395
+ ]
396
+
397
+
398
+ def sample_top_p(probs, p):
399
+ """
400
+ Perform top-p (nucleus) sampling on a probability distribution.
401
+
402
+ Args:
403
+ probs (torch.Tensor): Probability distribution tensor.
404
+ p (float): Probability threshold for top-p sampling.
405
+
406
+ Returns:
407
+ torch.Tensor: Sampled token indices.
408
+
409
+ Note:
410
+ Top-p sampling selects the smallest set of tokens whose cumulative probability mass
411
+ exceeds the threshold p. The distribution is renormalized based on the selected tokens.
412
+
413
+ """
414
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
415
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
416
+ mask = probs_sum - probs_sort > p
417
+ probs_sort[mask] = 0.0
418
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
419
+ next_token = torch.multinomial(probs_sort, num_samples=1)
420
+ next_token = torch.gather(probs_idx, -1, next_token)
421
+ return next_token
llama/model.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple
7
+
8
+ import fairscale.nn.model_parallel.initialize as fs_init
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from fairscale.nn.model_parallel.layers import (
12
+ ColumnParallelLinear,
13
+ ParallelEmbedding,
14
+ RowParallelLinear,
15
+ )
16
+ from torch import nn
17
+
18
+
19
+ @dataclass
20
+ class ModelArgs:
21
+ dim: int = 4096
22
+ n_layers: int = 32
23
+ n_heads: int = 32
24
+ n_kv_heads: Optional[int] = None
25
+ vocab_size: int = -1 # defined later by tokenizer
26
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27
+ ffn_dim_multiplier: Optional[float] = None
28
+ norm_eps: float = 1e-5
29
+ n_future_tokens: int = 1
30
+ rope_theta: float = 10000.0
31
+
32
+ max_batch_size: int = 32
33
+ max_seq_len: int = 2048
34
+
35
+
36
+ class RMSNorm(torch.nn.Module):
37
+ def __init__(self, dim: int, eps: float = 1e-6):
38
+ """
39
+ Initialize the RMSNorm normalization layer.
40
+
41
+ Args:
42
+ dim (int): The dimension of the input tensor.
43
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
44
+
45
+ Attributes:
46
+ eps (float): A small value added to the denominator for numerical stability.
47
+ weight (nn.Parameter): Learnable scaling parameter.
48
+
49
+ """
50
+ super().__init__()
51
+ self.eps = eps
52
+ self.weight = nn.Parameter(torch.ones(dim))
53
+
54
+ def _norm(self, x):
55
+ """
56
+ Apply the RMSNorm normalization to the input tensor.
57
+
58
+ Args:
59
+ x (torch.Tensor): The input tensor.
60
+
61
+ Returns:
62
+ torch.Tensor: The normalized tensor.
63
+
64
+ """
65
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
66
+
67
+ def forward(self, x):
68
+ """
69
+ Forward pass through the RMSNorm layer.
70
+
71
+ Args:
72
+ x (torch.Tensor): The input tensor.
73
+
74
+ Returns:
75
+ torch.Tensor: The output tensor after applying RMSNorm.
76
+
77
+ """
78
+ output = self._norm(x.float()).type_as(x)
79
+ return output * self.weight
80
+
81
+
82
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
83
+ """
84
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
85
+
86
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
87
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
88
+ The returned tensor contains complex values in complex64 data type.
89
+
90
+ Args:
91
+ dim (int): Dimension of the frequency tensor.
92
+ end (int): End index for precomputing frequencies.
93
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
94
+
95
+ Returns:
96
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
97
+
98
+
99
+
100
+
101
+ """
102
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
103
+ t = torch.arange(end, device=freqs.device) # type: ignore
104
+ freqs = torch.outer(t, freqs).float() # type: ignore
105
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
106
+ return freqs_cis
107
+
108
+
109
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
110
+ """
111
+ Reshape frequency tensor for broadcasting it with another tensor.
112
+
113
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
114
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
115
+
116
+ Args:
117
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
118
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
119
+
120
+ Returns:
121
+ torch.Tensor: Reshaped frequency tensor.
122
+
123
+ Raises:
124
+ AssertionError: If the frequency tensor doesn't match the expected shape.
125
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
126
+ """
127
+ ndim = x.ndim
128
+ assert 0 <= 1 < ndim
129
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
130
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
131
+ return freqs_cis.view(*shape)
132
+
133
+
134
+ def apply_rotary_emb(
135
+ xq: torch.Tensor,
136
+ xk: torch.Tensor,
137
+ freqs_cis: torch.Tensor,
138
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
139
+ """
140
+ Apply rotary embeddings to input tensors using the given frequency tensor.
141
+
142
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
143
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
144
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
145
+ returned as real tensors.
146
+
147
+ Args:
148
+ xq (torch.Tensor): Query tensor to apply rotary embeddings.
149
+ xk (torch.Tensor): Key tensor to apply rotary embeddings.
150
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
151
+
152
+ Returns:
153
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
154
+
155
+
156
+
157
+ """
158
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
159
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
160
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
161
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
162
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
163
+ return xq_out.type_as(xq), xk_out.type_as(xk)
164
+
165
+
166
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
167
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
168
+ bs, slen, n_kv_heads, head_dim = x.shape
169
+ if n_rep == 1:
170
+ return x
171
+ return (
172
+ x[:, :, :, None, :]
173
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
174
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
175
+ )
176
+
177
+
178
+ class Attention(nn.Module):
179
+ """Multi-head attention module."""
180
+ def __init__(self, args: ModelArgs):
181
+ """
182
+ Initialize the Attention module.
183
+
184
+ Args:
185
+ args (ModelArgs): Model configuration parameters.
186
+
187
+ Attributes:
188
+ n_kv_heads (int): Number of key and value heads.
189
+ n_local_heads (int): Number of local query heads.
190
+ n_local_kv_heads (int): Number of local key and value heads.
191
+ n_rep (int): Number of repetitions for local heads.
192
+ head_dim (int): Dimension size of each attention head.
193
+ wq (ColumnParallelLinear): Linear transformation for queries.
194
+ wk (ColumnParallelLinear): Linear transformation for keys.
195
+ wv (ColumnParallelLinear): Linear transformation for values.
196
+ wo (RowParallelLinear): Linear transformation for output.
197
+ cache_k (torch.Tensor): Cached keys for attention.
198
+ cache_v (torch.Tensor): Cached values for attention.
199
+
200
+ """
201
+ super().__init__()
202
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
203
+ model_parallel_size = fs_init.get_model_parallel_world_size()
204
+ self.n_local_heads = args.n_heads // model_parallel_size
205
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
206
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
207
+ self.head_dim = args.dim // args.n_heads
208
+
209
+ self.wq = ColumnParallelLinear(
210
+ args.dim,
211
+ args.n_heads * self.head_dim,
212
+ bias=False,
213
+ gather_output=False,
214
+ init_method=lambda x: x,
215
+ )
216
+ self.wk = ColumnParallelLinear(
217
+ args.dim,
218
+ self.n_kv_heads * self.head_dim,
219
+ bias=False,
220
+ gather_output=False,
221
+ init_method=lambda x: x,
222
+ )
223
+ self.wv = ColumnParallelLinear(
224
+ args.dim,
225
+ self.n_kv_heads * self.head_dim,
226
+ bias=False,
227
+ gather_output=False,
228
+ init_method=lambda x: x,
229
+ )
230
+ self.wo = RowParallelLinear(
231
+ args.n_heads * self.head_dim,
232
+ args.dim,
233
+ bias=False,
234
+ input_is_parallel=True,
235
+ init_method=lambda x: x,
236
+ )
237
+
238
+ self.cache_k = torch.zeros(
239
+ (
240
+ args.max_batch_size,
241
+ args.max_seq_len,
242
+ self.n_local_kv_heads,
243
+ self.head_dim,
244
+ )
245
+ ).cuda()
246
+ self.cache_v = torch.zeros(
247
+ (
248
+ args.max_batch_size,
249
+ args.max_seq_len,
250
+ self.n_local_kv_heads,
251
+ self.head_dim,
252
+ )
253
+ ).cuda()
254
+
255
+ def forward(
256
+ self,
257
+ x: torch.Tensor,
258
+ start_pos: int,
259
+ freqs_cis: torch.Tensor,
260
+ mask: Optional[torch.Tensor],
261
+ ):
262
+ """
263
+ Forward pass of the attention module.
264
+
265
+ Args:
266
+ x (torch.Tensor): Input tensor.
267
+ start_pos (int): Starting position for caching.
268
+ freqs_cis (torch.Tensor): Precomputed frequency tensor.
269
+ mask (torch.Tensor, optional): Attention mask tensor.
270
+
271
+ Returns:
272
+ torch.Tensor: Output tensor after attention.
273
+
274
+ """
275
+ bsz, seqlen, _ = x.shape
276
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
277
+
278
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
279
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
280
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
281
+
282
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
283
+
284
+ self.cache_k = self.cache_k.to(xq)
285
+ self.cache_v = self.cache_v.to(xq)
286
+
287
+ self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
288
+ self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
289
+
290
+ keys = self.cache_k[:bsz, : start_pos + seqlen]
291
+ values = self.cache_v[:bsz, : start_pos + seqlen]
292
+
293
+ # repeat k/v heads if n_kv_heads < n_heads
294
+ keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
295
+ values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
296
+
297
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
298
+ keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
299
+ values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
300
+ scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
301
+ if mask is not None:
302
+ scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
303
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
304
+ output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
305
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
306
+ return self.wo(output)
307
+
308
+
309
+ class FeedForward(nn.Module):
310
+ def __init__(
311
+ self,
312
+ dim: int,
313
+ hidden_dim: int,
314
+ multiple_of: int,
315
+ ffn_dim_multiplier: Optional[float],
316
+ ):
317
+ """
318
+ Initialize the FeedForward module.
319
+
320
+ Args:
321
+ dim (int): Input dimension.
322
+ hidden_dim (int): Hidden dimension of the feedforward layer.
323
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
324
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
325
+
326
+ Attributes:
327
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
328
+ w2 (RowParallelLinear): Linear transformation for the second layer.
329
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
330
+
331
+ """
332
+ super().__init__()
333
+ hidden_dim = int(2 * hidden_dim / 3)
334
+ # custom dim factor multiplier
335
+ if ffn_dim_multiplier is not None:
336
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
337
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
338
+
339
+ self.w1 = ColumnParallelLinear(
340
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
341
+ )
342
+ self.w2 = RowParallelLinear(
343
+ hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
344
+ )
345
+ self.w3 = ColumnParallelLinear(
346
+ dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
347
+ )
348
+
349
+ def forward(self, x):
350
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
351
+
352
+
353
+ class TransformerBlock(nn.Module):
354
+ def __init__(self, layer_id: int, args: ModelArgs):
355
+ """
356
+ Initialize a TransformerBlock.
357
+
358
+ Args:
359
+ layer_id (int): Identifier for the layer.
360
+ args (ModelArgs): Model configuration parameters.
361
+
362
+ Attributes:
363
+ n_heads (int): Number of attention heads.
364
+ dim (int): Dimension size of the model.
365
+ head_dim (int): Dimension size of each attention head.
366
+ attention (Attention): Attention module.
367
+ feed_forward (FeedForward): FeedForward module.
368
+ layer_id (int): Identifier for the layer.
369
+ attention_norm (RMSNorm): Layer normalization for attention output.
370
+ ffn_norm (RMSNorm): Layer normalization for feedforward output.
371
+
372
+ """
373
+ super().__init__()
374
+ self.n_heads = args.n_heads
375
+ self.dim = args.dim
376
+ self.head_dim = args.dim // args.n_heads
377
+ self.attention = Attention(args)
378
+ self.feed_forward = FeedForward(
379
+ dim=args.dim,
380
+ hidden_dim=4 * args.dim,
381
+ multiple_of=args.multiple_of,
382
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
383
+ )
384
+ self.layer_id = layer_id
385
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
386
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
387
+
388
+ def forward(
389
+ self,
390
+ x: torch.Tensor,
391
+ start_pos: int,
392
+ freqs_cis: torch.Tensor,
393
+ mask: Optional[torch.Tensor],
394
+ ):
395
+ """
396
+ Perform a forward pass through the TransformerBlock.
397
+
398
+ Args:
399
+ x (torch.Tensor): Input tensor.
400
+ start_pos (int): Starting position for attention caching.
401
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
402
+ mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
403
+
404
+ Returns:
405
+ torch.Tensor: Output tensor after applying attention and feedforward layers.
406
+
407
+ """
408
+ h = x + self.attention(
409
+ self.attention_norm(x), start_pos, freqs_cis, mask
410
+ )
411
+ out = h + self.feed_forward(self.ffn_norm(h))
412
+ return out
413
+
414
+
415
+ class Transformer(nn.Module):
416
+ def __init__(self, params: ModelArgs):
417
+ """
418
+ Initialize a Transformer model.
419
+
420
+ Args:
421
+ params (ModelArgs): Model configuration parameters.
422
+
423
+ Attributes:
424
+ params (ModelArgs): Model configuration parameters.
425
+ vocab_size (int): Vocabulary size.
426
+ n_layers (int): Total number of layers in the model (including extra heads).
427
+ n_future_tokens (int): Number of prediction heads in the model (= 1 + `len(extra_heads)`).
428
+ tok_embeddings (ParallelEmbedding): Token embeddings.
429
+ layers (torch.nn.ModuleList): List of Transformer blocks (trunk + next-token head).
430
+ extra_heads (torch.nn.ModuleList): List of Transformer blocks
431
+ (additional prediction heads for multi-token prediction).
432
+ norm (RMSNorm): Layer normalization for the model output.
433
+ output (ColumnParallelLinear): Linear layer for final output.
434
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
435
+
436
+ """
437
+ super().__init__()
438
+ self.params = params
439
+ self.vocab_size = params.vocab_size
440
+ self.n_layers = params.n_layers
441
+ self.n_future_tokens = params.n_future_tokens
442
+
443
+ self.tok_embeddings = ParallelEmbedding(
444
+ params.vocab_size, params.dim, init_method=lambda x: x
445
+ )
446
+
447
+ self.layers = torch.nn.ModuleList()
448
+ for layer_id in range(params.n_layers - self.n_future_tokens + 1):
449
+ self.layers.append(TransformerBlock(layer_id, params))
450
+
451
+ # Additional prediction heads for multi-token prediction.
452
+ # `layer_id` counts contiguously from the first Transformer block.
453
+ self.extra_heads = torch.nn.ModuleList()
454
+ for layer_id in range(self.n_layers - self.n_future_tokens + 1, self.n_layers):
455
+ self.extra_heads.append(TransformerBlock(layer_id, params))
456
+
457
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
458
+ self.output = ColumnParallelLinear(
459
+ params.dim, params.vocab_size, bias=False, init_method=lambda x: x
460
+ )
461
+
462
+ self.freqs_cis = precompute_freqs_cis(
463
+ # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096.
464
+ # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
465
+ self.params.dim // self.params.n_heads, self.params.max_seq_len * 2, theta=self.params.rope_theta
466
+ )
467
+
468
+ @torch.inference_mode()
469
+ def forward(self, tokens: torch.Tensor, start_pos: int, return_all_heads: bool = False):
470
+ """
471
+ Perform a forward pass through the Transformer model.
472
+
473
+ Args:
474
+ tokens (torch.Tensor): Input token indices.
475
+ start_pos (int): Starting position for attention caching.
476
+ return_all_heads (bool, optional): Whether to return logits
477
+ for all prediction heads. Defaults to False.
478
+
479
+ Returns:
480
+ torch.Tensor: Output logits after applying the Transformer model
481
+ of shape (batch_size, seq_len, n_future_tokens, vocab_size).
482
+
483
+ Note:
484
+ If return_all_heads is False, the output logits broadcast to
485
+ (batch_size, seq_len, vocab_size) and are compatible with standard
486
+ decoding.
487
+ """
488
+ _bsz, seqlen = tokens.shape
489
+ h = self.tok_embeddings(tokens)
490
+ self.freqs_cis = self.freqs_cis.to(h.device)
491
+ freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
492
+
493
+ mask = None
494
+ if seqlen > 1:
495
+ mask = torch.full(
496
+ (seqlen, seqlen), float("-inf"), device=tokens.device
497
+ )
498
+
499
+ mask = torch.triu(mask, diagonal=1)
500
+
501
+ # When performing key-value caching, we compute the attention scores
502
+ # only for the new sequence. Thus, the matrix of scores is of size
503
+ # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
504
+ # j > cache_len + i, since row i corresponds to token cache_len + i.
505
+ mask = torch.hstack([
506
+ torch.zeros((seqlen, start_pos), device=tokens.device),
507
+ mask
508
+ ]).type_as(h)
509
+
510
+ # Model trunk.
511
+ for layer in self.layers[:-1]:
512
+ h = layer(h, start_pos, freqs_cis, mask)
513
+ h_trunk = h
514
+
515
+ # Prediction heads.
516
+ latents = []
517
+ n_heads_to_use = self.n_future_tokens if return_all_heads else 1
518
+ prediction_heads = [self.layers[-1]] + list(self.extra_heads)
519
+ for layer in prediction_heads[:n_heads_to_use]:
520
+ h = layer(h_trunk, start_pos, freqs_cis, mask)
521
+ latents.append(h)
522
+
523
+ h = torch.stack(latents, dim=-2) # (_bsz, seqlen, n_heads_to_use, dim)
524
+ h = self.norm(h)
525
+ output = self.output(h).float()
526
+ return output
llama/tokenizer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+
4
+ import os
5
+ from logging import getLogger
6
+ from typing import List
7
+
8
+ from sentencepiece import SentencePieceProcessor
9
+
10
+
11
+ logger = getLogger()
12
+
13
+
14
+ class Tokenizer:
15
+ """tokenizing and encoding/decoding text using SentencePiece."""
16
+ def __init__(self, model_path: str):
17
+ """
18
+ Initializes the Tokenizer with a SentencePiece model.
19
+
20
+ Args:
21
+ model_path (str): The path to the SentencePiece model file.
22
+ """
23
+ # reload tokenizer
24
+ assert os.path.isfile(model_path), model_path
25
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
26
+ logger.info(f"Reloaded SentencePiece model from {model_path}")
27
+
28
+ # BOS / EOS token IDs
29
+ self.n_words: int = self.sp_model.vocab_size()
30
+ self.bos_id: int = self.sp_model.bos_id()
31
+ self.eos_id: int = self.sp_model.eos_id()
32
+ self.pad_id: int = self.sp_model.pad_id()
33
+ logger.info(
34
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
35
+ )
36
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
37
+
38
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
39
+ """
40
+ Encodes a string into a list of token IDs.
41
+
42
+ Args:
43
+ s (str): The input string to be encoded.
44
+ bos (bool): Whether to prepend the beginning-of-sequence token.
45
+ eos (bool): Whether to append the end-of-sequence token.
46
+
47
+ Returns:
48
+ List[int]: A list of token IDs.
49
+ """
50
+ assert type(s) is str
51
+ t = self.sp_model.encode(s)
52
+ if bos:
53
+ t = [self.bos_id] + t
54
+ if eos:
55
+ t = t + [self.eos_id]
56
+ return t
57
+
58
+ def decode(self, t: List[int]) -> str:
59
+ """
60
+ Decodes a list of token IDs into a string.
61
+
62
+ Args:
63
+ t (List[int]): The list of token IDs to be decoded.
64
+
65
+ Returns:
66
+ str: The decoded string.
67
+ """
68
+ return self.sp_model.decode(t)
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723