trhacknon pseudotensor commited on
Commit
d7185d6
·
0 Parent(s):

Duplicate from h2oai/h2ogpt-chatbot

Browse files

Co-authored-by: Jonathan McKinney <[email protected]>

Files changed (11) hide show
  1. .gitattributes +34 -0
  2. LICENSE +201 -0
  3. README.md +14 -0
  4. app.py +0 -0
  5. client_test.py +121 -0
  6. finetune.py +932 -0
  7. h2o-logo.svg +1 -0
  8. prompter.py +106 -0
  9. requirements.txt +50 -0
  10. stopping.py +139 -0
  11. utils.py +186 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: H2ogpt Chatbot
3
+ emoji: 📚
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: h2oai/h2ogpt-chatbot
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
The diff for this file is too large to render. See raw diff
 
client_test.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Client test.
3
+
4
+ Run server:
5
+
6
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
7
+
8
+ NOTE: For private models, add --use-auth_token=True
9
+
10
+ NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
11
+ Currently, this will force model to be on a single GPU.
12
+
13
+ Then run this client as:
14
+
15
+ python client_test.py
16
+
17
+
18
+
19
+ For HF spaces:
20
+
21
+ HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
22
+
23
+ Result:
24
+
25
+ Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
26
+ {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.'}
27
+
28
+
29
+ For demo:
30
+
31
+ HOST="https://gpt.h2o.ai" python client_test.py
32
+
33
+ Result:
34
+
35
+ Loaded as API: https://gpt.h2o.ai ✔
36
+ {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.'}
37
+
38
+ """
39
+
40
+ debug = False
41
+
42
+ import os
43
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
44
+
45
+
46
+ def get_client():
47
+ from gradio_client import Client
48
+
49
+ client = Client(os.getenv('HOST', "http://localhost:7860"))
50
+ if debug:
51
+ print(client.view_api(all_endpoints=True))
52
+ return client
53
+
54
+
55
+ def test_client_basic():
56
+ instruction = '' # only for chat=True
57
+ iinput = '' # only for chat=True
58
+ context = ''
59
+ # streaming output is supported, loops over and outputs each generation in streaming mode
60
+ # but leave stream_output=False for simple input/output mode
61
+ stream_output = False
62
+ prompt_type = 'human_bot'
63
+ temperature = 0.1
64
+ top_p = 0.75
65
+ top_k = 40
66
+ num_beams = 1
67
+ max_new_tokens = 50
68
+ min_new_tokens = 0
69
+ early_stopping = False
70
+ max_time = 20
71
+ repetition_penalty = 1.0
72
+ num_return_sequences = 1
73
+ do_sample = True
74
+ # only these 2 below used if pass chat=False
75
+ chat = False
76
+ instruction_nochat = "Who are you?"
77
+ iinput_nochat = ''
78
+
79
+ args = [instruction,
80
+ iinput,
81
+ context,
82
+ stream_output,
83
+ prompt_type,
84
+ temperature,
85
+ top_p,
86
+ top_k,
87
+ num_beams,
88
+ max_new_tokens,
89
+ min_new_tokens,
90
+ early_stopping,
91
+ max_time,
92
+ repetition_penalty,
93
+ num_return_sequences,
94
+ do_sample,
95
+ chat,
96
+ instruction_nochat,
97
+ iinput_nochat,
98
+ ]
99
+ api_name = '/submit_nochat'
100
+ client = get_client()
101
+ res = client.predict(
102
+ *tuple(args),
103
+ api_name=api_name,
104
+ )
105
+ res_dict = dict(instruction_nochat=instruction_nochat, iinput_nochat=iinput_nochat, response=md_to_text(res))
106
+ print(res_dict)
107
+ return res_dict
108
+
109
+
110
+ import markdown # pip install markdown
111
+ from bs4 import BeautifulSoup # pip install beautifulsoup4
112
+
113
+
114
+ def md_to_text(md):
115
+ html = markdown.markdown(md)
116
+ soup = BeautifulSoup(html, features='html.parser')
117
+ return soup.get_text()
118
+
119
+
120
+ if __name__ == '__main__':
121
+ test_client_basic()
finetune.py ADDED
@@ -0,0 +1,932 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import random
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from datetime import datetime
9
+ from typing import List, Union
10
+ import fire
11
+ import numpy as np
12
+ import torch
13
+ from datasets import load_dataset, concatenate_datasets
14
+ import transformers
15
+ import torch.distributed as dist
16
+
17
+ from peft import (
18
+ prepare_model_for_int8_training,
19
+ LoraConfig,
20
+ get_peft_model,
21
+ get_peft_model_state_dict,
22
+ set_peft_model_state_dict,
23
+ )
24
+
25
+ from peft import mapping
26
+ lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
27
+
28
+
29
+ def log(*args, **kwargs):
30
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
31
+ print(*args, **kwargs)
32
+
33
+
34
+ try:
35
+ import neptune
36
+ from transformers.integrations import NeptuneCallback
37
+
38
+ neptune_run = neptune.init_run(
39
+ source_files=[],
40
+ )
41
+ log("Connected to Neptune.")
42
+ except ImportError:
43
+ neptune_run = None
44
+ log("Please pip install neptune for tracking.")
45
+ except neptune.exceptions.NeptuneMissingApiTokenException:
46
+ neptune_run = None
47
+ os.environ["NEPTUNE_MODE"] = 'debug'
48
+ log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
49
+
50
+ from enum import Enum
51
+
52
+
53
+ class PromptType(Enum):
54
+ plain = 0
55
+ instruct = 1
56
+ quality = 2
57
+ human_bot = 3
58
+ dai_faq = 4
59
+ summarize = 5
60
+ simple_instruct = 6
61
+ instruct_vicuna = 7
62
+ instruct_with_end = 8
63
+ human_bot_orig = 9
64
+
65
+
66
+ prompt_type_to_model_name = {
67
+ 'plain': [
68
+ 'EleutherAI/gpt-j-6B',
69
+ 'EleutherAI/pythia-6.9b',
70
+ 'EleutherAI/pythia-12b',
71
+ 'EleutherAI/pythia-12b-deduped',
72
+ 'EleutherAI/gpt-neox-20b',
73
+ 'decapoda-research/llama-7b-hf',
74
+ 'decapoda-research/llama-13b-hf',
75
+ 'decapoda-research/llama-30b-hf',
76
+ 'decapoda-research/llama-65b-hf',
77
+ 'facebook/mbart-large-50-many-to-many-mmt',
78
+ 'philschmid/bart-large-cnn-samsum',
79
+ 'philschmid/flan-t5-base-samsum',
80
+ 'gpt2',
81
+ 'distilgpt2',
82
+ ],
83
+ 'instruct': [],
84
+ 'instruct_with_end': ['databricks/dolly-v2-12b'],
85
+ 'quality': [],
86
+ 'human_bot': [
87
+ 'h2oai/h2ogpt-oasst1-512-12b',
88
+ 'h2oai/h2ogpt-oasst1-512-20b',
89
+ 'h2oai/h2ogpt-oig-oasst1-512-6.9b',
90
+ ],
91
+ 'dai_faq': [],
92
+ 'summarize': [],
93
+ 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
94
+ 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
95
+ 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
96
+ }
97
+
98
+ inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
99
+ inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
100
+
101
+ human = '<human>:'
102
+ bot = "<bot>:"
103
+
104
+ prompt_types_strings = []
105
+ for p in PromptType:
106
+ prompt_types_strings.extend([p.name])
107
+
108
+
109
+ prompt_types = []
110
+ for p in PromptType:
111
+ prompt_types.extend([p.name, p.value, str(p.value)])
112
+
113
+
114
+ # supported by huggingface evaluate
115
+ supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
116
+
117
+
118
+ def train(
119
+ save_code: bool = False,
120
+ run_id: int = None,
121
+
122
+ base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6.9b',
123
+ # base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
124
+ # base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
125
+ # base_model: str = 'EleutherAI/gpt-neox-20b',
126
+ # base_model: str = 'EleutherAI/pythia-12b-deduped',
127
+ # base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
128
+ # base_model: str = 'decapoda-research/llama-7b-hf',
129
+ # base_model: str = 'decapoda-research/llama-13b-hf',
130
+ # base_model: str = 'decapoda-research/llama-30b-hf',
131
+ # base_model: str = 'EleutherAI/gpt-j-6B',
132
+
133
+ # only needed if base_model is self-exported HF state without tokenizer
134
+ tokenizer_base_model: str = None,
135
+ # tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
136
+
137
+ data_path: str = None,
138
+ data_col_dict: dict = None,
139
+ # data_path: str = "./dai_docs.train.json",
140
+ prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
141
+
142
+ valid_path: str = None,
143
+ # valid_path: str = "./dai_docs.valid.json",
144
+
145
+ # data_mix_in_path: str = "laion/OIG", # way too big, medium quality
146
+ data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
147
+ data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
148
+ data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
149
+ data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
150
+
151
+ output_dir: str = None,
152
+
153
+ # LoRA checkpoint continuation
154
+ lora_weights: str = "",
155
+
156
+ # batching training hyperparams
157
+ batch_size: int = 128,
158
+ micro_batch_size: int = 4,
159
+ gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
160
+ fp16=True,
161
+
162
+ # general training hyperparams
163
+ num_epochs: float = 1,
164
+ learning_rate: float = 3e-4,
165
+
166
+ # validation settings
167
+ val_set_size: int = None,
168
+ val_metrics: List[str] = [],
169
+ eval_steps: int = None, # to control eval steps via steps
170
+ eval_epochs: float = None, # to control eval steps via epochs
171
+
172
+ # lora hyperparams
173
+ lora_r: int = 8,
174
+ lora_alpha: int = 16,
175
+ lora_dropout: float = 0.05,
176
+ lora_target_modules: List[str] = None,
177
+ llama_type: bool = None,
178
+
179
+ # llm hyperparams
180
+ train_on_inputs: bool = True, # if False, masks out inputs in loss
181
+ group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
182
+ resume_from_checkpoint: str = None, # either training checkpoint or final adapter
183
+ cutoff_len: int = 1024, # Good default, especially when have high quality non-trivial data
184
+
185
+ # torch training params
186
+ ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
187
+ local_files_only: bool = False, # else will download new versions, normally unwanted
188
+ resume_download: bool = True,
189
+ use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
190
+ warmup_steps: int = 100,
191
+ logging_steps: int = 1,
192
+ save_steps: int = None, # must be round multiple of eval_steps
193
+ add_eos_token: bool = False,
194
+ ):
195
+ # allow set token directly
196
+ use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
197
+
198
+ prompt_type = str(prompt_type) # migration from integers
199
+ assert prompt_type in prompt_types
200
+
201
+ world_size = int(os.getenv("WORLD_SIZE", 1))
202
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
203
+ rank = int(os.getenv("RANK", 0))
204
+ print(f"local_rank: {local_rank}")
205
+ print(f"global rank: {rank}")
206
+
207
+ gpus = max(world_size, torch.cuda.device_count())
208
+ run_id = run_id or 0
209
+ if not data_path:
210
+ raise ValueError("No data_path provided")
211
+ if not output_dir:
212
+ output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
213
+ if os.path.exists(output_dir) and not resume_from_checkpoint:
214
+ raise FileExistsError(f"output_dir based on run_id {run_id} already exists. Please pick a different run_id.")
215
+ else:
216
+ if os.path.exists(output_dir) and not resume_from_checkpoint:
217
+ raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
218
+ device_map = "auto"
219
+
220
+ if save_code:
221
+ copy_code(run_id)
222
+ if tokenizer_base_model is None:
223
+ tokenizer_base_model = base_model
224
+ if llama_type is None:
225
+ llama_type = "llama" in base_model.lower()
226
+ assert (
227
+ base_model
228
+ ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
229
+ gradient_accumulation_steps = batch_size // micro_batch_size
230
+ assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
231
+
232
+ device_map = "auto"
233
+
234
+ locals_dict = locals()
235
+ locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
236
+ log(f"Training model with params:\n{locals_print}")
237
+ log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
238
+
239
+ max_memory = None
240
+ if gpus > 1:
241
+ if ddp:
242
+ log("Distributed: data parallel")
243
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
244
+ gradient_accumulation_steps = gradient_accumulation_steps // world_size
245
+ else:
246
+ free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
247
+ max_memory = f"{free_in_GB - 2}GB"
248
+ max_memory = {i: max_memory for i in range(gpus)}
249
+ log("world_size: %d" % world_size)
250
+ log("num_gpus: %d" % gpus)
251
+ log("max mem: %s" % max_memory)
252
+
253
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
254
+
255
+ model = model_loader.from_pretrained(
256
+ base_model,
257
+ load_in_8bit=True,
258
+ device_map=device_map,
259
+ torch_dtype=torch.float16,
260
+ max_memory=max_memory,
261
+ local_files_only=local_files_only,
262
+ resume_download=resume_download,
263
+ use_auth_token=use_auth_token,
264
+ )
265
+ if gpus > 1:
266
+ if not ddp:
267
+ log("model parallel")
268
+ model.is_parallelizable = True
269
+ model.model_parallel = True
270
+
271
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
272
+ local_files_only=local_files_only,
273
+ resume_download=resume_download,
274
+ use_auth_token=use_auth_token)
275
+
276
+ tokenizer.pad_token_id = 0 # different from the eos token
277
+ # when generating, we will use the logits of right-most token to predict the next token
278
+ # so the padding should be on the left,
279
+ # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
280
+ tokenizer.padding_side = "left" # Allow batched inference
281
+
282
+ def tokenize(prompt, add_eos_token=True):
283
+ # there's probably a way to do this with the tokenizer settings
284
+ # but again, gotta move fast
285
+ result = tokenizer(
286
+ prompt,
287
+ truncation=True,
288
+ max_length=cutoff_len,
289
+ padding=False,
290
+ return_tensors=None,
291
+ )
292
+ if (
293
+ result["input_ids"][-1] != tokenizer.eos_token_id
294
+ and len(result["input_ids"]) < cutoff_len
295
+ and add_eos_token
296
+ ):
297
+ result["input_ids"].append(tokenizer.eos_token_id)
298
+ result["attention_mask"].append(1)
299
+
300
+ result["labels"] = result["input_ids"].copy()
301
+
302
+ return result
303
+
304
+ def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
305
+ full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
306
+ tokenized_full_prompt = tokenize(full_prompt)
307
+ if not train_on_inputs:
308
+ user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
309
+ tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos)
310
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
311
+ if add_eos:
312
+ user_prompt_len -= 1
313
+
314
+ # ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
315
+ tokenized_full_prompt["labels"] = [
316
+ -100
317
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
318
+ user_prompt_len:
319
+ ] # could be sped up, probably
320
+ return tokenized_full_prompt
321
+
322
+ if "gpt-neox" not in base_model or True:
323
+ model = prepare_model_for_int8_training(model)
324
+ else:
325
+ model = prepare_model_for_int8_training(
326
+ model,
327
+ output_embedding_layer_name="embed_out", # keep output logits in float32
328
+ layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
329
+ )
330
+ if lora_weights:
331
+ from peft import PeftModel
332
+ model = PeftModel.from_pretrained(
333
+ model,
334
+ lora_weights,
335
+ torch_dtype=torch.float16,
336
+ device_map=device_map,
337
+ local_files_only=local_files_only,
338
+ resume_download=resume_download,
339
+ use_auth_token=use_auth_token,
340
+ )
341
+ else:
342
+ if lora_target_modules is None:
343
+ base_model_lower = base_model.lower()
344
+ if base_model_lower in lora_mappings:
345
+ lora_target_modules_cand = [lora_mappings[base_model_lower]]
346
+ else:
347
+ lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
348
+ else:
349
+ lora_target_modules_cand = [lora_target_modules]
350
+
351
+ for lora_target_modules in lora_target_modules_cand:
352
+ try:
353
+ config = LoraConfig(
354
+ r=lora_r,
355
+ lora_alpha=lora_alpha,
356
+ target_modules=lora_target_modules,
357
+ lora_dropout=lora_dropout,
358
+ bias="none",
359
+ task_type="CAUSAL_LM",
360
+ )
361
+ model = get_peft_model(model, config)
362
+ break
363
+ except ValueError as e:
364
+ if "Target modules" in str(e) and "not found" in str(e):
365
+ continue
366
+ else:
367
+ raise
368
+ from peft import PeftModel
369
+ assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
370
+ if resume_from_checkpoint:
371
+ # Check the available weights and load them
372
+ checkpoint_name = os.path.join(
373
+ resume_from_checkpoint, "pytorch_model.bin"
374
+ ) # Full checkpoint
375
+ if not os.path.exists(checkpoint_name):
376
+ checkpoint_name = os.path.join(
377
+ resume_from_checkpoint, "adapter_model.bin"
378
+ ) # only LoRA model - LoRA config above has to fit
379
+ resume_from_checkpoint = False # So the trainer won't try loading its state
380
+ # The two files above have a different name depending on how they were saved, but are actually the same.
381
+ if os.path.exists(checkpoint_name):
382
+ log(f"Restarting from {checkpoint_name}")
383
+ adapters_weights = torch.load(checkpoint_name)
384
+ model = set_peft_model_state_dict(model, adapters_weights)
385
+ else:
386
+ log(f"Checkpoint {checkpoint_name} not found")
387
+
388
+ print(model)
389
+ model.print_trainable_parameters() # Be more transparent about the % of trainable params.
390
+
391
+ metrics = {}
392
+ for name in supported_metrics:
393
+ if name in val_metrics:
394
+ import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
395
+ metrics[name] = evaluate.load(name)
396
+ log("Using Validation Metrics: %s" % str(list(metrics.keys())))
397
+ log("Supported Metrics: %s" % supported_metrics)
398
+
399
+ if val_set_size is None:
400
+ if len(metrics) == 0:
401
+ val_set_size = 1000
402
+ else:
403
+ val_set_size = 100
404
+ log("Auto set val_set_size %s" % val_set_size)
405
+ elif val_set_size < 1.0 and val_set_size != 0:
406
+ raise RuntimeError("Fractional validation size not supported.")
407
+
408
+ if valid_path:
409
+ data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
410
+ else:
411
+ if "json" in data_path:
412
+ data = load_dataset("json", data_files={"train": data_path})
413
+ else:
414
+ data = load_dataset(data_path)
415
+ data = data.rename_columns(data_col_dict or {})
416
+
417
+ valid_data = None
418
+ train_data_mix_in = None
419
+ valid_data_mix_in = None
420
+
421
+ if data_mix_in_path and data_mix_in_factor > 0:
422
+ # get mix-in training/validation data - to keep model "sane"
423
+ num_rows = data["train"].num_rows
424
+ log("Loading mix-in dataset: %s" % data_mix_in_path)
425
+ if "json" in data_mix_in_path:
426
+ data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
427
+ else:
428
+ data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
429
+ data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
430
+
431
+ # only get as much as we need to balance
432
+ valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
433
+ train_size = max(1, min(data_mix_in.num_rows - valid_size, int(num_rows * data_mix_in_factor)))
434
+ mixin_small = data_mix_in.train_test_split(
435
+ test_size=train_size + valid_size,
436
+ shuffle=True, seed=np.random.randint(10000),
437
+ )["test"]
438
+ if valid_size:
439
+ mixin_train_test = mixin_small.train_test_split(
440
+ test_size=valid_size, shuffle=False,
441
+ )
442
+ train_data_mix_in = mixin_train_test["train"]
443
+ valid_data_mix_in = mixin_train_test["test"]
444
+ else:
445
+ train_data_mix_in = mixin_small
446
+
447
+ if "prompt_type" not in train_data_mix_in.column_names:
448
+ train_data_mix_in = train_data_mix_in.add_column(
449
+ "prompt_type",
450
+ [data_mix_in_prompt_type] * train_data_mix_in.num_rows,
451
+ )
452
+ log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
453
+ if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
454
+ valid_data_mix_in = valid_data_mix_in.add_column(
455
+ "prompt_type",
456
+ [data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
457
+ )
458
+ log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
459
+ log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
460
+
461
+ # get our own training/validation data - for fine-tuning
462
+ if val_set_size > 0 and not valid_path and not data_mix_in_path:
463
+ # create valid split from train
464
+ train_val = data["train"].train_test_split(
465
+ test_size=val_set_size, shuffle=True, seed=42
466
+ )
467
+ train_data = train_val["train"]
468
+ valid_data = train_val["test"]
469
+ else:
470
+ train_data = data["train"]
471
+ if valid_path:
472
+ # use given valid split, has priority over data_mix_in_path
473
+ valid_data = data["valid"]
474
+ if "prompt_type" not in train_data.column_names:
475
+ train_data = train_data.add_column(
476
+ "prompt_type",
477
+ [prompt_type] * train_data.num_rows,
478
+ )
479
+ log("Added prompt type %s to training data" % prompt_type)
480
+ if valid_data and "prompt_type" not in valid_data.column_names:
481
+ valid_data = valid_data.add_column(
482
+ "prompt_type",
483
+ [prompt_type] * valid_data.num_rows,
484
+ )
485
+ log("Added prompt type %s to validation data" % prompt_type)
486
+
487
+ assert train_data is not None
488
+
489
+ # shuffle and tokenize data
490
+ if train_data_mix_in:
491
+ train_data = concatenate_datasets([train_data, train_data_mix_in])
492
+ train_data = train_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
493
+ train_set_size = len(train_data)
494
+
495
+ if valid_data and valid_data_mix_in:
496
+ valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
497
+ elif valid_data_mix_in:
498
+ valid_data = valid_data_mix_in
499
+
500
+ if valid_data:
501
+ valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
502
+ val_set_size = len(valid_data)
503
+ else:
504
+ val_set_size = 0
505
+ log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
506
+ sample_row_dict = train_data[:1]
507
+ del sample_row_dict['input_ids']
508
+ del sample_row_dict['attention_mask']
509
+ del sample_row_dict['labels']
510
+ log("Sample input: %s" % sample_row_dict)
511
+
512
+ if neptune_run:
513
+ neptune_callback = NeptuneCallback(run=neptune_run)
514
+ callbacks = [neptune_callback]
515
+ else:
516
+ from transformers.integrations import TensorBoardCallback, is_tensorboard_available
517
+ if is_tensorboard_available:
518
+ # tensorboard --logdir=runs/
519
+ from torch.utils.tensorboard import SummaryWriter
520
+ tb_writer = SummaryWriter()
521
+ callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
522
+ else:
523
+ callbacks = []
524
+
525
+ expected_steps = (train_set_size * num_epochs) // batch_size
526
+ if eval_steps is None and eval_epochs is None:
527
+ # 20 evaluations for a run
528
+ eval_steps = max(1, int(expected_steps / 20))
529
+ log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
530
+ elif eval_steps is None and eval_epochs is not None:
531
+ eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
532
+ log("Auto converted eval_epochs=%s to eval_steps %s"
533
+ " out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
534
+ if save_steps is None:
535
+ save_steps = eval_steps
536
+ log("Auto step save_steps to %s" % save_steps)
537
+ elif save_steps > eval_steps:
538
+ # save steps must be round multiple of eval_steps
539
+ save_steps0 = save_steps
540
+ save_steps = max(1, (save_steps//eval_steps)) * eval_steps
541
+ if save_steps0 != save_steps:
542
+ log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
543
+
544
+ def compute_metrics(eval_preds):
545
+ # e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
546
+ inputs = eval_preds.inputs
547
+ label_ids = eval_preds.label_ids
548
+ predictions = eval_preds.predictions
549
+
550
+ #inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
551
+ #decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
552
+ #decoded_inputs = [pred.strip() for pred in decoded_inputs]
553
+
554
+ label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
555
+ # tokenizer behavior like generate time
556
+ decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
557
+ clean_up_tokenization_spaces=True)
558
+ decoded_labels = [pred.strip() for pred in decoded_labels]
559
+
560
+ predictions = np.argmax(predictions, -1)
561
+ predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
562
+ # tokenizer behavior like generate time
563
+ decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
564
+ clean_up_tokenization_spaces=True)
565
+ decoded_predictions = [pred.strip() for pred in decoded_predictions]
566
+
567
+ result = {}
568
+ for metric in metrics.values():
569
+ result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
570
+ # get rid of lists, for precision etc., for now
571
+ numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
572
+ result.update(numeric_results)
573
+ return result
574
+
575
+ # the callback that computes metrics of interest
576
+ if val_metrics:
577
+ trainer_kwargs = dict(compute_metrics=compute_metrics)
578
+ else:
579
+ trainer_kwargs = dict()
580
+
581
+ trainer = transformers.Trainer(
582
+ model=model,
583
+ tokenizer=tokenizer,
584
+ train_dataset=train_data,
585
+ eval_dataset=valid_data,
586
+ # NOTE: CausalLM is not supporting Seq2SeqTrainingArguments arguments, but not incompatible
587
+ args=transformers.Seq2SeqTrainingArguments(
588
+ per_device_train_batch_size=micro_batch_size,
589
+ per_device_eval_batch_size=1,
590
+ eval_accumulation_steps=10,
591
+ # predict_with_generate=True, # SEQ2SEQ only
592
+ include_inputs_for_metrics=True,
593
+ gradient_accumulation_steps=gradient_accumulation_steps,
594
+ warmup_steps=warmup_steps,
595
+ num_train_epochs=num_epochs,
596
+ learning_rate=learning_rate,
597
+ gradient_checkpointing=gradient_checkpointing,
598
+ fp16=fp16,
599
+ # cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
600
+ optim="adamw_torch", # consider "adafactor" to save memory
601
+ logging_steps=logging_steps,
602
+ logging_strategy="steps",
603
+ evaluation_strategy="steps" if val_set_size > 0 else "no",
604
+ save_strategy="steps",
605
+ eval_steps=eval_steps if val_set_size > 0 else None,
606
+ save_steps=save_steps,
607
+ output_dir=output_dir,
608
+ save_total_limit=3,
609
+ load_best_model_at_end=True if val_set_size > 0 else False,
610
+ ddp_find_unused_parameters=False if ddp else None,
611
+ group_by_length=group_by_length,
612
+ #fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
613
+ #fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
614
+ report_to='tensorboard' if not neptune_run else 'neptune',
615
+ ),
616
+ data_collator=transformers.DataCollatorForSeq2Seq(
617
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
618
+ ),
619
+ callbacks=callbacks,
620
+ **trainer_kwargs,
621
+ )
622
+ model.config.use_cache = False
623
+
624
+ old_state_dict = model.state_dict
625
+ model.state_dict = (
626
+ lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
627
+ ).__get__(model, type(model))
628
+
629
+ if torch.__version__ >= "2" and sys.platform != "win32":
630
+ model = torch.compile(model)
631
+ # WIP (not generally replacing layers until pytorch 2.1)
632
+ torch.backends.cuda.enable_flash_sdp(True)
633
+
634
+ if gpus > 1 and not ddp:
635
+ assert trainer.is_model_parallel
636
+ else:
637
+ assert not trainer.is_model_parallel
638
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
639
+
640
+ model.save_pretrained(output_dir)
641
+
642
+ log("\n If there's a warning about missing keys above, please disregard :)")
643
+
644
+
645
+ def get_loaders(llama_type, model_name, reward_type):
646
+ # NOTE: Some models need specific new prompt_type
647
+ # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
648
+ if llama_type:
649
+ from transformers import LlamaForCausalLM, LlamaTokenizer
650
+ model_loader = LlamaForCausalLM
651
+ tokenizer_loader = LlamaTokenizer
652
+ elif 'gpt2' in model_name.lower():
653
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
654
+ return GPT2LMHeadModel, GPT2Tokenizer
655
+ elif 'mbart-' in model_name.lower():
656
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
657
+ return MBartForConditionalGeneration, MBart50TokenizerFast
658
+ elif 't5' == model_name.lower() or \
659
+ 't5-' in model_name.lower() or \
660
+ 'flan-' in model_name.lower():
661
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
662
+ return T5ForConditionalGeneration, AutoTokenizer
663
+ elif 'bigbird' in model_name:
664
+ from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
665
+ return BigBirdPegasusForConditionalGeneration, AutoTokenizer
666
+ elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
667
+ from transformers import pipeline
668
+ return pipeline, "summarization"
669
+ elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
670
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
671
+ return AutoModelForSequenceClassification, AutoTokenizer
672
+ else:
673
+ from transformers import AutoTokenizer, AutoModelForCausalLM
674
+ model_loader = AutoModelForCausalLM
675
+ tokenizer_loader = AutoTokenizer
676
+ return model_loader, tokenizer_loader
677
+
678
+
679
+ def get_githash():
680
+ try:
681
+ githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
682
+ except:
683
+ githash = ''
684
+ return githash
685
+
686
+
687
+ def copy_code(run_id):
688
+ """
689
+ copy code to track changes
690
+ :param run_id:
691
+ :return:
692
+ """
693
+ rnd_num = str(random.randint(0, 2 ** 31))
694
+ run_id = 'run_' + str(run_id)
695
+ os.makedirs(run_id, exist_ok=True)
696
+ me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
697
+ me_file = os.path.basename(__file__)
698
+ new_me = os.path.join(run_id, me_file + '_' + get_githash())
699
+ if os.path.isfile(new_me):
700
+ new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
701
+ shutil.copy(me_full, new_me)
702
+ else:
703
+ shutil.copy(me_full, new_me)
704
+
705
+
706
+ def get_prompt(prompt_type, chat, context, reduced):
707
+ if prompt_type in [-1, "-1", "plain"]:
708
+ promptA = promptB = PreInstruct = PreInput = PreResponse = ''
709
+ terminate_response = []
710
+ elif prompt_type == 'simple_instruct':
711
+ promptA = promptB = PreInstruct = PreInput = PreResponse = None
712
+ terminate_response = []
713
+ elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
714
+ promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
715
+ promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
716
+
717
+ PreInstruct = """
718
+ ### Instruction:
719
+ """
720
+
721
+ PreInput = """
722
+ ### Input:
723
+ """
724
+
725
+ PreResponse = """
726
+ ### Response:
727
+ """
728
+ if prompt_type in [7, "7", "instruct_with_end"]:
729
+ terminate_response = ['### End']
730
+ else:
731
+ terminate_response = None
732
+ elif prompt_type in [1, "1", "quality"]:
733
+ promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
734
+ promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
735
+
736
+ PreInstruct = """
737
+ ### Instruction:
738
+ """
739
+
740
+ PreInput = """
741
+ ### Input:
742
+ """
743
+
744
+ PreResponse = """
745
+ ### Response:
746
+ """
747
+ terminate_response = None
748
+ elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
749
+ if reduced or context or prompt_type in [2, "2", "human_bot"]:
750
+ preprompt = ''
751
+ else:
752
+ cur_date = time.strftime('%Y-%m-%d')
753
+ cur_time = time.strftime('%H:%M:%S %p %Z')
754
+
755
+ PRE_PROMPT = """\
756
+ Current Date: {}
757
+ Current Time: {}
758
+
759
+ """
760
+ preprompt = PRE_PROMPT.format(cur_date, cur_time)
761
+ start = human
762
+ promptB = promptA = '%s%s ' % (preprompt, start)
763
+
764
+ PreInstruct = ""
765
+
766
+ PreInput = None
767
+
768
+ PreResponse = bot
769
+
770
+ terminate_response = [start, PreResponse]
771
+ elif prompt_type in [3, "3", "dai_faq"]:
772
+ promptA = ''
773
+ promptB = 'Answer the following Driverless AI question.\n'
774
+
775
+ PreInstruct = """
776
+ ### Driverless AI frequently asked question:
777
+ """
778
+
779
+ PreInput = None
780
+
781
+ PreResponse = """
782
+ ### Driverless AI documentation answer:
783
+ """
784
+ terminate_response = ['\n\n']
785
+ elif prompt_type in [5, "5", "summarize"]:
786
+ promptA = promptB = PreInput = ''
787
+ PreInstruct = '## Main Text\n\n'
788
+ PreResponse = '\n\n## Summary\n\n'
789
+ terminate_response = None
790
+ elif prompt_type in [6, "6", "instruct_vicuna"]:
791
+ promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
792
+ "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
793
+
794
+ PreInstruct = """
795
+ ### Human:
796
+ """
797
+
798
+ PreInput = None
799
+
800
+ PreResponse = """
801
+ ### Assistant:
802
+ """
803
+ terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
804
+ else:
805
+ raise RuntimeError("No such prompt_type=%s" % prompt_type)
806
+
807
+ return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
808
+
809
+
810
+ def generate_prompt(data_point, prompt_type, chat, reduced):
811
+ context = data_point.get('context')
812
+ if context is None:
813
+ context = ''
814
+ instruction = data_point.get('instruction')
815
+ input = data_point.get('input')
816
+ output = data_point.get('output')
817
+ prompt_type = data_point.get('prompt_type', prompt_type)
818
+ assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
819
+ promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
820
+
821
+ prompt = context
822
+
823
+ if input and promptA:
824
+ prompt += f"""{promptA}"""
825
+ elif promptB:
826
+ prompt += f"""{promptB}"""
827
+
828
+ if instruction and PreInstruct is not None and input and PreInput is not None:
829
+ prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
830
+ prompt = inject_newline(prompt_type, prompt)
831
+ elif instruction and input and PreInstruct is None and PreInput is not None:
832
+ prompt += f"""{PreInput}{instruction}
833
+ {input}"""
834
+ prompt = inject_newline(prompt_type, prompt)
835
+ elif input and instruction and PreInput is None and PreInstruct is not None:
836
+ prompt += f"""{PreInstruct}{instruction}
837
+ {input}"""
838
+ prompt = inject_newline(prompt_type, prompt)
839
+ elif instruction and PreInstruct is not None:
840
+ prompt += f"""{PreInstruct}{instruction}"""
841
+ prompt = inject_newline(prompt_type, prompt)
842
+ elif input and PreInput is not None:
843
+ prompt += f"""{PreInput}{input}"""
844
+ prompt = inject_newline(prompt_type, prompt)
845
+ elif input and instruction and PreInput is not None:
846
+ prompt += f"""{PreInput}{instruction}{input}"""
847
+ prompt = inject_newline(prompt_type, prompt)
848
+ elif input and instruction and PreInstruct is not None:
849
+ prompt += f"""{PreInstruct}{instruction}{input}"""
850
+ prompt = inject_newline(prompt_type, prompt)
851
+ elif input and instruction:
852
+ # i.e. for simple_instruct
853
+ prompt += f"""{instruction}: {input}"""
854
+ prompt = inject_newline(prompt_type, prompt)
855
+ elif input:
856
+ prompt += f"""{input}"""
857
+ prompt = inject_newline(prompt_type, prompt)
858
+ elif instruction:
859
+ prompt += f"""{instruction}"""
860
+ prompt = inject_newline(prompt_type, prompt)
861
+
862
+ if PreResponse is not None:
863
+ prompt += f"""{PreResponse}"""
864
+ pre_response = PreResponse # Don't use strip
865
+ else:
866
+ pre_response = ''
867
+
868
+ if output:
869
+ prompt += f"""{output}"""
870
+
871
+ return prompt, pre_response, terminate_response
872
+
873
+
874
+ def inject_newline(prompt_type, prompt):
875
+ if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
876
+ # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
877
+ prompt += '\n'
878
+ return prompt
879
+
880
+
881
+ example_data_point0 = dict(instruction="Summarize",
882
+ input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.",
883
+ output="Ducks eat and swim at the lake.")
884
+
885
+ example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?",
886
+ output="Einstein.")
887
+
888
+ example_data_point2 = dict(input="Who is smarter, Einstein or Newton?",
889
+ output="Einstein.")
890
+
891
+ example_data_points = [example_data_point0, example_data_point1, example_data_point2]
892
+
893
+
894
+ def test_train_prompt(prompt_type='instruct', data_point=0):
895
+ example_data_point = example_data_points[data_point]
896
+ return generate_prompt(example_data_point, prompt_type, False, False)
897
+
898
+
899
+ def test_debug():
900
+ fire.Fire(train)
901
+
902
+
903
+ if __name__ == "__main__":
904
+ CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
905
+ CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
906
+ log(f"""
907
+ Example runs on 4 GPUs:
908
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
909
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
910
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
911
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
912
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
913
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
914
+
915
+ All metrics:
916
+ CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
917
+
918
+ # Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
919
+ rippa>
920
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
921
+ ova>
922
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
923
+ timemachine>
924
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
925
+
926
+ """, flush=True)
927
+
928
+ if os.environ.get("LOCAL_RANK") is None:
929
+ # then not using torchrun, so can't do distributed, ensure CVD set
930
+ assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
931
+
932
+ fire.Fire(train)
h2o-logo.svg ADDED
prompter.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from finetune import generate_prompt
2
+
3
+
4
+ class Prompter(object):
5
+ def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
6
+ allowed_repeat_line_length=10):
7
+ self.prompt_type = prompt_type
8
+ data_point = dict(instruction='', input='', output='')
9
+ _, self.pre_response, self.terminate_response = generate_prompt(data_point, prompt_type, chat, False)
10
+ self.debug = debug
11
+ self.chat = chat
12
+ self.stream_output = stream_output
13
+ self.repeat_penalty = repeat_penalty
14
+ self.allowed_repeat_line_length = allowed_repeat_line_length
15
+
16
+ def generate_prompt(self, data_point):
17
+ reduced = False
18
+ prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
19
+ if self.debug:
20
+ print("prompt: ", prompt, flush=True)
21
+ self.prompt = prompt
22
+ return prompt
23
+
24
+ def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
25
+ if isinstance(outputs, str):
26
+ outputs = [outputs]
27
+ if self.debug:
28
+ print("output: ", '\n\n'.join(outputs), flush=True)
29
+ if prompt is not None:
30
+ self.prompt = prompt
31
+
32
+ def clean_response(response):
33
+ meaningless_words = ['<pad>', '</s>', '<|endoftext|>', '”\n']
34
+ for word in meaningless_words:
35
+ response = response.replace(word, "")
36
+ if sanitize_bot_response:
37
+ from better_profanity import profanity
38
+ response = profanity.censor(response)
39
+ response = response.strip("\n")
40
+ return response
41
+
42
+ def clean_repeats(response):
43
+ lines = response.split('\n')
44
+ new_lines = []
45
+ [new_lines.append(line) for line in lines if
46
+ line not in new_lines or len(line) < self.allowed_repeat_line_length]
47
+ if self.debug and len(lines) != len(new_lines):
48
+ print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
49
+ response = '\n'.join(new_lines)
50
+ return response
51
+
52
+ multi_output = len(outputs) > 1
53
+
54
+ for oi, output in enumerate(outputs):
55
+ if self.prompt_type in [0, '0', 'plain']:
56
+ output = clean_response(output)
57
+ else:
58
+ # find first instance of prereponse
59
+ # prompt sometimes has odd characters, that mutate length,
60
+ # so can't go by length alone
61
+ if self.pre_response:
62
+ outputi = output.find(prompt)
63
+ if outputi >= 0:
64
+ output = output[outputi + len(prompt):]
65
+ allow_terminate = True
66
+ else:
67
+ # subtraction is risky due to space offsets sometimes, so only do if necessary
68
+ output = output[len(prompt) - len(self.pre_response):]
69
+ # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
70
+ if self.pre_response in output:
71
+ output = output.split(self.pre_response)[1]
72
+ allow_terminate = True
73
+ else:
74
+ print("Failure of parsing: %s" % output, flush=True)
75
+ allow_terminate = False
76
+ else:
77
+ allow_terminate = True
78
+ output = output[len(prompt):]
79
+ # clean after subtract prompt out, so correct removal of pre_response
80
+ output = clean_response(output).strip()
81
+ if self.repeat_penalty:
82
+ output = clean_repeats(output).strip()
83
+ if self.terminate_response and allow_terminate:
84
+ finds = []
85
+ for term in self.terminate_response:
86
+ finds.append(output.find(term))
87
+ finds = [x for x in finds if x >= 0]
88
+ if len(finds) > 0:
89
+ termi = finds[0]
90
+ output = output[:termi].strip()
91
+ else:
92
+ output = output.strip()
93
+ else:
94
+ output = output.strip()
95
+ if multi_output:
96
+ # prefix with output counter
97
+ output = "\n=========== Output %d\n\n" % (1 + oi) + output
98
+ if oi > 0:
99
+ # post fix outputs with seperator
100
+ output += '\n'
101
+ outputs[oi] = output
102
+ # join all outputs, only one extra new line between outputs
103
+ output = '\n'.join(outputs)
104
+ if self.debug:
105
+ print("outputclean: ", '\n\n'.join(outputs), flush=True)
106
+ return output
requirements.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for generate (gradio server) and finetune
2
+ datasets==2.11.0
3
+ sentencepiece==0.1.97
4
+ accelerate==0.18.0
5
+ gradio==3.27.0
6
+ huggingface_hub==0.13.4
7
+ appdirs==1.4.4
8
+ fire==0.5.0
9
+ docutils==0.19
10
+ torch==2.0.0
11
+ evaluate==0.4.0
12
+ rouge_score==0.1.2
13
+ sacrebleu==2.3.1
14
+ scikit-learn==1.2.2
15
+ alt-profanity-check==1.2.2
16
+ better-profanity==0.6.1
17
+ numpy==1.24.2
18
+ pandas==2.0.0
19
+ matplotlib==3.7.1
20
+ loralib==0.1.1
21
+ bitsandbytes==0.38.1
22
+ git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
23
+ transformers==4.28.1
24
+ tokenizers==0.13.3
25
+
26
+ # optional for generate
27
+ pynvml==11.5.0
28
+ psutil==5.9.4
29
+ boto3==1.26.101
30
+ botocore==1.29.101
31
+
32
+ # optional for finetune
33
+ tensorboard==2.12.1
34
+ neptune==1.1.1
35
+
36
+ # for gradio client
37
+ gradio_client==0.1.3
38
+ beautifulsoup4==4.12.2
39
+ markdown==3.4.1
40
+
41
+ # data and testing
42
+ pytest==7.2.2
43
+ pytest-xdist==3.2.1
44
+ nltk==3.8.1
45
+ textstat==0.7.3
46
+ pandoc==2.3
47
+ pypandoc==1.11
48
+ openpyxl==3.1.2
49
+ lm_dataformat==0.0.20
50
+ bioc==2.0
stopping.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from queue import Queue
3
+ from threading import Thread
4
+ import collections.abc
5
+
6
+ import torch
7
+ from transformers import StoppingCriteria
8
+
9
+
10
+ class StoppingCriteriaSub(StoppingCriteria):
11
+
12
+ def __init__(self, stops=[], encounters=[]):
13
+ super().__init__()
14
+ assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
15
+ self.encounters = encounters
16
+ self.stops = [stop.to("cuda") for stop in stops]
17
+ self.num_stops = [0] * len(stops)
18
+
19
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
20
+ for stopi, stop in enumerate(self.stops):
21
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
22
+ self.num_stops[stopi] += 1
23
+ if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
24
+ return True
25
+ # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
26
+ # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
27
+ return False
28
+
29
+
30
+ class Stream(StoppingCriteria):
31
+ """
32
+ This class can be used to callback during generation. Keep
33
+ in mind for decoder-only type of transformers, this will include the initial prompted tokens.
34
+
35
+ Args:
36
+ func (`callable`):
37
+ A callable function to apply on first input in list every iteration of generation
38
+ """
39
+
40
+ def __init__(self, func=None):
41
+ self.func = func
42
+
43
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
44
+ if self.func is not None:
45
+ # only consume first of multiple responses
46
+ self.func(input_ids[0])
47
+ return False
48
+
49
+
50
+ class CallbackToGenerator(collections.abc.Generator):
51
+ """
52
+ A generator wrapper for a function that invokes a callback multiple times.
53
+
54
+ Calling `send` on the generator emits a value from one callback, and returns
55
+ the next.
56
+
57
+ Note this starts a background thread
58
+ """
59
+
60
+ def __init__(self, func, *args, callback=None, **kwargs):
61
+ self.func = func
62
+ self.args = args
63
+ self.kwargs = kwargs
64
+ self.callback = callback
65
+
66
+ self._ready_queue = Queue(1)
67
+ self._done_queue = Queue(1)
68
+ self._done_holder = [False]
69
+
70
+ # local to avoid reference cycles
71
+ ready_queue = self._ready_queue
72
+ done_queue = self._done_queue
73
+ done_holder = self._done_holder
74
+
75
+ def val_callback(value):
76
+ done_queue.put((False, value))
77
+ cmd, val = ready_queue.get()
78
+ if cmd == 'send':
79
+ return val
80
+ elif cmd == 'throw':
81
+ raise val
82
+ else:
83
+ assert False # pragma: no cover
84
+
85
+ def thread_func():
86
+ while True:
87
+ cmd, val = ready_queue.get()
88
+ if cmd == 'send' and val is not None:
89
+ done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
90
+ continue
91
+ break
92
+ try:
93
+ if cmd == 'throw':
94
+ raise val
95
+ ret = func(callback=val_callback, **self.kwargs)
96
+ raise StopIteration(ret) if ret is not None else StopIteration
97
+ except BaseException as e:
98
+ done_holder[0] = True
99
+ done_queue.put((True, e))
100
+
101
+ self._thread = Thread(target=thread_func)
102
+ self._thread.start()
103
+
104
+ def _put(self, *args):
105
+ if self._done_holder[0]:
106
+ raise StopIteration
107
+ self._ready_queue.put(args)
108
+ is_exception, val = self._done_queue.get()
109
+ if is_exception:
110
+ try:
111
+ raise val
112
+ finally:
113
+ # prevent val's traceback containing a reference cycle
114
+ del val
115
+ else:
116
+ return val
117
+
118
+ def send(self, value):
119
+ return self._put('send', value)
120
+
121
+ def throw(self, exc):
122
+ return self._put('throw', exc)
123
+
124
+ def close(self):
125
+ try:
126
+ self.throw(GeneratorExit)
127
+ except StopIteration:
128
+ self._thread.join()
129
+ except GeneratorExit:
130
+ self._thread.join()
131
+ except BaseException:
132
+ self._thread.join()
133
+ raise
134
+ else:
135
+ # yielded again, can't clean up the thread
136
+ raise RuntimeError('Task with callback ignored GeneratorExit')
137
+
138
+ def __del__(self):
139
+ self.close()
utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import random
4
+ import time
5
+ import traceback
6
+ import zipfile
7
+ from datetime import datetime
8
+ import filelock
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+
13
+
14
+ def set_seed(seed: int):
15
+ """
16
+ Sets the seed of the entire notebook so results are the same every time we run.
17
+ This is for REPRODUCIBILITY.
18
+ """
19
+ np.random.seed(seed)
20
+ random_state = np.random.RandomState(seed)
21
+ random.seed(seed)
22
+ torch.manual_seed(seed)
23
+ torch.cuda.manual_seed(seed)
24
+ torch.backends.cudnn.deterministic = True
25
+ torch.backends.cudnn.benchmark = False
26
+ os.environ['PYTHONHASHSEED'] = str(seed)
27
+ return random_state
28
+
29
+
30
+ def flatten_list(lis):
31
+ """Given a list, possibly nested to any level, return it flattened."""
32
+ new_lis = []
33
+ for item in lis:
34
+ if type(item) == type([]):
35
+ new_lis.extend(flatten_list(item))
36
+ else:
37
+ new_lis.append(item)
38
+ return new_lis
39
+
40
+
41
+ def clear_torch_cache():
42
+ if torch.cuda.is_available:
43
+ torch.cuda.empty_cache()
44
+ torch.cuda.ipc_collect()
45
+ gc.collect()
46
+
47
+
48
+ def system_info():
49
+ import psutil
50
+
51
+ system = {}
52
+ # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
53
+ # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
54
+ temps = psutil.sensors_temperatures(fahrenheit=False)
55
+ if 'coretemp' in temps:
56
+ coretemp = temps['coretemp']
57
+ temp_dict = {k.label: k.current for k in coretemp}
58
+ for k, v in temp_dict.items():
59
+ system['CPU_C/%s' % k] = v
60
+
61
+ # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
62
+ from pynvml.smi import nvidia_smi
63
+ nvsmi = nvidia_smi.getInstance()
64
+
65
+ gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
66
+ enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
67
+ for k, v in gpu_power_dict.items():
68
+ system['GPU_W/%s' % k] = v
69
+
70
+ gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
71
+ enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
72
+ for k, v in gpu_temp_dict.items():
73
+ system['GPU_C/%s' % k] = v
74
+
75
+ gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
76
+ enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
77
+ gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
78
+ enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
79
+ gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
80
+ for k, v in gpu_memory_frac_dict.items():
81
+ system[f'GPU_M/%s' % k] = v
82
+
83
+ return system
84
+
85
+
86
+ def system_info_print():
87
+ try:
88
+ df = pd.DataFrame.from_dict(system_info(), orient='index')
89
+ # avoid slamming GPUs
90
+ time.sleep(1)
91
+ return df.to_markdown()
92
+ except Exception as e:
93
+ return "Error: %s" % str(e)
94
+
95
+
96
+ def zip_data(root_dirs=None, zip_file=None, base_dir='./'):
97
+ try:
98
+ return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
99
+ except Exception as e:
100
+ traceback.print_exc()
101
+ print('Exception in zipping: %s' % str(e))
102
+
103
+
104
+ def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
105
+ if zip_file is None:
106
+ datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
107
+ host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
108
+ zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
109
+ assert root_dirs is not None
110
+
111
+ with zipfile.ZipFile(zip_file, "w") as expt_zip:
112
+ for root_dir in root_dirs:
113
+ if root_dir is None:
114
+ continue
115
+ for root, d, files in os.walk(root_dir):
116
+ for file in files:
117
+ file_to_archive = os.path.join(root, file)
118
+ assert os.path.exists(file_to_archive)
119
+ path_to_archive = os.path.relpath(file_to_archive, base_dir)
120
+ expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
121
+ return zip_file, zip_file
122
+
123
+
124
+ def save_generate_output(output=None, base_model=None, save_dir=None):
125
+ try:
126
+ return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
127
+ except Exception as e:
128
+ traceback.print_exc()
129
+ print('Exception in saving: %s' % str(e))
130
+
131
+
132
+ def _save_generate_output(output=None, base_model=None, save_dir=None):
133
+ """
134
+ Save conversation to .json, row by row.
135
+ json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
136
+ Appends if file exists
137
+ """
138
+ assert save_dir, "save_dir must be provided"
139
+ if os.path.exists(save_dir) and not os.path.isdir(save_dir):
140
+ raise RuntimeError("save_dir already exists and is not a directory!")
141
+ os.makedirs(save_dir, exist_ok=True)
142
+ import json
143
+ if output[-10:] == '\n\n<human>:':
144
+ # remove trailing <human>:
145
+ output = output[:-10]
146
+ with filelock.FileLock("save_dir.lock"):
147
+ # lock logging in case have concurrency
148
+ with open(os.path.join(save_dir, "history.json"), "a") as f:
149
+ # just add [ at start, and ] at end, and have proper JSON dataset
150
+ f.write(
151
+ " " + json.dumps(
152
+ dict(text=output, time=time.ctime(), base_model=base_model)
153
+ ) + ",\n"
154
+ )
155
+
156
+
157
+ def s3up(filename):
158
+ try:
159
+ return _s3up(filename)
160
+ except Exception as e:
161
+ traceback.print_exc()
162
+ print('Exception for file %s in s3up: %s' % (filename, str(e)))
163
+ return "Failed to upload %s: Error: %s" % (filename, str(e))
164
+
165
+
166
+ def _s3up(filename):
167
+ import boto3
168
+
169
+ aws_access_key_id = os.getenv('AWS_SERVER_PUBLIC_KEY')
170
+ aws_secret_access_key = os.getenv('AWS_SERVER_SECRET_KEY')
171
+ bucket = os.getenv('AWS_BUCKET')
172
+ assert aws_access_key_id, "Set AWS key"
173
+ assert aws_secret_access_key, "Set AWS secret"
174
+ assert bucket, "Set AWS Bucket"
175
+
176
+ s3 = boto3.client('s3',
177
+ aws_access_key_id=os.getenv('AWS_SERVER_PUBLIC_KEY'),
178
+ aws_secret_access_key=os.getenv('AWS_SERVER_SECRET_KEY'),
179
+ )
180
+ ret = s3.upload_file(
181
+ Filename=filename,
182
+ Bucket=os.getenv('AWS_BUCKET'),
183
+ Key=filename,
184
+ )
185
+ if ret in [None, '']:
186
+ return "Successfully uploaded %s" % filename