Spaces:
Runtime error
Runtime error
Commit
·
d7185d6
0
Parent(s):
Duplicate from h2oai/h2ogpt-chatbot
Browse filesCo-authored-by: Jonathan McKinney <[email protected]>
- .gitattributes +34 -0
- LICENSE +201 -0
- README.md +14 -0
- app.py +0 -0
- client_test.py +121 -0
- finetune.py +932 -0
- h2o-logo.svg +1 -0
- prompter.py +106 -0
- requirements.txt +50 -0
- stopping.py +139 -0
- utils.py +186 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: H2ogpt Chatbot
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.27.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
duplicated_from: h2oai/h2ogpt-chatbot
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
client_test.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Client test.
|
3 |
+
|
4 |
+
Run server:
|
5 |
+
|
6 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
|
7 |
+
|
8 |
+
NOTE: For private models, add --use-auth_token=True
|
9 |
+
|
10 |
+
NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
|
11 |
+
Currently, this will force model to be on a single GPU.
|
12 |
+
|
13 |
+
Then run this client as:
|
14 |
+
|
15 |
+
python client_test.py
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
For HF spaces:
|
20 |
+
|
21 |
+
HOST="https://h2oai-h2ogpt-chatbot.hf.space" python client_test.py
|
22 |
+
|
23 |
+
Result:
|
24 |
+
|
25 |
+
Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔
|
26 |
+
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.'}
|
27 |
+
|
28 |
+
|
29 |
+
For demo:
|
30 |
+
|
31 |
+
HOST="https://gpt.h2o.ai" python client_test.py
|
32 |
+
|
33 |
+
Result:
|
34 |
+
|
35 |
+
Loaded as API: https://gpt.h2o.ai ✔
|
36 |
+
{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.'}
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
debug = False
|
41 |
+
|
42 |
+
import os
|
43 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
44 |
+
|
45 |
+
|
46 |
+
def get_client():
|
47 |
+
from gradio_client import Client
|
48 |
+
|
49 |
+
client = Client(os.getenv('HOST', "http://localhost:7860"))
|
50 |
+
if debug:
|
51 |
+
print(client.view_api(all_endpoints=True))
|
52 |
+
return client
|
53 |
+
|
54 |
+
|
55 |
+
def test_client_basic():
|
56 |
+
instruction = '' # only for chat=True
|
57 |
+
iinput = '' # only for chat=True
|
58 |
+
context = ''
|
59 |
+
# streaming output is supported, loops over and outputs each generation in streaming mode
|
60 |
+
# but leave stream_output=False for simple input/output mode
|
61 |
+
stream_output = False
|
62 |
+
prompt_type = 'human_bot'
|
63 |
+
temperature = 0.1
|
64 |
+
top_p = 0.75
|
65 |
+
top_k = 40
|
66 |
+
num_beams = 1
|
67 |
+
max_new_tokens = 50
|
68 |
+
min_new_tokens = 0
|
69 |
+
early_stopping = False
|
70 |
+
max_time = 20
|
71 |
+
repetition_penalty = 1.0
|
72 |
+
num_return_sequences = 1
|
73 |
+
do_sample = True
|
74 |
+
# only these 2 below used if pass chat=False
|
75 |
+
chat = False
|
76 |
+
instruction_nochat = "Who are you?"
|
77 |
+
iinput_nochat = ''
|
78 |
+
|
79 |
+
args = [instruction,
|
80 |
+
iinput,
|
81 |
+
context,
|
82 |
+
stream_output,
|
83 |
+
prompt_type,
|
84 |
+
temperature,
|
85 |
+
top_p,
|
86 |
+
top_k,
|
87 |
+
num_beams,
|
88 |
+
max_new_tokens,
|
89 |
+
min_new_tokens,
|
90 |
+
early_stopping,
|
91 |
+
max_time,
|
92 |
+
repetition_penalty,
|
93 |
+
num_return_sequences,
|
94 |
+
do_sample,
|
95 |
+
chat,
|
96 |
+
instruction_nochat,
|
97 |
+
iinput_nochat,
|
98 |
+
]
|
99 |
+
api_name = '/submit_nochat'
|
100 |
+
client = get_client()
|
101 |
+
res = client.predict(
|
102 |
+
*tuple(args),
|
103 |
+
api_name=api_name,
|
104 |
+
)
|
105 |
+
res_dict = dict(instruction_nochat=instruction_nochat, iinput_nochat=iinput_nochat, response=md_to_text(res))
|
106 |
+
print(res_dict)
|
107 |
+
return res_dict
|
108 |
+
|
109 |
+
|
110 |
+
import markdown # pip install markdown
|
111 |
+
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
112 |
+
|
113 |
+
|
114 |
+
def md_to_text(md):
|
115 |
+
html = markdown.markdown(md)
|
116 |
+
soup = BeautifulSoup(html, features='html.parser')
|
117 |
+
return soup.get_text()
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
test_client_basic()
|
finetune.py
ADDED
@@ -0,0 +1,932 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
import random
|
4 |
+
import shutil
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from datetime import datetime
|
9 |
+
from typing import List, Union
|
10 |
+
import fire
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from datasets import load_dataset, concatenate_datasets
|
14 |
+
import transformers
|
15 |
+
import torch.distributed as dist
|
16 |
+
|
17 |
+
from peft import (
|
18 |
+
prepare_model_for_int8_training,
|
19 |
+
LoraConfig,
|
20 |
+
get_peft_model,
|
21 |
+
get_peft_model_state_dict,
|
22 |
+
set_peft_model_state_dict,
|
23 |
+
)
|
24 |
+
|
25 |
+
from peft import mapping
|
26 |
+
lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
27 |
+
|
28 |
+
|
29 |
+
def log(*args, **kwargs):
|
30 |
+
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
31 |
+
print(*args, **kwargs)
|
32 |
+
|
33 |
+
|
34 |
+
try:
|
35 |
+
import neptune
|
36 |
+
from transformers.integrations import NeptuneCallback
|
37 |
+
|
38 |
+
neptune_run = neptune.init_run(
|
39 |
+
source_files=[],
|
40 |
+
)
|
41 |
+
log("Connected to Neptune.")
|
42 |
+
except ImportError:
|
43 |
+
neptune_run = None
|
44 |
+
log("Please pip install neptune for tracking.")
|
45 |
+
except neptune.exceptions.NeptuneMissingApiTokenException:
|
46 |
+
neptune_run = None
|
47 |
+
os.environ["NEPTUNE_MODE"] = 'debug'
|
48 |
+
log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
|
49 |
+
|
50 |
+
from enum import Enum
|
51 |
+
|
52 |
+
|
53 |
+
class PromptType(Enum):
|
54 |
+
plain = 0
|
55 |
+
instruct = 1
|
56 |
+
quality = 2
|
57 |
+
human_bot = 3
|
58 |
+
dai_faq = 4
|
59 |
+
summarize = 5
|
60 |
+
simple_instruct = 6
|
61 |
+
instruct_vicuna = 7
|
62 |
+
instruct_with_end = 8
|
63 |
+
human_bot_orig = 9
|
64 |
+
|
65 |
+
|
66 |
+
prompt_type_to_model_name = {
|
67 |
+
'plain': [
|
68 |
+
'EleutherAI/gpt-j-6B',
|
69 |
+
'EleutherAI/pythia-6.9b',
|
70 |
+
'EleutherAI/pythia-12b',
|
71 |
+
'EleutherAI/pythia-12b-deduped',
|
72 |
+
'EleutherAI/gpt-neox-20b',
|
73 |
+
'decapoda-research/llama-7b-hf',
|
74 |
+
'decapoda-research/llama-13b-hf',
|
75 |
+
'decapoda-research/llama-30b-hf',
|
76 |
+
'decapoda-research/llama-65b-hf',
|
77 |
+
'facebook/mbart-large-50-many-to-many-mmt',
|
78 |
+
'philschmid/bart-large-cnn-samsum',
|
79 |
+
'philschmid/flan-t5-base-samsum',
|
80 |
+
'gpt2',
|
81 |
+
'distilgpt2',
|
82 |
+
],
|
83 |
+
'instruct': [],
|
84 |
+
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
85 |
+
'quality': [],
|
86 |
+
'human_bot': [
|
87 |
+
'h2oai/h2ogpt-oasst1-512-12b',
|
88 |
+
'h2oai/h2ogpt-oasst1-512-20b',
|
89 |
+
'h2oai/h2ogpt-oig-oasst1-512-6.9b',
|
90 |
+
],
|
91 |
+
'dai_faq': [],
|
92 |
+
'summarize': [],
|
93 |
+
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
94 |
+
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
|
95 |
+
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
96 |
+
}
|
97 |
+
|
98 |
+
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
99 |
+
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
100 |
+
|
101 |
+
human = '<human>:'
|
102 |
+
bot = "<bot>:"
|
103 |
+
|
104 |
+
prompt_types_strings = []
|
105 |
+
for p in PromptType:
|
106 |
+
prompt_types_strings.extend([p.name])
|
107 |
+
|
108 |
+
|
109 |
+
prompt_types = []
|
110 |
+
for p in PromptType:
|
111 |
+
prompt_types.extend([p.name, p.value, str(p.value)])
|
112 |
+
|
113 |
+
|
114 |
+
# supported by huggingface evaluate
|
115 |
+
supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
|
116 |
+
|
117 |
+
|
118 |
+
def train(
|
119 |
+
save_code: bool = False,
|
120 |
+
run_id: int = None,
|
121 |
+
|
122 |
+
base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6.9b',
|
123 |
+
# base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
|
124 |
+
# base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
|
125 |
+
# base_model: str = 'EleutherAI/gpt-neox-20b',
|
126 |
+
# base_model: str = 'EleutherAI/pythia-12b-deduped',
|
127 |
+
# base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
|
128 |
+
# base_model: str = 'decapoda-research/llama-7b-hf',
|
129 |
+
# base_model: str = 'decapoda-research/llama-13b-hf',
|
130 |
+
# base_model: str = 'decapoda-research/llama-30b-hf',
|
131 |
+
# base_model: str = 'EleutherAI/gpt-j-6B',
|
132 |
+
|
133 |
+
# only needed if base_model is self-exported HF state without tokenizer
|
134 |
+
tokenizer_base_model: str = None,
|
135 |
+
# tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
|
136 |
+
|
137 |
+
data_path: str = None,
|
138 |
+
data_col_dict: dict = None,
|
139 |
+
# data_path: str = "./dai_docs.train.json",
|
140 |
+
prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
|
141 |
+
|
142 |
+
valid_path: str = None,
|
143 |
+
# valid_path: str = "./dai_docs.valid.json",
|
144 |
+
|
145 |
+
# data_mix_in_path: str = "laion/OIG", # way too big, medium quality
|
146 |
+
data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
|
147 |
+
data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
|
148 |
+
data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
|
149 |
+
data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
|
150 |
+
|
151 |
+
output_dir: str = None,
|
152 |
+
|
153 |
+
# LoRA checkpoint continuation
|
154 |
+
lora_weights: str = "",
|
155 |
+
|
156 |
+
# batching training hyperparams
|
157 |
+
batch_size: int = 128,
|
158 |
+
micro_batch_size: int = 4,
|
159 |
+
gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
|
160 |
+
fp16=True,
|
161 |
+
|
162 |
+
# general training hyperparams
|
163 |
+
num_epochs: float = 1,
|
164 |
+
learning_rate: float = 3e-4,
|
165 |
+
|
166 |
+
# validation settings
|
167 |
+
val_set_size: int = None,
|
168 |
+
val_metrics: List[str] = [],
|
169 |
+
eval_steps: int = None, # to control eval steps via steps
|
170 |
+
eval_epochs: float = None, # to control eval steps via epochs
|
171 |
+
|
172 |
+
# lora hyperparams
|
173 |
+
lora_r: int = 8,
|
174 |
+
lora_alpha: int = 16,
|
175 |
+
lora_dropout: float = 0.05,
|
176 |
+
lora_target_modules: List[str] = None,
|
177 |
+
llama_type: bool = None,
|
178 |
+
|
179 |
+
# llm hyperparams
|
180 |
+
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
181 |
+
group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
|
182 |
+
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
183 |
+
cutoff_len: int = 1024, # Good default, especially when have high quality non-trivial data
|
184 |
+
|
185 |
+
# torch training params
|
186 |
+
ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
|
187 |
+
local_files_only: bool = False, # else will download new versions, normally unwanted
|
188 |
+
resume_download: bool = True,
|
189 |
+
use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
|
190 |
+
warmup_steps: int = 100,
|
191 |
+
logging_steps: int = 1,
|
192 |
+
save_steps: int = None, # must be round multiple of eval_steps
|
193 |
+
add_eos_token: bool = False,
|
194 |
+
):
|
195 |
+
# allow set token directly
|
196 |
+
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
197 |
+
|
198 |
+
prompt_type = str(prompt_type) # migration from integers
|
199 |
+
assert prompt_type in prompt_types
|
200 |
+
|
201 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
202 |
+
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
203 |
+
rank = int(os.getenv("RANK", 0))
|
204 |
+
print(f"local_rank: {local_rank}")
|
205 |
+
print(f"global rank: {rank}")
|
206 |
+
|
207 |
+
gpus = max(world_size, torch.cuda.device_count())
|
208 |
+
run_id = run_id or 0
|
209 |
+
if not data_path:
|
210 |
+
raise ValueError("No data_path provided")
|
211 |
+
if not output_dir:
|
212 |
+
output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
|
213 |
+
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
214 |
+
raise FileExistsError(f"output_dir based on run_id {run_id} already exists. Please pick a different run_id.")
|
215 |
+
else:
|
216 |
+
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
217 |
+
raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
|
218 |
+
device_map = "auto"
|
219 |
+
|
220 |
+
if save_code:
|
221 |
+
copy_code(run_id)
|
222 |
+
if tokenizer_base_model is None:
|
223 |
+
tokenizer_base_model = base_model
|
224 |
+
if llama_type is None:
|
225 |
+
llama_type = "llama" in base_model.lower()
|
226 |
+
assert (
|
227 |
+
base_model
|
228 |
+
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
229 |
+
gradient_accumulation_steps = batch_size // micro_batch_size
|
230 |
+
assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
|
231 |
+
|
232 |
+
device_map = "auto"
|
233 |
+
|
234 |
+
locals_dict = locals()
|
235 |
+
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
236 |
+
log(f"Training model with params:\n{locals_print}")
|
237 |
+
log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
|
238 |
+
|
239 |
+
max_memory = None
|
240 |
+
if gpus > 1:
|
241 |
+
if ddp:
|
242 |
+
log("Distributed: data parallel")
|
243 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
244 |
+
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
245 |
+
else:
|
246 |
+
free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
|
247 |
+
max_memory = f"{free_in_GB - 2}GB"
|
248 |
+
max_memory = {i: max_memory for i in range(gpus)}
|
249 |
+
log("world_size: %d" % world_size)
|
250 |
+
log("num_gpus: %d" % gpus)
|
251 |
+
log("max mem: %s" % max_memory)
|
252 |
+
|
253 |
+
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
|
254 |
+
|
255 |
+
model = model_loader.from_pretrained(
|
256 |
+
base_model,
|
257 |
+
load_in_8bit=True,
|
258 |
+
device_map=device_map,
|
259 |
+
torch_dtype=torch.float16,
|
260 |
+
max_memory=max_memory,
|
261 |
+
local_files_only=local_files_only,
|
262 |
+
resume_download=resume_download,
|
263 |
+
use_auth_token=use_auth_token,
|
264 |
+
)
|
265 |
+
if gpus > 1:
|
266 |
+
if not ddp:
|
267 |
+
log("model parallel")
|
268 |
+
model.is_parallelizable = True
|
269 |
+
model.model_parallel = True
|
270 |
+
|
271 |
+
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
272 |
+
local_files_only=local_files_only,
|
273 |
+
resume_download=resume_download,
|
274 |
+
use_auth_token=use_auth_token)
|
275 |
+
|
276 |
+
tokenizer.pad_token_id = 0 # different from the eos token
|
277 |
+
# when generating, we will use the logits of right-most token to predict the next token
|
278 |
+
# so the padding should be on the left,
|
279 |
+
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
280 |
+
tokenizer.padding_side = "left" # Allow batched inference
|
281 |
+
|
282 |
+
def tokenize(prompt, add_eos_token=True):
|
283 |
+
# there's probably a way to do this with the tokenizer settings
|
284 |
+
# but again, gotta move fast
|
285 |
+
result = tokenizer(
|
286 |
+
prompt,
|
287 |
+
truncation=True,
|
288 |
+
max_length=cutoff_len,
|
289 |
+
padding=False,
|
290 |
+
return_tensors=None,
|
291 |
+
)
|
292 |
+
if (
|
293 |
+
result["input_ids"][-1] != tokenizer.eos_token_id
|
294 |
+
and len(result["input_ids"]) < cutoff_len
|
295 |
+
and add_eos_token
|
296 |
+
):
|
297 |
+
result["input_ids"].append(tokenizer.eos_token_id)
|
298 |
+
result["attention_mask"].append(1)
|
299 |
+
|
300 |
+
result["labels"] = result["input_ids"].copy()
|
301 |
+
|
302 |
+
return result
|
303 |
+
|
304 |
+
def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
|
305 |
+
full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
|
306 |
+
tokenized_full_prompt = tokenize(full_prompt)
|
307 |
+
if not train_on_inputs:
|
308 |
+
user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
|
309 |
+
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos)
|
310 |
+
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
311 |
+
if add_eos:
|
312 |
+
user_prompt_len -= 1
|
313 |
+
|
314 |
+
# ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
|
315 |
+
tokenized_full_prompt["labels"] = [
|
316 |
+
-100
|
317 |
+
] * user_prompt_len + tokenized_full_prompt["labels"][
|
318 |
+
user_prompt_len:
|
319 |
+
] # could be sped up, probably
|
320 |
+
return tokenized_full_prompt
|
321 |
+
|
322 |
+
if "gpt-neox" not in base_model or True:
|
323 |
+
model = prepare_model_for_int8_training(model)
|
324 |
+
else:
|
325 |
+
model = prepare_model_for_int8_training(
|
326 |
+
model,
|
327 |
+
output_embedding_layer_name="embed_out", # keep output logits in float32
|
328 |
+
layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
|
329 |
+
)
|
330 |
+
if lora_weights:
|
331 |
+
from peft import PeftModel
|
332 |
+
model = PeftModel.from_pretrained(
|
333 |
+
model,
|
334 |
+
lora_weights,
|
335 |
+
torch_dtype=torch.float16,
|
336 |
+
device_map=device_map,
|
337 |
+
local_files_only=local_files_only,
|
338 |
+
resume_download=resume_download,
|
339 |
+
use_auth_token=use_auth_token,
|
340 |
+
)
|
341 |
+
else:
|
342 |
+
if lora_target_modules is None:
|
343 |
+
base_model_lower = base_model.lower()
|
344 |
+
if base_model_lower in lora_mappings:
|
345 |
+
lora_target_modules_cand = [lora_mappings[base_model_lower]]
|
346 |
+
else:
|
347 |
+
lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
|
348 |
+
else:
|
349 |
+
lora_target_modules_cand = [lora_target_modules]
|
350 |
+
|
351 |
+
for lora_target_modules in lora_target_modules_cand:
|
352 |
+
try:
|
353 |
+
config = LoraConfig(
|
354 |
+
r=lora_r,
|
355 |
+
lora_alpha=lora_alpha,
|
356 |
+
target_modules=lora_target_modules,
|
357 |
+
lora_dropout=lora_dropout,
|
358 |
+
bias="none",
|
359 |
+
task_type="CAUSAL_LM",
|
360 |
+
)
|
361 |
+
model = get_peft_model(model, config)
|
362 |
+
break
|
363 |
+
except ValueError as e:
|
364 |
+
if "Target modules" in str(e) and "not found" in str(e):
|
365 |
+
continue
|
366 |
+
else:
|
367 |
+
raise
|
368 |
+
from peft import PeftModel
|
369 |
+
assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
|
370 |
+
if resume_from_checkpoint:
|
371 |
+
# Check the available weights and load them
|
372 |
+
checkpoint_name = os.path.join(
|
373 |
+
resume_from_checkpoint, "pytorch_model.bin"
|
374 |
+
) # Full checkpoint
|
375 |
+
if not os.path.exists(checkpoint_name):
|
376 |
+
checkpoint_name = os.path.join(
|
377 |
+
resume_from_checkpoint, "adapter_model.bin"
|
378 |
+
) # only LoRA model - LoRA config above has to fit
|
379 |
+
resume_from_checkpoint = False # So the trainer won't try loading its state
|
380 |
+
# The two files above have a different name depending on how they were saved, but are actually the same.
|
381 |
+
if os.path.exists(checkpoint_name):
|
382 |
+
log(f"Restarting from {checkpoint_name}")
|
383 |
+
adapters_weights = torch.load(checkpoint_name)
|
384 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
385 |
+
else:
|
386 |
+
log(f"Checkpoint {checkpoint_name} not found")
|
387 |
+
|
388 |
+
print(model)
|
389 |
+
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
390 |
+
|
391 |
+
metrics = {}
|
392 |
+
for name in supported_metrics:
|
393 |
+
if name in val_metrics:
|
394 |
+
import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
|
395 |
+
metrics[name] = evaluate.load(name)
|
396 |
+
log("Using Validation Metrics: %s" % str(list(metrics.keys())))
|
397 |
+
log("Supported Metrics: %s" % supported_metrics)
|
398 |
+
|
399 |
+
if val_set_size is None:
|
400 |
+
if len(metrics) == 0:
|
401 |
+
val_set_size = 1000
|
402 |
+
else:
|
403 |
+
val_set_size = 100
|
404 |
+
log("Auto set val_set_size %s" % val_set_size)
|
405 |
+
elif val_set_size < 1.0 and val_set_size != 0:
|
406 |
+
raise RuntimeError("Fractional validation size not supported.")
|
407 |
+
|
408 |
+
if valid_path:
|
409 |
+
data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
|
410 |
+
else:
|
411 |
+
if "json" in data_path:
|
412 |
+
data = load_dataset("json", data_files={"train": data_path})
|
413 |
+
else:
|
414 |
+
data = load_dataset(data_path)
|
415 |
+
data = data.rename_columns(data_col_dict or {})
|
416 |
+
|
417 |
+
valid_data = None
|
418 |
+
train_data_mix_in = None
|
419 |
+
valid_data_mix_in = None
|
420 |
+
|
421 |
+
if data_mix_in_path and data_mix_in_factor > 0:
|
422 |
+
# get mix-in training/validation data - to keep model "sane"
|
423 |
+
num_rows = data["train"].num_rows
|
424 |
+
log("Loading mix-in dataset: %s" % data_mix_in_path)
|
425 |
+
if "json" in data_mix_in_path:
|
426 |
+
data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
|
427 |
+
else:
|
428 |
+
data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
|
429 |
+
data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
|
430 |
+
|
431 |
+
# only get as much as we need to balance
|
432 |
+
valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
|
433 |
+
train_size = max(1, min(data_mix_in.num_rows - valid_size, int(num_rows * data_mix_in_factor)))
|
434 |
+
mixin_small = data_mix_in.train_test_split(
|
435 |
+
test_size=train_size + valid_size,
|
436 |
+
shuffle=True, seed=np.random.randint(10000),
|
437 |
+
)["test"]
|
438 |
+
if valid_size:
|
439 |
+
mixin_train_test = mixin_small.train_test_split(
|
440 |
+
test_size=valid_size, shuffle=False,
|
441 |
+
)
|
442 |
+
train_data_mix_in = mixin_train_test["train"]
|
443 |
+
valid_data_mix_in = mixin_train_test["test"]
|
444 |
+
else:
|
445 |
+
train_data_mix_in = mixin_small
|
446 |
+
|
447 |
+
if "prompt_type" not in train_data_mix_in.column_names:
|
448 |
+
train_data_mix_in = train_data_mix_in.add_column(
|
449 |
+
"prompt_type",
|
450 |
+
[data_mix_in_prompt_type] * train_data_mix_in.num_rows,
|
451 |
+
)
|
452 |
+
log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
|
453 |
+
if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
|
454 |
+
valid_data_mix_in = valid_data_mix_in.add_column(
|
455 |
+
"prompt_type",
|
456 |
+
[data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
|
457 |
+
)
|
458 |
+
log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
|
459 |
+
log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
|
460 |
+
|
461 |
+
# get our own training/validation data - for fine-tuning
|
462 |
+
if val_set_size > 0 and not valid_path and not data_mix_in_path:
|
463 |
+
# create valid split from train
|
464 |
+
train_val = data["train"].train_test_split(
|
465 |
+
test_size=val_set_size, shuffle=True, seed=42
|
466 |
+
)
|
467 |
+
train_data = train_val["train"]
|
468 |
+
valid_data = train_val["test"]
|
469 |
+
else:
|
470 |
+
train_data = data["train"]
|
471 |
+
if valid_path:
|
472 |
+
# use given valid split, has priority over data_mix_in_path
|
473 |
+
valid_data = data["valid"]
|
474 |
+
if "prompt_type" not in train_data.column_names:
|
475 |
+
train_data = train_data.add_column(
|
476 |
+
"prompt_type",
|
477 |
+
[prompt_type] * train_data.num_rows,
|
478 |
+
)
|
479 |
+
log("Added prompt type %s to training data" % prompt_type)
|
480 |
+
if valid_data and "prompt_type" not in valid_data.column_names:
|
481 |
+
valid_data = valid_data.add_column(
|
482 |
+
"prompt_type",
|
483 |
+
[prompt_type] * valid_data.num_rows,
|
484 |
+
)
|
485 |
+
log("Added prompt type %s to validation data" % prompt_type)
|
486 |
+
|
487 |
+
assert train_data is not None
|
488 |
+
|
489 |
+
# shuffle and tokenize data
|
490 |
+
if train_data_mix_in:
|
491 |
+
train_data = concatenate_datasets([train_data, train_data_mix_in])
|
492 |
+
train_data = train_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
|
493 |
+
train_set_size = len(train_data)
|
494 |
+
|
495 |
+
if valid_data and valid_data_mix_in:
|
496 |
+
valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
|
497 |
+
elif valid_data_mix_in:
|
498 |
+
valid_data = valid_data_mix_in
|
499 |
+
|
500 |
+
if valid_data:
|
501 |
+
valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
|
502 |
+
val_set_size = len(valid_data)
|
503 |
+
else:
|
504 |
+
val_set_size = 0
|
505 |
+
log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
|
506 |
+
sample_row_dict = train_data[:1]
|
507 |
+
del sample_row_dict['input_ids']
|
508 |
+
del sample_row_dict['attention_mask']
|
509 |
+
del sample_row_dict['labels']
|
510 |
+
log("Sample input: %s" % sample_row_dict)
|
511 |
+
|
512 |
+
if neptune_run:
|
513 |
+
neptune_callback = NeptuneCallback(run=neptune_run)
|
514 |
+
callbacks = [neptune_callback]
|
515 |
+
else:
|
516 |
+
from transformers.integrations import TensorBoardCallback, is_tensorboard_available
|
517 |
+
if is_tensorboard_available:
|
518 |
+
# tensorboard --logdir=runs/
|
519 |
+
from torch.utils.tensorboard import SummaryWriter
|
520 |
+
tb_writer = SummaryWriter()
|
521 |
+
callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
|
522 |
+
else:
|
523 |
+
callbacks = []
|
524 |
+
|
525 |
+
expected_steps = (train_set_size * num_epochs) // batch_size
|
526 |
+
if eval_steps is None and eval_epochs is None:
|
527 |
+
# 20 evaluations for a run
|
528 |
+
eval_steps = max(1, int(expected_steps / 20))
|
529 |
+
log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
|
530 |
+
elif eval_steps is None and eval_epochs is not None:
|
531 |
+
eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
|
532 |
+
log("Auto converted eval_epochs=%s to eval_steps %s"
|
533 |
+
" out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
|
534 |
+
if save_steps is None:
|
535 |
+
save_steps = eval_steps
|
536 |
+
log("Auto step save_steps to %s" % save_steps)
|
537 |
+
elif save_steps > eval_steps:
|
538 |
+
# save steps must be round multiple of eval_steps
|
539 |
+
save_steps0 = save_steps
|
540 |
+
save_steps = max(1, (save_steps//eval_steps)) * eval_steps
|
541 |
+
if save_steps0 != save_steps:
|
542 |
+
log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
|
543 |
+
|
544 |
+
def compute_metrics(eval_preds):
|
545 |
+
# e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
|
546 |
+
inputs = eval_preds.inputs
|
547 |
+
label_ids = eval_preds.label_ids
|
548 |
+
predictions = eval_preds.predictions
|
549 |
+
|
550 |
+
#inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
|
551 |
+
#decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
|
552 |
+
#decoded_inputs = [pred.strip() for pred in decoded_inputs]
|
553 |
+
|
554 |
+
label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
|
555 |
+
# tokenizer behavior like generate time
|
556 |
+
decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
|
557 |
+
clean_up_tokenization_spaces=True)
|
558 |
+
decoded_labels = [pred.strip() for pred in decoded_labels]
|
559 |
+
|
560 |
+
predictions = np.argmax(predictions, -1)
|
561 |
+
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
|
562 |
+
# tokenizer behavior like generate time
|
563 |
+
decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
|
564 |
+
clean_up_tokenization_spaces=True)
|
565 |
+
decoded_predictions = [pred.strip() for pred in decoded_predictions]
|
566 |
+
|
567 |
+
result = {}
|
568 |
+
for metric in metrics.values():
|
569 |
+
result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
|
570 |
+
# get rid of lists, for precision etc., for now
|
571 |
+
numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
|
572 |
+
result.update(numeric_results)
|
573 |
+
return result
|
574 |
+
|
575 |
+
# the callback that computes metrics of interest
|
576 |
+
if val_metrics:
|
577 |
+
trainer_kwargs = dict(compute_metrics=compute_metrics)
|
578 |
+
else:
|
579 |
+
trainer_kwargs = dict()
|
580 |
+
|
581 |
+
trainer = transformers.Trainer(
|
582 |
+
model=model,
|
583 |
+
tokenizer=tokenizer,
|
584 |
+
train_dataset=train_data,
|
585 |
+
eval_dataset=valid_data,
|
586 |
+
# NOTE: CausalLM is not supporting Seq2SeqTrainingArguments arguments, but not incompatible
|
587 |
+
args=transformers.Seq2SeqTrainingArguments(
|
588 |
+
per_device_train_batch_size=micro_batch_size,
|
589 |
+
per_device_eval_batch_size=1,
|
590 |
+
eval_accumulation_steps=10,
|
591 |
+
# predict_with_generate=True, # SEQ2SEQ only
|
592 |
+
include_inputs_for_metrics=True,
|
593 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
594 |
+
warmup_steps=warmup_steps,
|
595 |
+
num_train_epochs=num_epochs,
|
596 |
+
learning_rate=learning_rate,
|
597 |
+
gradient_checkpointing=gradient_checkpointing,
|
598 |
+
fp16=fp16,
|
599 |
+
# cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
|
600 |
+
optim="adamw_torch", # consider "adafactor" to save memory
|
601 |
+
logging_steps=logging_steps,
|
602 |
+
logging_strategy="steps",
|
603 |
+
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
604 |
+
save_strategy="steps",
|
605 |
+
eval_steps=eval_steps if val_set_size > 0 else None,
|
606 |
+
save_steps=save_steps,
|
607 |
+
output_dir=output_dir,
|
608 |
+
save_total_limit=3,
|
609 |
+
load_best_model_at_end=True if val_set_size > 0 else False,
|
610 |
+
ddp_find_unused_parameters=False if ddp else None,
|
611 |
+
group_by_length=group_by_length,
|
612 |
+
#fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
|
613 |
+
#fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
|
614 |
+
report_to='tensorboard' if not neptune_run else 'neptune',
|
615 |
+
),
|
616 |
+
data_collator=transformers.DataCollatorForSeq2Seq(
|
617 |
+
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
618 |
+
),
|
619 |
+
callbacks=callbacks,
|
620 |
+
**trainer_kwargs,
|
621 |
+
)
|
622 |
+
model.config.use_cache = False
|
623 |
+
|
624 |
+
old_state_dict = model.state_dict
|
625 |
+
model.state_dict = (
|
626 |
+
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
627 |
+
).__get__(model, type(model))
|
628 |
+
|
629 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
630 |
+
model = torch.compile(model)
|
631 |
+
# WIP (not generally replacing layers until pytorch 2.1)
|
632 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
633 |
+
|
634 |
+
if gpus > 1 and not ddp:
|
635 |
+
assert trainer.is_model_parallel
|
636 |
+
else:
|
637 |
+
assert not trainer.is_model_parallel
|
638 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
639 |
+
|
640 |
+
model.save_pretrained(output_dir)
|
641 |
+
|
642 |
+
log("\n If there's a warning about missing keys above, please disregard :)")
|
643 |
+
|
644 |
+
|
645 |
+
def get_loaders(llama_type, model_name, reward_type):
|
646 |
+
# NOTE: Some models need specific new prompt_type
|
647 |
+
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
648 |
+
if llama_type:
|
649 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
650 |
+
model_loader = LlamaForCausalLM
|
651 |
+
tokenizer_loader = LlamaTokenizer
|
652 |
+
elif 'gpt2' in model_name.lower():
|
653 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
654 |
+
return GPT2LMHeadModel, GPT2Tokenizer
|
655 |
+
elif 'mbart-' in model_name.lower():
|
656 |
+
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
657 |
+
return MBartForConditionalGeneration, MBart50TokenizerFast
|
658 |
+
elif 't5' == model_name.lower() or \
|
659 |
+
't5-' in model_name.lower() or \
|
660 |
+
'flan-' in model_name.lower():
|
661 |
+
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
662 |
+
return T5ForConditionalGeneration, AutoTokenizer
|
663 |
+
elif 'bigbird' in model_name:
|
664 |
+
from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
665 |
+
return BigBirdPegasusForConditionalGeneration, AutoTokenizer
|
666 |
+
elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
|
667 |
+
from transformers import pipeline
|
668 |
+
return pipeline, "summarization"
|
669 |
+
elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
|
670 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
671 |
+
return AutoModelForSequenceClassification, AutoTokenizer
|
672 |
+
else:
|
673 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
674 |
+
model_loader = AutoModelForCausalLM
|
675 |
+
tokenizer_loader = AutoTokenizer
|
676 |
+
return model_loader, tokenizer_loader
|
677 |
+
|
678 |
+
|
679 |
+
def get_githash():
|
680 |
+
try:
|
681 |
+
githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
|
682 |
+
except:
|
683 |
+
githash = ''
|
684 |
+
return githash
|
685 |
+
|
686 |
+
|
687 |
+
def copy_code(run_id):
|
688 |
+
"""
|
689 |
+
copy code to track changes
|
690 |
+
:param run_id:
|
691 |
+
:return:
|
692 |
+
"""
|
693 |
+
rnd_num = str(random.randint(0, 2 ** 31))
|
694 |
+
run_id = 'run_' + str(run_id)
|
695 |
+
os.makedirs(run_id, exist_ok=True)
|
696 |
+
me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
|
697 |
+
me_file = os.path.basename(__file__)
|
698 |
+
new_me = os.path.join(run_id, me_file + '_' + get_githash())
|
699 |
+
if os.path.isfile(new_me):
|
700 |
+
new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
|
701 |
+
shutil.copy(me_full, new_me)
|
702 |
+
else:
|
703 |
+
shutil.copy(me_full, new_me)
|
704 |
+
|
705 |
+
|
706 |
+
def get_prompt(prompt_type, chat, context, reduced):
|
707 |
+
if prompt_type in [-1, "-1", "plain"]:
|
708 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
709 |
+
terminate_response = []
|
710 |
+
elif prompt_type == 'simple_instruct':
|
711 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
712 |
+
terminate_response = []
|
713 |
+
elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
|
714 |
+
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
715 |
+
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
716 |
+
|
717 |
+
PreInstruct = """
|
718 |
+
### Instruction:
|
719 |
+
"""
|
720 |
+
|
721 |
+
PreInput = """
|
722 |
+
### Input:
|
723 |
+
"""
|
724 |
+
|
725 |
+
PreResponse = """
|
726 |
+
### Response:
|
727 |
+
"""
|
728 |
+
if prompt_type in [7, "7", "instruct_with_end"]:
|
729 |
+
terminate_response = ['### End']
|
730 |
+
else:
|
731 |
+
terminate_response = None
|
732 |
+
elif prompt_type in [1, "1", "quality"]:
|
733 |
+
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
|
734 |
+
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
|
735 |
+
|
736 |
+
PreInstruct = """
|
737 |
+
### Instruction:
|
738 |
+
"""
|
739 |
+
|
740 |
+
PreInput = """
|
741 |
+
### Input:
|
742 |
+
"""
|
743 |
+
|
744 |
+
PreResponse = """
|
745 |
+
### Response:
|
746 |
+
"""
|
747 |
+
terminate_response = None
|
748 |
+
elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
|
749 |
+
if reduced or context or prompt_type in [2, "2", "human_bot"]:
|
750 |
+
preprompt = ''
|
751 |
+
else:
|
752 |
+
cur_date = time.strftime('%Y-%m-%d')
|
753 |
+
cur_time = time.strftime('%H:%M:%S %p %Z')
|
754 |
+
|
755 |
+
PRE_PROMPT = """\
|
756 |
+
Current Date: {}
|
757 |
+
Current Time: {}
|
758 |
+
|
759 |
+
"""
|
760 |
+
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
761 |
+
start = human
|
762 |
+
promptB = promptA = '%s%s ' % (preprompt, start)
|
763 |
+
|
764 |
+
PreInstruct = ""
|
765 |
+
|
766 |
+
PreInput = None
|
767 |
+
|
768 |
+
PreResponse = bot
|
769 |
+
|
770 |
+
terminate_response = [start, PreResponse]
|
771 |
+
elif prompt_type in [3, "3", "dai_faq"]:
|
772 |
+
promptA = ''
|
773 |
+
promptB = 'Answer the following Driverless AI question.\n'
|
774 |
+
|
775 |
+
PreInstruct = """
|
776 |
+
### Driverless AI frequently asked question:
|
777 |
+
"""
|
778 |
+
|
779 |
+
PreInput = None
|
780 |
+
|
781 |
+
PreResponse = """
|
782 |
+
### Driverless AI documentation answer:
|
783 |
+
"""
|
784 |
+
terminate_response = ['\n\n']
|
785 |
+
elif prompt_type in [5, "5", "summarize"]:
|
786 |
+
promptA = promptB = PreInput = ''
|
787 |
+
PreInstruct = '## Main Text\n\n'
|
788 |
+
PreResponse = '\n\n## Summary\n\n'
|
789 |
+
terminate_response = None
|
790 |
+
elif prompt_type in [6, "6", "instruct_vicuna"]:
|
791 |
+
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
792 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
|
793 |
+
|
794 |
+
PreInstruct = """
|
795 |
+
### Human:
|
796 |
+
"""
|
797 |
+
|
798 |
+
PreInput = None
|
799 |
+
|
800 |
+
PreResponse = """
|
801 |
+
### Assistant:
|
802 |
+
"""
|
803 |
+
terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
804 |
+
else:
|
805 |
+
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
806 |
+
|
807 |
+
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
|
808 |
+
|
809 |
+
|
810 |
+
def generate_prompt(data_point, prompt_type, chat, reduced):
|
811 |
+
context = data_point.get('context')
|
812 |
+
if context is None:
|
813 |
+
context = ''
|
814 |
+
instruction = data_point.get('instruction')
|
815 |
+
input = data_point.get('input')
|
816 |
+
output = data_point.get('output')
|
817 |
+
prompt_type = data_point.get('prompt_type', prompt_type)
|
818 |
+
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
819 |
+
promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
|
820 |
+
|
821 |
+
prompt = context
|
822 |
+
|
823 |
+
if input and promptA:
|
824 |
+
prompt += f"""{promptA}"""
|
825 |
+
elif promptB:
|
826 |
+
prompt += f"""{promptB}"""
|
827 |
+
|
828 |
+
if instruction and PreInstruct is not None and input and PreInput is not None:
|
829 |
+
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
830 |
+
prompt = inject_newline(prompt_type, prompt)
|
831 |
+
elif instruction and input and PreInstruct is None and PreInput is not None:
|
832 |
+
prompt += f"""{PreInput}{instruction}
|
833 |
+
{input}"""
|
834 |
+
prompt = inject_newline(prompt_type, prompt)
|
835 |
+
elif input and instruction and PreInput is None and PreInstruct is not None:
|
836 |
+
prompt += f"""{PreInstruct}{instruction}
|
837 |
+
{input}"""
|
838 |
+
prompt = inject_newline(prompt_type, prompt)
|
839 |
+
elif instruction and PreInstruct is not None:
|
840 |
+
prompt += f"""{PreInstruct}{instruction}"""
|
841 |
+
prompt = inject_newline(prompt_type, prompt)
|
842 |
+
elif input and PreInput is not None:
|
843 |
+
prompt += f"""{PreInput}{input}"""
|
844 |
+
prompt = inject_newline(prompt_type, prompt)
|
845 |
+
elif input and instruction and PreInput is not None:
|
846 |
+
prompt += f"""{PreInput}{instruction}{input}"""
|
847 |
+
prompt = inject_newline(prompt_type, prompt)
|
848 |
+
elif input and instruction and PreInstruct is not None:
|
849 |
+
prompt += f"""{PreInstruct}{instruction}{input}"""
|
850 |
+
prompt = inject_newline(prompt_type, prompt)
|
851 |
+
elif input and instruction:
|
852 |
+
# i.e. for simple_instruct
|
853 |
+
prompt += f"""{instruction}: {input}"""
|
854 |
+
prompt = inject_newline(prompt_type, prompt)
|
855 |
+
elif input:
|
856 |
+
prompt += f"""{input}"""
|
857 |
+
prompt = inject_newline(prompt_type, prompt)
|
858 |
+
elif instruction:
|
859 |
+
prompt += f"""{instruction}"""
|
860 |
+
prompt = inject_newline(prompt_type, prompt)
|
861 |
+
|
862 |
+
if PreResponse is not None:
|
863 |
+
prompt += f"""{PreResponse}"""
|
864 |
+
pre_response = PreResponse # Don't use strip
|
865 |
+
else:
|
866 |
+
pre_response = ''
|
867 |
+
|
868 |
+
if output:
|
869 |
+
prompt += f"""{output}"""
|
870 |
+
|
871 |
+
return prompt, pre_response, terminate_response
|
872 |
+
|
873 |
+
|
874 |
+
def inject_newline(prompt_type, prompt):
|
875 |
+
if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
|
876 |
+
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
877 |
+
prompt += '\n'
|
878 |
+
return prompt
|
879 |
+
|
880 |
+
|
881 |
+
example_data_point0 = dict(instruction="Summarize",
|
882 |
+
input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.",
|
883 |
+
output="Ducks eat and swim at the lake.")
|
884 |
+
|
885 |
+
example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?",
|
886 |
+
output="Einstein.")
|
887 |
+
|
888 |
+
example_data_point2 = dict(input="Who is smarter, Einstein or Newton?",
|
889 |
+
output="Einstein.")
|
890 |
+
|
891 |
+
example_data_points = [example_data_point0, example_data_point1, example_data_point2]
|
892 |
+
|
893 |
+
|
894 |
+
def test_train_prompt(prompt_type='instruct', data_point=0):
|
895 |
+
example_data_point = example_data_points[data_point]
|
896 |
+
return generate_prompt(example_data_point, prompt_type, False, False)
|
897 |
+
|
898 |
+
|
899 |
+
def test_debug():
|
900 |
+
fire.Fire(train)
|
901 |
+
|
902 |
+
|
903 |
+
if __name__ == "__main__":
|
904 |
+
CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
|
905 |
+
CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
|
906 |
+
log(f"""
|
907 |
+
Example runs on 4 GPUs:
|
908 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
|
909 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
|
910 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
|
911 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
|
912 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
|
913 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
|
914 |
+
|
915 |
+
All metrics:
|
916 |
+
CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
|
917 |
+
|
918 |
+
# Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
|
919 |
+
rippa>
|
920 |
+
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
|
921 |
+
ova>
|
922 |
+
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
|
923 |
+
timemachine>
|
924 |
+
NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
|
925 |
+
|
926 |
+
""", flush=True)
|
927 |
+
|
928 |
+
if os.environ.get("LOCAL_RANK") is None:
|
929 |
+
# then not using torchrun, so can't do distributed, ensure CVD set
|
930 |
+
assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
|
931 |
+
|
932 |
+
fire.Fire(train)
|
h2o-logo.svg
ADDED
|
prompter.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from finetune import generate_prompt
|
2 |
+
|
3 |
+
|
4 |
+
class Prompter(object):
|
5 |
+
def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
|
6 |
+
allowed_repeat_line_length=10):
|
7 |
+
self.prompt_type = prompt_type
|
8 |
+
data_point = dict(instruction='', input='', output='')
|
9 |
+
_, self.pre_response, self.terminate_response = generate_prompt(data_point, prompt_type, chat, False)
|
10 |
+
self.debug = debug
|
11 |
+
self.chat = chat
|
12 |
+
self.stream_output = stream_output
|
13 |
+
self.repeat_penalty = repeat_penalty
|
14 |
+
self.allowed_repeat_line_length = allowed_repeat_line_length
|
15 |
+
|
16 |
+
def generate_prompt(self, data_point):
|
17 |
+
reduced = False
|
18 |
+
prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
|
19 |
+
if self.debug:
|
20 |
+
print("prompt: ", prompt, flush=True)
|
21 |
+
self.prompt = prompt
|
22 |
+
return prompt
|
23 |
+
|
24 |
+
def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
|
25 |
+
if isinstance(outputs, str):
|
26 |
+
outputs = [outputs]
|
27 |
+
if self.debug:
|
28 |
+
print("output: ", '\n\n'.join(outputs), flush=True)
|
29 |
+
if prompt is not None:
|
30 |
+
self.prompt = prompt
|
31 |
+
|
32 |
+
def clean_response(response):
|
33 |
+
meaningless_words = ['<pad>', '</s>', '<|endoftext|>', '”\n']
|
34 |
+
for word in meaningless_words:
|
35 |
+
response = response.replace(word, "")
|
36 |
+
if sanitize_bot_response:
|
37 |
+
from better_profanity import profanity
|
38 |
+
response = profanity.censor(response)
|
39 |
+
response = response.strip("\n")
|
40 |
+
return response
|
41 |
+
|
42 |
+
def clean_repeats(response):
|
43 |
+
lines = response.split('\n')
|
44 |
+
new_lines = []
|
45 |
+
[new_lines.append(line) for line in lines if
|
46 |
+
line not in new_lines or len(line) < self.allowed_repeat_line_length]
|
47 |
+
if self.debug and len(lines) != len(new_lines):
|
48 |
+
print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
|
49 |
+
response = '\n'.join(new_lines)
|
50 |
+
return response
|
51 |
+
|
52 |
+
multi_output = len(outputs) > 1
|
53 |
+
|
54 |
+
for oi, output in enumerate(outputs):
|
55 |
+
if self.prompt_type in [0, '0', 'plain']:
|
56 |
+
output = clean_response(output)
|
57 |
+
else:
|
58 |
+
# find first instance of prereponse
|
59 |
+
# prompt sometimes has odd characters, that mutate length,
|
60 |
+
# so can't go by length alone
|
61 |
+
if self.pre_response:
|
62 |
+
outputi = output.find(prompt)
|
63 |
+
if outputi >= 0:
|
64 |
+
output = output[outputi + len(prompt):]
|
65 |
+
allow_terminate = True
|
66 |
+
else:
|
67 |
+
# subtraction is risky due to space offsets sometimes, so only do if necessary
|
68 |
+
output = output[len(prompt) - len(self.pre_response):]
|
69 |
+
# [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
|
70 |
+
if self.pre_response in output:
|
71 |
+
output = output.split(self.pre_response)[1]
|
72 |
+
allow_terminate = True
|
73 |
+
else:
|
74 |
+
print("Failure of parsing: %s" % output, flush=True)
|
75 |
+
allow_terminate = False
|
76 |
+
else:
|
77 |
+
allow_terminate = True
|
78 |
+
output = output[len(prompt):]
|
79 |
+
# clean after subtract prompt out, so correct removal of pre_response
|
80 |
+
output = clean_response(output).strip()
|
81 |
+
if self.repeat_penalty:
|
82 |
+
output = clean_repeats(output).strip()
|
83 |
+
if self.terminate_response and allow_terminate:
|
84 |
+
finds = []
|
85 |
+
for term in self.terminate_response:
|
86 |
+
finds.append(output.find(term))
|
87 |
+
finds = [x for x in finds if x >= 0]
|
88 |
+
if len(finds) > 0:
|
89 |
+
termi = finds[0]
|
90 |
+
output = output[:termi].strip()
|
91 |
+
else:
|
92 |
+
output = output.strip()
|
93 |
+
else:
|
94 |
+
output = output.strip()
|
95 |
+
if multi_output:
|
96 |
+
# prefix with output counter
|
97 |
+
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
98 |
+
if oi > 0:
|
99 |
+
# post fix outputs with seperator
|
100 |
+
output += '\n'
|
101 |
+
outputs[oi] = output
|
102 |
+
# join all outputs, only one extra new line between outputs
|
103 |
+
output = '\n'.join(outputs)
|
104 |
+
if self.debug:
|
105 |
+
print("outputclean: ", '\n\n'.join(outputs), flush=True)
|
106 |
+
return output
|
requirements.txt
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# for generate (gradio server) and finetune
|
2 |
+
datasets==2.11.0
|
3 |
+
sentencepiece==0.1.97
|
4 |
+
accelerate==0.18.0
|
5 |
+
gradio==3.27.0
|
6 |
+
huggingface_hub==0.13.4
|
7 |
+
appdirs==1.4.4
|
8 |
+
fire==0.5.0
|
9 |
+
docutils==0.19
|
10 |
+
torch==2.0.0
|
11 |
+
evaluate==0.4.0
|
12 |
+
rouge_score==0.1.2
|
13 |
+
sacrebleu==2.3.1
|
14 |
+
scikit-learn==1.2.2
|
15 |
+
alt-profanity-check==1.2.2
|
16 |
+
better-profanity==0.6.1
|
17 |
+
numpy==1.24.2
|
18 |
+
pandas==2.0.0
|
19 |
+
matplotlib==3.7.1
|
20 |
+
loralib==0.1.1
|
21 |
+
bitsandbytes==0.38.1
|
22 |
+
git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
|
23 |
+
transformers==4.28.1
|
24 |
+
tokenizers==0.13.3
|
25 |
+
|
26 |
+
# optional for generate
|
27 |
+
pynvml==11.5.0
|
28 |
+
psutil==5.9.4
|
29 |
+
boto3==1.26.101
|
30 |
+
botocore==1.29.101
|
31 |
+
|
32 |
+
# optional for finetune
|
33 |
+
tensorboard==2.12.1
|
34 |
+
neptune==1.1.1
|
35 |
+
|
36 |
+
# for gradio client
|
37 |
+
gradio_client==0.1.3
|
38 |
+
beautifulsoup4==4.12.2
|
39 |
+
markdown==3.4.1
|
40 |
+
|
41 |
+
# data and testing
|
42 |
+
pytest==7.2.2
|
43 |
+
pytest-xdist==3.2.1
|
44 |
+
nltk==3.8.1
|
45 |
+
textstat==0.7.3
|
46 |
+
pandoc==2.3
|
47 |
+
pypandoc==1.11
|
48 |
+
openpyxl==3.1.2
|
49 |
+
lm_dataformat==0.0.20
|
50 |
+
bioc==2.0
|
stopping.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from queue import Queue
|
3 |
+
from threading import Thread
|
4 |
+
import collections.abc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from transformers import StoppingCriteria
|
8 |
+
|
9 |
+
|
10 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
11 |
+
|
12 |
+
def __init__(self, stops=[], encounters=[]):
|
13 |
+
super().__init__()
|
14 |
+
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
15 |
+
self.encounters = encounters
|
16 |
+
self.stops = [stop.to("cuda") for stop in stops]
|
17 |
+
self.num_stops = [0] * len(stops)
|
18 |
+
|
19 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
20 |
+
for stopi, stop in enumerate(self.stops):
|
21 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
22 |
+
self.num_stops[stopi] += 1
|
23 |
+
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
24 |
+
return True
|
25 |
+
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
26 |
+
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
27 |
+
return False
|
28 |
+
|
29 |
+
|
30 |
+
class Stream(StoppingCriteria):
|
31 |
+
"""
|
32 |
+
This class can be used to callback during generation. Keep
|
33 |
+
in mind for decoder-only type of transformers, this will include the initial prompted tokens.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
func (`callable`):
|
37 |
+
A callable function to apply on first input in list every iteration of generation
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, func=None):
|
41 |
+
self.func = func
|
42 |
+
|
43 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
44 |
+
if self.func is not None:
|
45 |
+
# only consume first of multiple responses
|
46 |
+
self.func(input_ids[0])
|
47 |
+
return False
|
48 |
+
|
49 |
+
|
50 |
+
class CallbackToGenerator(collections.abc.Generator):
|
51 |
+
"""
|
52 |
+
A generator wrapper for a function that invokes a callback multiple times.
|
53 |
+
|
54 |
+
Calling `send` on the generator emits a value from one callback, and returns
|
55 |
+
the next.
|
56 |
+
|
57 |
+
Note this starts a background thread
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, func, *args, callback=None, **kwargs):
|
61 |
+
self.func = func
|
62 |
+
self.args = args
|
63 |
+
self.kwargs = kwargs
|
64 |
+
self.callback = callback
|
65 |
+
|
66 |
+
self._ready_queue = Queue(1)
|
67 |
+
self._done_queue = Queue(1)
|
68 |
+
self._done_holder = [False]
|
69 |
+
|
70 |
+
# local to avoid reference cycles
|
71 |
+
ready_queue = self._ready_queue
|
72 |
+
done_queue = self._done_queue
|
73 |
+
done_holder = self._done_holder
|
74 |
+
|
75 |
+
def val_callback(value):
|
76 |
+
done_queue.put((False, value))
|
77 |
+
cmd, val = ready_queue.get()
|
78 |
+
if cmd == 'send':
|
79 |
+
return val
|
80 |
+
elif cmd == 'throw':
|
81 |
+
raise val
|
82 |
+
else:
|
83 |
+
assert False # pragma: no cover
|
84 |
+
|
85 |
+
def thread_func():
|
86 |
+
while True:
|
87 |
+
cmd, val = ready_queue.get()
|
88 |
+
if cmd == 'send' and val is not None:
|
89 |
+
done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
|
90 |
+
continue
|
91 |
+
break
|
92 |
+
try:
|
93 |
+
if cmd == 'throw':
|
94 |
+
raise val
|
95 |
+
ret = func(callback=val_callback, **self.kwargs)
|
96 |
+
raise StopIteration(ret) if ret is not None else StopIteration
|
97 |
+
except BaseException as e:
|
98 |
+
done_holder[0] = True
|
99 |
+
done_queue.put((True, e))
|
100 |
+
|
101 |
+
self._thread = Thread(target=thread_func)
|
102 |
+
self._thread.start()
|
103 |
+
|
104 |
+
def _put(self, *args):
|
105 |
+
if self._done_holder[0]:
|
106 |
+
raise StopIteration
|
107 |
+
self._ready_queue.put(args)
|
108 |
+
is_exception, val = self._done_queue.get()
|
109 |
+
if is_exception:
|
110 |
+
try:
|
111 |
+
raise val
|
112 |
+
finally:
|
113 |
+
# prevent val's traceback containing a reference cycle
|
114 |
+
del val
|
115 |
+
else:
|
116 |
+
return val
|
117 |
+
|
118 |
+
def send(self, value):
|
119 |
+
return self._put('send', value)
|
120 |
+
|
121 |
+
def throw(self, exc):
|
122 |
+
return self._put('throw', exc)
|
123 |
+
|
124 |
+
def close(self):
|
125 |
+
try:
|
126 |
+
self.throw(GeneratorExit)
|
127 |
+
except StopIteration:
|
128 |
+
self._thread.join()
|
129 |
+
except GeneratorExit:
|
130 |
+
self._thread.join()
|
131 |
+
except BaseException:
|
132 |
+
self._thread.join()
|
133 |
+
raise
|
134 |
+
else:
|
135 |
+
# yielded again, can't clean up the thread
|
136 |
+
raise RuntimeError('Task with callback ignored GeneratorExit')
|
137 |
+
|
138 |
+
def __del__(self):
|
139 |
+
self.close()
|
utils.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
import traceback
|
6 |
+
import zipfile
|
7 |
+
from datetime import datetime
|
8 |
+
import filelock
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def set_seed(seed: int):
|
15 |
+
"""
|
16 |
+
Sets the seed of the entire notebook so results are the same every time we run.
|
17 |
+
This is for REPRODUCIBILITY.
|
18 |
+
"""
|
19 |
+
np.random.seed(seed)
|
20 |
+
random_state = np.random.RandomState(seed)
|
21 |
+
random.seed(seed)
|
22 |
+
torch.manual_seed(seed)
|
23 |
+
torch.cuda.manual_seed(seed)
|
24 |
+
torch.backends.cudnn.deterministic = True
|
25 |
+
torch.backends.cudnn.benchmark = False
|
26 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
27 |
+
return random_state
|
28 |
+
|
29 |
+
|
30 |
+
def flatten_list(lis):
|
31 |
+
"""Given a list, possibly nested to any level, return it flattened."""
|
32 |
+
new_lis = []
|
33 |
+
for item in lis:
|
34 |
+
if type(item) == type([]):
|
35 |
+
new_lis.extend(flatten_list(item))
|
36 |
+
else:
|
37 |
+
new_lis.append(item)
|
38 |
+
return new_lis
|
39 |
+
|
40 |
+
|
41 |
+
def clear_torch_cache():
|
42 |
+
if torch.cuda.is_available:
|
43 |
+
torch.cuda.empty_cache()
|
44 |
+
torch.cuda.ipc_collect()
|
45 |
+
gc.collect()
|
46 |
+
|
47 |
+
|
48 |
+
def system_info():
|
49 |
+
import psutil
|
50 |
+
|
51 |
+
system = {}
|
52 |
+
# https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
|
53 |
+
# https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
|
54 |
+
temps = psutil.sensors_temperatures(fahrenheit=False)
|
55 |
+
if 'coretemp' in temps:
|
56 |
+
coretemp = temps['coretemp']
|
57 |
+
temp_dict = {k.label: k.current for k in coretemp}
|
58 |
+
for k, v in temp_dict.items():
|
59 |
+
system['CPU_C/%s' % k] = v
|
60 |
+
|
61 |
+
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
62 |
+
from pynvml.smi import nvidia_smi
|
63 |
+
nvsmi = nvidia_smi.getInstance()
|
64 |
+
|
65 |
+
gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
|
66 |
+
enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
|
67 |
+
for k, v in gpu_power_dict.items():
|
68 |
+
system['GPU_W/%s' % k] = v
|
69 |
+
|
70 |
+
gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
|
71 |
+
enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
|
72 |
+
for k, v in gpu_temp_dict.items():
|
73 |
+
system['GPU_C/%s' % k] = v
|
74 |
+
|
75 |
+
gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
|
76 |
+
enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
|
77 |
+
gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
|
78 |
+
enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
|
79 |
+
gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
|
80 |
+
for k, v in gpu_memory_frac_dict.items():
|
81 |
+
system[f'GPU_M/%s' % k] = v
|
82 |
+
|
83 |
+
return system
|
84 |
+
|
85 |
+
|
86 |
+
def system_info_print():
|
87 |
+
try:
|
88 |
+
df = pd.DataFrame.from_dict(system_info(), orient='index')
|
89 |
+
# avoid slamming GPUs
|
90 |
+
time.sleep(1)
|
91 |
+
return df.to_markdown()
|
92 |
+
except Exception as e:
|
93 |
+
return "Error: %s" % str(e)
|
94 |
+
|
95 |
+
|
96 |
+
def zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
97 |
+
try:
|
98 |
+
return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
|
99 |
+
except Exception as e:
|
100 |
+
traceback.print_exc()
|
101 |
+
print('Exception in zipping: %s' % str(e))
|
102 |
+
|
103 |
+
|
104 |
+
def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
105 |
+
if zip_file is None:
|
106 |
+
datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
|
107 |
+
host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
|
108 |
+
zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
|
109 |
+
assert root_dirs is not None
|
110 |
+
|
111 |
+
with zipfile.ZipFile(zip_file, "w") as expt_zip:
|
112 |
+
for root_dir in root_dirs:
|
113 |
+
if root_dir is None:
|
114 |
+
continue
|
115 |
+
for root, d, files in os.walk(root_dir):
|
116 |
+
for file in files:
|
117 |
+
file_to_archive = os.path.join(root, file)
|
118 |
+
assert os.path.exists(file_to_archive)
|
119 |
+
path_to_archive = os.path.relpath(file_to_archive, base_dir)
|
120 |
+
expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
|
121 |
+
return zip_file, zip_file
|
122 |
+
|
123 |
+
|
124 |
+
def save_generate_output(output=None, base_model=None, save_dir=None):
|
125 |
+
try:
|
126 |
+
return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
|
127 |
+
except Exception as e:
|
128 |
+
traceback.print_exc()
|
129 |
+
print('Exception in saving: %s' % str(e))
|
130 |
+
|
131 |
+
|
132 |
+
def _save_generate_output(output=None, base_model=None, save_dir=None):
|
133 |
+
"""
|
134 |
+
Save conversation to .json, row by row.
|
135 |
+
json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
|
136 |
+
Appends if file exists
|
137 |
+
"""
|
138 |
+
assert save_dir, "save_dir must be provided"
|
139 |
+
if os.path.exists(save_dir) and not os.path.isdir(save_dir):
|
140 |
+
raise RuntimeError("save_dir already exists and is not a directory!")
|
141 |
+
os.makedirs(save_dir, exist_ok=True)
|
142 |
+
import json
|
143 |
+
if output[-10:] == '\n\n<human>:':
|
144 |
+
# remove trailing <human>:
|
145 |
+
output = output[:-10]
|
146 |
+
with filelock.FileLock("save_dir.lock"):
|
147 |
+
# lock logging in case have concurrency
|
148 |
+
with open(os.path.join(save_dir, "history.json"), "a") as f:
|
149 |
+
# just add [ at start, and ] at end, and have proper JSON dataset
|
150 |
+
f.write(
|
151 |
+
" " + json.dumps(
|
152 |
+
dict(text=output, time=time.ctime(), base_model=base_model)
|
153 |
+
) + ",\n"
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
def s3up(filename):
|
158 |
+
try:
|
159 |
+
return _s3up(filename)
|
160 |
+
except Exception as e:
|
161 |
+
traceback.print_exc()
|
162 |
+
print('Exception for file %s in s3up: %s' % (filename, str(e)))
|
163 |
+
return "Failed to upload %s: Error: %s" % (filename, str(e))
|
164 |
+
|
165 |
+
|
166 |
+
def _s3up(filename):
|
167 |
+
import boto3
|
168 |
+
|
169 |
+
aws_access_key_id = os.getenv('AWS_SERVER_PUBLIC_KEY')
|
170 |
+
aws_secret_access_key = os.getenv('AWS_SERVER_SECRET_KEY')
|
171 |
+
bucket = os.getenv('AWS_BUCKET')
|
172 |
+
assert aws_access_key_id, "Set AWS key"
|
173 |
+
assert aws_secret_access_key, "Set AWS secret"
|
174 |
+
assert bucket, "Set AWS Bucket"
|
175 |
+
|
176 |
+
s3 = boto3.client('s3',
|
177 |
+
aws_access_key_id=os.getenv('AWS_SERVER_PUBLIC_KEY'),
|
178 |
+
aws_secret_access_key=os.getenv('AWS_SERVER_SECRET_KEY'),
|
179 |
+
)
|
180 |
+
ret = s3.upload_file(
|
181 |
+
Filename=filename,
|
182 |
+
Bucket=os.getenv('AWS_BUCKET'),
|
183 |
+
Key=filename,
|
184 |
+
)
|
185 |
+
if ret in [None, '']:
|
186 |
+
return "Successfully uploaded %s" % filename
|