Spaces:
Sleeping
Sleeping
IC4T
commited on
Commit
·
6997035
1
Parent(s):
ea27bb1
update
Browse files- training/__init__.py +0 -0
- training/consts.py +74 -0
- training/generate.py +239 -0
- training/trainer.py +330 -0
training/__init__.py
ADDED
File without changes
|
training/consts.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEFAULT_INPUT_MODEL = "EleutherAI/pythia-6.9b"
|
2 |
+
SUGGESTED_INPUT_MODELS = [
|
3 |
+
"EleutherAI/pythia-2.8b",
|
4 |
+
"EleutherAI/pythia-6.9b",
|
5 |
+
"EleutherAI/pythia-12b",
|
6 |
+
"EleutherAI/gpt-j-6B",
|
7 |
+
]
|
8 |
+
INTRO_BLURB = (
|
9 |
+
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
10 |
+
)
|
11 |
+
INSTRUCTION_KEY = "### Instruction:"
|
12 |
+
INPUT_KEY = "Input:"
|
13 |
+
RESPONSE_KEY = "### Response:"
|
14 |
+
END_KEY = "### End"
|
15 |
+
RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
|
16 |
+
DEFAULT_SEED = 42
|
17 |
+
|
18 |
+
# This is a training prompt that does not contain an input string. The instruction by itself has enough information
|
19 |
+
# to respond. For example, the instruction might ask for the year a historic figure was born.
|
20 |
+
PROMPT_NO_INPUT_FORMAT = """{intro}
|
21 |
+
|
22 |
+
{instruction_key}
|
23 |
+
{instruction}
|
24 |
+
|
25 |
+
{response_key}
|
26 |
+
{response}
|
27 |
+
|
28 |
+
{end_key}""".format(
|
29 |
+
intro=INTRO_BLURB,
|
30 |
+
instruction_key=INSTRUCTION_KEY,
|
31 |
+
instruction="{instruction}",
|
32 |
+
response_key=RESPONSE_KEY,
|
33 |
+
response="{response}",
|
34 |
+
end_key=END_KEY,
|
35 |
+
)
|
36 |
+
|
37 |
+
# This is a training prompt that contains an input string that serves as context for the instruction. For example,
|
38 |
+
# the input might be a passage from Wikipedia and the intruction is to extract some information from it.
|
39 |
+
PROMPT_WITH_INPUT_FORMAT = """{intro}
|
40 |
+
|
41 |
+
{instruction_key}
|
42 |
+
{instruction}
|
43 |
+
|
44 |
+
{input_key}
|
45 |
+
{input}
|
46 |
+
|
47 |
+
{response_key}
|
48 |
+
{response}
|
49 |
+
|
50 |
+
{end_key}""".format(
|
51 |
+
intro=INTRO_BLURB,
|
52 |
+
instruction_key=INSTRUCTION_KEY,
|
53 |
+
instruction="{instruction}",
|
54 |
+
input_key=INPUT_KEY,
|
55 |
+
input="{input}",
|
56 |
+
response_key=RESPONSE_KEY,
|
57 |
+
response="{response}",
|
58 |
+
end_key=END_KEY,
|
59 |
+
)
|
60 |
+
|
61 |
+
# This is the prompt that is used for generating responses using an already trained model. It ends with the response
|
62 |
+
# key, where the job of the model is to provide the completion that follows it (i.e. the response itself).
|
63 |
+
PROMPT_FOR_GENERATION_FORMAT = """{intro}
|
64 |
+
|
65 |
+
{instruction_key}
|
66 |
+
{instruction}
|
67 |
+
|
68 |
+
{response_key}
|
69 |
+
""".format(
|
70 |
+
intro=INTRO_BLURB,
|
71 |
+
instruction_key=INSTRUCTION_KEY,
|
72 |
+
instruction="{instruction}",
|
73 |
+
response_key=RESPONSE_KEY,
|
74 |
+
)
|
training/generate.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from transformers import (
|
7 |
+
AutoModelForCausalLM,
|
8 |
+
AutoTokenizer,
|
9 |
+
Pipeline,
|
10 |
+
PreTrainedModel,
|
11 |
+
PreTrainedTokenizer,
|
12 |
+
)
|
13 |
+
|
14 |
+
from transformers.utils import is_tf_available
|
15 |
+
|
16 |
+
if is_tf_available():
|
17 |
+
import tensorflow as tf
|
18 |
+
|
19 |
+
from .consts import END_KEY, PROMPT_FOR_GENERATION_FORMAT, RESPONSE_KEY
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def load_model_tokenizer_for_generate(
|
25 |
+
pretrained_model_name_or_path: str,
|
26 |
+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
27 |
+
"""Loads the model and tokenizer so that it can be used for generating responses.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
pretrained_model_name_or_path (str): name or path for model
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
Tuple[PreTrainedModel, PreTrainedTokenizer]: model and tokenizer
|
34 |
+
"""
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, padding_side="left", cache_dir="/media/siiva/DataStore/LLMs/cache/dollyV2")
|
36 |
+
model = AutoModelForCausalLM.from_pretrained(
|
37 |
+
pretrained_model_name_or_path, device_map="auto", trust_remote_code=True, cache_dir="/media/siiva/DataStore/LLMs/cache/dollyV2"
|
38 |
+
)
|
39 |
+
return model, tokenizer
|
40 |
+
|
41 |
+
|
42 |
+
def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int:
|
43 |
+
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
44 |
+
|
45 |
+
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
46 |
+
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
tokenizer (PreTrainedTokenizer): the tokenizer
|
50 |
+
key (str): the key to convert to a single token
|
51 |
+
|
52 |
+
Raises:
|
53 |
+
RuntimeError: if more than one ID was generated
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
int: the token ID for the given key
|
57 |
+
"""
|
58 |
+
token_ids = tokenizer.encode(key)
|
59 |
+
if len(token_ids) > 1:
|
60 |
+
raise RuntimeError(f"Expected only a single token for '{key}' but found {token_ids}")
|
61 |
+
return token_ids[0]
|
62 |
+
|
63 |
+
|
64 |
+
class InstructionTextGenerationPipeline(Pipeline):
|
65 |
+
def __init__(
|
66 |
+
self, *args, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs
|
67 |
+
):
|
68 |
+
"""Initialize the pipeline
|
69 |
+
|
70 |
+
Args:
|
71 |
+
do_sample (bool, optional): Whether or not to use sampling. Defaults to True.
|
72 |
+
max_new_tokens (int, optional): Max new tokens after the prompt to generate. Defaults to 128.
|
73 |
+
top_p (float, optional): If set to float < 1, only the smallest set of most probable tokens with
|
74 |
+
probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92.
|
75 |
+
top_k (int, optional): The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
76 |
+
Defaults to 0.
|
77 |
+
"""
|
78 |
+
super().__init__(*args, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k,
|
79 |
+
**kwargs)
|
80 |
+
|
81 |
+
def _sanitize_parameters(self,
|
82 |
+
return_full_text: bool = None,
|
83 |
+
**generate_kwargs):
|
84 |
+
preprocess_params = {}
|
85 |
+
|
86 |
+
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
|
87 |
+
# append a newline to yield a single token. find whatever token is configured for the response key.
|
88 |
+
tokenizer_response_key = next(
|
89 |
+
(token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None
|
90 |
+
)
|
91 |
+
|
92 |
+
response_key_token_id = None
|
93 |
+
end_key_token_id = None
|
94 |
+
if tokenizer_response_key:
|
95 |
+
try:
|
96 |
+
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
|
97 |
+
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
|
98 |
+
|
99 |
+
# Ensure generation stops once it generates "### End"
|
100 |
+
generate_kwargs["eos_token_id"] = end_key_token_id
|
101 |
+
except ValueError:
|
102 |
+
pass
|
103 |
+
|
104 |
+
forward_params = generate_kwargs
|
105 |
+
postprocess_params = {
|
106 |
+
"response_key_token_id": response_key_token_id,
|
107 |
+
"end_key_token_id": end_key_token_id
|
108 |
+
}
|
109 |
+
|
110 |
+
if return_full_text is not None:
|
111 |
+
postprocess_params["return_full_text"] = return_full_text
|
112 |
+
|
113 |
+
return preprocess_params, forward_params, postprocess_params
|
114 |
+
|
115 |
+
def preprocess(self, instruction_text, **generate_kwargs):
|
116 |
+
prompt_text = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction_text)
|
117 |
+
inputs = self.tokenizer(
|
118 |
+
prompt_text,
|
119 |
+
return_tensors="pt",
|
120 |
+
)
|
121 |
+
inputs["prompt_text"] = prompt_text
|
122 |
+
inputs["instruction_text"] = instruction_text
|
123 |
+
return inputs
|
124 |
+
|
125 |
+
def _forward(self, model_inputs, **generate_kwargs):
|
126 |
+
input_ids = model_inputs["input_ids"]
|
127 |
+
attention_mask = model_inputs.get("attention_mask", None)
|
128 |
+
|
129 |
+
if input_ids.shape[1] == 0:
|
130 |
+
input_ids = None
|
131 |
+
attention_mask = None
|
132 |
+
in_b = 1
|
133 |
+
else:
|
134 |
+
in_b = input_ids.shape[0]
|
135 |
+
|
136 |
+
generated_sequence = self.model.generate(
|
137 |
+
input_ids=input_ids.to(self.model.device),
|
138 |
+
attention_mask=attention_mask,
|
139 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
140 |
+
**generate_kwargs,
|
141 |
+
)
|
142 |
+
|
143 |
+
out_b = generated_sequence.shape[0]
|
144 |
+
if self.framework == "pt":
|
145 |
+
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
146 |
+
elif self.framework == "tf":
|
147 |
+
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
148 |
+
|
149 |
+
instruction_text = model_inputs.pop("instruction_text")
|
150 |
+
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
|
151 |
+
|
152 |
+
def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_full_text: bool = False):
|
153 |
+
|
154 |
+
generated_sequence = model_outputs["generated_sequence"][0]
|
155 |
+
instruction_text = model_outputs["instruction_text"]
|
156 |
+
|
157 |
+
generated_sequence: List[List[int]] = generated_sequence.numpy().tolist()
|
158 |
+
records = []
|
159 |
+
for sequence in generated_sequence:
|
160 |
+
|
161 |
+
# The response will be set to this variable if we can identify it.
|
162 |
+
decoded = None
|
163 |
+
|
164 |
+
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
|
165 |
+
if response_key_token_id and end_key_token_id:
|
166 |
+
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
|
167 |
+
# prompt, we should definitely find it. We will return the tokens found after this token.
|
168 |
+
try:
|
169 |
+
response_pos = sequence.index(response_key_token_id)
|
170 |
+
except ValueError:
|
171 |
+
logger.warn(f"Could not find response key {response_key_token_id} in: {sequence}")
|
172 |
+
response_pos = None
|
173 |
+
|
174 |
+
if response_pos:
|
175 |
+
# Next find where "### End" is located. The model has been trained to end its responses with this
|
176 |
+
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
177 |
+
# this token, as the response could be truncated. If we don't find it then just return everything
|
178 |
+
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
|
179 |
+
try:
|
180 |
+
end_pos = sequence.index(end_key_token_id)
|
181 |
+
except ValueError:
|
182 |
+
end_pos = None
|
183 |
+
|
184 |
+
decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
|
185 |
+
|
186 |
+
if not decoded:
|
187 |
+
# Otherwise we'll decode everything and use a regex to find the response and end.
|
188 |
+
|
189 |
+
fully_decoded = self.tokenizer.decode(sequence)
|
190 |
+
|
191 |
+
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
192 |
+
# end.
|
193 |
+
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
194 |
+
|
195 |
+
if m:
|
196 |
+
decoded = m.group(1).strip()
|
197 |
+
else:
|
198 |
+
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
199 |
+
# return everything after "### Response:".
|
200 |
+
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
201 |
+
if m:
|
202 |
+
decoded = m.group(1).strip()
|
203 |
+
else:
|
204 |
+
logger.warn(f"Failed to find response in:\n{fully_decoded}")
|
205 |
+
|
206 |
+
# If the full text is requested, then append the decoded text to the original instruction.
|
207 |
+
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
208 |
+
# trained on, but to the client it will appear to be the full text.
|
209 |
+
if return_full_text:
|
210 |
+
decoded = f"{instruction_text}\n{decoded}"
|
211 |
+
|
212 |
+
rec = {"generated_text": decoded}
|
213 |
+
|
214 |
+
records.append(rec)
|
215 |
+
|
216 |
+
return records
|
217 |
+
|
218 |
+
|
219 |
+
def generate_response(
|
220 |
+
instruction: str,
|
221 |
+
*,
|
222 |
+
model: PreTrainedModel,
|
223 |
+
tokenizer: PreTrainedTokenizer,
|
224 |
+
**kwargs,
|
225 |
+
) -> str:
|
226 |
+
"""Given an instruction, uses the model and tokenizer to generate a response. This formats the instruction in
|
227 |
+
the instruction format that the model was fine-tuned on.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
instruction (str): _description_
|
231 |
+
model (PreTrainedModel): the model to use
|
232 |
+
tokenizer (PreTrainedTokenizer): the tokenizer to use
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
str: response
|
236 |
+
"""
|
237 |
+
|
238 |
+
generation_pipeline = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer, **kwargs)
|
239 |
+
return generation_pipeline(instruction)[0]["generated_text"]
|
training/trainer.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Databricks, Inc.
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import logging
|
16 |
+
from functools import partial
|
17 |
+
from pathlib import Path
|
18 |
+
from typing import Any, Dict, List, Tuple, Union
|
19 |
+
|
20 |
+
import click
|
21 |
+
import numpy as np
|
22 |
+
from datasets import Dataset, load_dataset
|
23 |
+
from transformers import (
|
24 |
+
AutoModelForCausalLM,
|
25 |
+
AutoTokenizer,
|
26 |
+
DataCollatorForLanguageModeling,
|
27 |
+
PreTrainedTokenizer,
|
28 |
+
Trainer,
|
29 |
+
TrainingArguments,
|
30 |
+
set_seed,
|
31 |
+
)
|
32 |
+
|
33 |
+
from .consts import (
|
34 |
+
DEFAULT_INPUT_MODEL,
|
35 |
+
DEFAULT_SEED,
|
36 |
+
PROMPT_WITH_INPUT_FORMAT,
|
37 |
+
PROMPT_NO_INPUT_FORMAT,
|
38 |
+
END_KEY,
|
39 |
+
INSTRUCTION_KEY,
|
40 |
+
RESPONSE_KEY_NL,
|
41 |
+
)
|
42 |
+
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
ROOT_PATH = Path(__file__).parent.parent
|
45 |
+
DATABRICKS_DOLLY_15K_PATH = ROOT_PATH / "data" / "databricks-dolly-15k.jsonl"
|
46 |
+
|
47 |
+
|
48 |
+
class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
49 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
50 |
+
batch = super().torch_call(examples)
|
51 |
+
|
52 |
+
# The prompt ends with the response key plus a newline. We encode this and then try to find it in the
|
53 |
+
# sequence of tokens. This should just be a single token.
|
54 |
+
response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)
|
55 |
+
|
56 |
+
labels = batch["labels"].clone()
|
57 |
+
|
58 |
+
for i in range(len(examples)):
|
59 |
+
|
60 |
+
response_token_ids_start_idx = None
|
61 |
+
for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
|
62 |
+
response_token_ids_start_idx = idx
|
63 |
+
break
|
64 |
+
|
65 |
+
if response_token_ids_start_idx is None:
|
66 |
+
raise RuntimeError(
|
67 |
+
f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}'
|
68 |
+
)
|
69 |
+
|
70 |
+
response_token_ids_end_idx = response_token_ids_start_idx + 1
|
71 |
+
|
72 |
+
# Make pytorch loss function ignore all tokens up through the end of the response key
|
73 |
+
labels[i, :response_token_ids_end_idx] = -100
|
74 |
+
|
75 |
+
batch["labels"] = labels
|
76 |
+
|
77 |
+
return batch
|
78 |
+
|
79 |
+
|
80 |
+
def preprocess_batch(batch: Dict[str, List], tokenizer: AutoTokenizer, max_length: int) -> dict:
|
81 |
+
return tokenizer(
|
82 |
+
batch["text"],
|
83 |
+
max_length=max_length,
|
84 |
+
truncation=True,
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
def load_training_dataset() -> Dataset:
|
89 |
+
logger.info(f"Loading dataset from {DATABRICKS_DOLLY_15K_PATH}")
|
90 |
+
dataset = load_dataset("json", data_files=str(DATABRICKS_DOLLY_15K_PATH))["train"]
|
91 |
+
logger.info("Found %d rows", dataset.num_rows)
|
92 |
+
|
93 |
+
def _add_text(rec):
|
94 |
+
instruction = rec["instruction"]
|
95 |
+
response = rec["response"]
|
96 |
+
context = rec.get("context")
|
97 |
+
|
98 |
+
if not instruction:
|
99 |
+
raise ValueError(f"Expected an instruction in: {rec}")
|
100 |
+
|
101 |
+
if not response:
|
102 |
+
raise ValueError(f"Expected a response in: {rec}")
|
103 |
+
|
104 |
+
# For some instructions there is an input that goes along with the instruction, providing context for the
|
105 |
+
# instruction. For example, the input might be a passage from Wikipedia and the instruction says to extract
|
106 |
+
# some piece of information from it. The response is that information to extract. In other cases there is
|
107 |
+
# no input. For example, the instruction might be open QA such as asking what year some historic figure was
|
108 |
+
# born.
|
109 |
+
if context:
|
110 |
+
rec["text"] = PROMPT_WITH_INPUT_FORMAT.format(instruction=instruction, response=response, input=context)
|
111 |
+
else:
|
112 |
+
rec["text"] = PROMPT_NO_INPUT_FORMAT.format(instruction=instruction, response=response)
|
113 |
+
return rec
|
114 |
+
|
115 |
+
dataset = dataset.map(_add_text)
|
116 |
+
|
117 |
+
return dataset
|
118 |
+
|
119 |
+
|
120 |
+
def load_tokenizer(pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL) -> PreTrainedTokenizer:
|
121 |
+
logger.info(f"Loading tokenizer for {pretrained_model_name_or_path}")
|
122 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
123 |
+
tokenizer.pad_token = tokenizer.eos_token
|
124 |
+
tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})
|
125 |
+
return tokenizer
|
126 |
+
|
127 |
+
|
128 |
+
def load_model(
|
129 |
+
pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL, *, gradient_checkpointing: bool = False
|
130 |
+
) -> AutoModelForCausalLM:
|
131 |
+
logger.info(f"Loading model for {pretrained_model_name_or_path}")
|
132 |
+
model = AutoModelForCausalLM.from_pretrained(
|
133 |
+
pretrained_model_name_or_path, trust_remote_code=True, use_cache=False if gradient_checkpointing else True
|
134 |
+
)
|
135 |
+
return model
|
136 |
+
|
137 |
+
|
138 |
+
def get_model_tokenizer(
|
139 |
+
pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL, *, gradient_checkpointing: bool = False
|
140 |
+
) -> Tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
|
141 |
+
tokenizer = load_tokenizer(pretrained_model_name_or_path)
|
142 |
+
model = load_model(pretrained_model_name_or_path, gradient_checkpointing=gradient_checkpointing)
|
143 |
+
model.resize_token_embeddings(len(tokenizer))
|
144 |
+
|
145 |
+
return model, tokenizer
|
146 |
+
|
147 |
+
|
148 |
+
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed=DEFAULT_SEED) -> Dataset:
|
149 |
+
"""Loads the training dataset and tokenizes it so it is ready for training.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
tokenizer (AutoTokenizer): Tokenizer tied to the model.
|
153 |
+
max_length (int): Maximum number of tokens to emit from tokenizer.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Dataset: HuggingFace dataset
|
157 |
+
"""
|
158 |
+
|
159 |
+
dataset = load_training_dataset()
|
160 |
+
|
161 |
+
logger.info("Preprocessing dataset")
|
162 |
+
_preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
|
163 |
+
dataset = dataset.map(
|
164 |
+
_preprocessing_function,
|
165 |
+
batched=True,
|
166 |
+
remove_columns=["instruction", "context", "response", "text", "category"],
|
167 |
+
)
|
168 |
+
|
169 |
+
# Make sure we don't have any truncated records, as this would mean the end keyword is missing.
|
170 |
+
logger.info("Processed dataset has %d rows", dataset.num_rows)
|
171 |
+
dataset = dataset.filter(lambda rec: len(rec["input_ids"]) < max_length)
|
172 |
+
logger.info("Processed dataset has %d rows after filtering for truncated records", dataset.num_rows)
|
173 |
+
|
174 |
+
logger.info("Shuffling dataset")
|
175 |
+
dataset = dataset.shuffle(seed=seed)
|
176 |
+
|
177 |
+
logger.info("Done preprocessing")
|
178 |
+
|
179 |
+
return dataset
|
180 |
+
|
181 |
+
|
182 |
+
def train(
|
183 |
+
*,
|
184 |
+
input_model: str,
|
185 |
+
local_output_dir: str,
|
186 |
+
dbfs_output_dir: str,
|
187 |
+
epochs: int,
|
188 |
+
per_device_train_batch_size: int,
|
189 |
+
per_device_eval_batch_size: int,
|
190 |
+
lr: float,
|
191 |
+
seed: int,
|
192 |
+
deepspeed: str,
|
193 |
+
gradient_checkpointing: bool,
|
194 |
+
local_rank: str,
|
195 |
+
bf16: bool,
|
196 |
+
logging_steps: int,
|
197 |
+
save_steps: int,
|
198 |
+
eval_steps: int,
|
199 |
+
test_size: Union[float, int],
|
200 |
+
save_total_limit: int,
|
201 |
+
warmup_steps: int,
|
202 |
+
):
|
203 |
+
set_seed(seed)
|
204 |
+
|
205 |
+
model, tokenizer = get_model_tokenizer(
|
206 |
+
pretrained_model_name_or_path=input_model, gradient_checkpointing=gradient_checkpointing
|
207 |
+
)
|
208 |
+
|
209 |
+
# Use the same max length that the model supports. Fall back to 1024 if the setting can't be found.
|
210 |
+
# The configuraton for the length can be stored under different names depending on the model. Here we attempt
|
211 |
+
# a few possible names we've encountered.
|
212 |
+
conf = model.config
|
213 |
+
max_length = None
|
214 |
+
for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
|
215 |
+
max_length = getattr(model.config, length_setting, None)
|
216 |
+
if max_length:
|
217 |
+
logger.info(f"Found max lenth: {max_length}")
|
218 |
+
break
|
219 |
+
if not max_length:
|
220 |
+
max_length = 1024
|
221 |
+
logger.info(f"Using default max length: {max_length}")
|
222 |
+
|
223 |
+
processed_dataset = preprocess_dataset(tokenizer=tokenizer, max_length=max_length, seed=seed)
|
224 |
+
|
225 |
+
split_dataset = processed_dataset.train_test_split(test_size=test_size, seed=seed)
|
226 |
+
|
227 |
+
logger.info("Train data size: %d", split_dataset["train"].num_rows)
|
228 |
+
logger.info("Test data size: %d", split_dataset["test"].num_rows)
|
229 |
+
|
230 |
+
data_collator = DataCollatorForCompletionOnlyLM(
|
231 |
+
tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
|
232 |
+
)
|
233 |
+
|
234 |
+
if not dbfs_output_dir:
|
235 |
+
logger.warn("Will NOT save to DBFS")
|
236 |
+
|
237 |
+
training_args = TrainingArguments(
|
238 |
+
output_dir=local_output_dir,
|
239 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
240 |
+
per_device_eval_batch_size=per_device_eval_batch_size,
|
241 |
+
fp16=False,
|
242 |
+
bf16=bf16,
|
243 |
+
learning_rate=lr,
|
244 |
+
num_train_epochs=epochs,
|
245 |
+
deepspeed=deepspeed,
|
246 |
+
gradient_checkpointing=gradient_checkpointing,
|
247 |
+
logging_dir=f"{local_output_dir}/runs",
|
248 |
+
logging_strategy="steps",
|
249 |
+
logging_steps=logging_steps,
|
250 |
+
evaluation_strategy="steps",
|
251 |
+
eval_steps=eval_steps,
|
252 |
+
save_strategy="steps",
|
253 |
+
save_steps=save_steps,
|
254 |
+
save_total_limit=save_total_limit,
|
255 |
+
load_best_model_at_end=False,
|
256 |
+
report_to="tensorboard",
|
257 |
+
disable_tqdm=True,
|
258 |
+
remove_unused_columns=False,
|
259 |
+
local_rank=local_rank,
|
260 |
+
warmup_steps=warmup_steps,
|
261 |
+
)
|
262 |
+
|
263 |
+
logger.info("Instantiating Trainer")
|
264 |
+
|
265 |
+
trainer = Trainer(
|
266 |
+
model=model,
|
267 |
+
tokenizer=tokenizer,
|
268 |
+
args=training_args,
|
269 |
+
train_dataset=split_dataset["train"],
|
270 |
+
eval_dataset=split_dataset["test"],
|
271 |
+
data_collator=data_collator,
|
272 |
+
)
|
273 |
+
|
274 |
+
logger.info("Training")
|
275 |
+
trainer.train()
|
276 |
+
|
277 |
+
logger.info(f"Saving Model to {local_output_dir}")
|
278 |
+
trainer.save_model(output_dir=local_output_dir)
|
279 |
+
|
280 |
+
if dbfs_output_dir:
|
281 |
+
logger.info(f"Saving Model to {dbfs_output_dir}")
|
282 |
+
trainer.save_model(output_dir=dbfs_output_dir)
|
283 |
+
|
284 |
+
logger.info("Done.")
|
285 |
+
|
286 |
+
|
287 |
+
@click.command()
|
288 |
+
@click.option("--input-model", type=str, help="Input model to fine tune", default=DEFAULT_INPUT_MODEL)
|
289 |
+
@click.option("--local-output-dir", type=str, help="Write directly to this local path", required=True)
|
290 |
+
@click.option("--dbfs-output-dir", type=str, help="Sync data to this path on DBFS")
|
291 |
+
@click.option("--epochs", type=int, default=3, help="Number of epochs to train for.")
|
292 |
+
@click.option("--per-device-train-batch-size", type=int, default=8, help="Batch size to use for training.")
|
293 |
+
@click.option("--per-device-eval-batch-size", type=int, default=8, help="Batch size to use for evaluation.")
|
294 |
+
@click.option(
|
295 |
+
"--test-size", type=int, default=1000, help="Number of test records for evaluation, or ratio of test records."
|
296 |
+
)
|
297 |
+
@click.option("--warmup-steps", type=int, default=None, help="Number of steps to warm up to learning rate")
|
298 |
+
@click.option("--logging-steps", type=int, default=10, help="How often to log")
|
299 |
+
@click.option("--eval-steps", type=int, default=50, help="How often to run evaluation on test records")
|
300 |
+
@click.option("--save-steps", type=int, default=400, help="How often to checkpoint the model")
|
301 |
+
@click.option("--save-total-limit", type=int, default=10, help="Maximum number of checkpoints to keep on disk")
|
302 |
+
@click.option("--lr", type=float, default=1e-5, help="Learning rate to use for training.")
|
303 |
+
@click.option("--seed", type=int, default=DEFAULT_SEED, help="Seed to use for training.")
|
304 |
+
@click.option("--deepspeed", type=str, default=None, help="Path to deepspeed config file.")
|
305 |
+
@click.option(
|
306 |
+
"--gradient-checkpointing/--no-gradient-checkpointing",
|
307 |
+
is_flag=True,
|
308 |
+
default=True,
|
309 |
+
help="Use gradient checkpointing?",
|
310 |
+
)
|
311 |
+
@click.option(
|
312 |
+
"--local_rank",
|
313 |
+
type=str,
|
314 |
+
default=True,
|
315 |
+
help="Provided by deepspeed to identify which instance this process is when performing multi-GPU training.",
|
316 |
+
)
|
317 |
+
@click.option("--bf16", type=bool, default=True, help="Whether to use bf16 (preferred on A100's).")
|
318 |
+
def main(**kwargs):
|
319 |
+
train(**kwargs)
|
320 |
+
|
321 |
+
|
322 |
+
if __name__ == "__main__":
|
323 |
+
logging.basicConfig(
|
324 |
+
format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
|
325 |
+
)
|
326 |
+
try:
|
327 |
+
main()
|
328 |
+
except Exception:
|
329 |
+
logger.exception("main failed")
|
330 |
+
raise
|