Upload train.py with huggingface_hub
Browse files
train.py
CHANGED
|
@@ -41,9 +41,6 @@ log_and_write(log_dir, f'training data: {data_dir}')
|
|
| 41 |
# -----------------------------------------------------------------------------
|
| 42 |
|
| 43 |
|
| 44 |
-
# various inits, derived attributes, I/O setup
|
| 45 |
-
# ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
|
| 46 |
-
|
| 47 |
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
|
| 48 |
if ddp:
|
| 49 |
init_process_group(backend=backend)
|
|
@@ -53,13 +50,10 @@ if ddp:
|
|
| 53 |
device = f'cuda:{ddp_local_rank}'
|
| 54 |
torch.cuda.set_device(device)
|
| 55 |
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
| 56 |
-
seed_offset = ddp_rank
|
| 57 |
-
# world_size number of processes will be training simultaneously, so we can scale
|
| 58 |
-
# down the desired gradient accumulation iterations per process proportionally
|
| 59 |
assert gradient_accumulation_steps % ddp_world_size == 0
|
| 60 |
gradient_accumulation_steps //= ddp_world_size
|
| 61 |
else:
|
| 62 |
-
# if not ddp, we are running on a single gpu, and one process
|
| 63 |
master_process = True
|
| 64 |
seed_offset = 0
|
| 65 |
ddp_world_size = 1
|
|
@@ -85,7 +79,6 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc
|
|
| 85 |
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
| 86 |
|
| 87 |
# data loader
|
| 88 |
-
# data_dir = os.path.join('data', dataset)
|
| 89 |
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
| 90 |
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
| 91 |
def get_batch(split):
|
|
@@ -100,7 +93,6 @@ def get_batch(split):
|
|
| 100 |
x, y = x.to(device), y.to(device)
|
| 101 |
return x, y
|
| 102 |
|
| 103 |
-
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
|
| 104 |
iter_num = 0
|
| 105 |
best_val_loss = 1e9
|
| 106 |
|
|
@@ -127,7 +119,6 @@ if init_from == 'scratch':
|
|
| 127 |
elif init_from == 'resume':
|
| 128 |
print(f"Resuming training from {out_dir}")
|
| 129 |
# resume training from a checkpoint.
|
| 130 |
-
# ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
| 131 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 132 |
checkpoint_model_args = checkpoint['model_args']
|
| 133 |
# force these config attributes to be equal otherwise we can't even resume training
|
|
@@ -138,8 +129,6 @@ elif init_from == 'resume':
|
|
| 138 |
gptconf = GPTConfig(**model_args)
|
| 139 |
model = GPT(gptconf)
|
| 140 |
state_dict = checkpoint['model']
|
| 141 |
-
# fix the keys of the state dictionary :(
|
| 142 |
-
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
|
| 143 |
unwanted_prefix = '_orig_mod.'
|
| 144 |
for k,v in list(state_dict.items()):
|
| 145 |
if k.startswith(unwanted_prefix):
|
|
@@ -147,14 +136,6 @@ elif init_from == 'resume':
|
|
| 147 |
model.load_state_dict(state_dict)
|
| 148 |
iter_num = checkpoint['iter_num']
|
| 149 |
best_val_loss = checkpoint['best_val_loss']
|
| 150 |
-
elif init_from.startswith('gpt2'):
|
| 151 |
-
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
|
| 152 |
-
# initialize from OpenAI GPT-2 weights
|
| 153 |
-
override_args = dict(dropout=dropout)
|
| 154 |
-
model = GPT.from_pretrained(init_from, override_args)
|
| 155 |
-
# read off the created config params, so we can store them into checkpoint correctly
|
| 156 |
-
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
|
| 157 |
-
model_args[k] = getattr(model.config, k)
|
| 158 |
# crop down the model block size if desired, using model surgery
|
| 159 |
if block_size < model.config.block_size:
|
| 160 |
model.crop_block_size(block_size)
|
|
@@ -188,7 +169,7 @@ def estimate_loss():
|
|
| 188 |
model.eval()
|
| 189 |
for split in ['train', 'val']:
|
| 190 |
losses = torch.zeros(eval_iters)
|
| 191 |
-
total_loss = 0
|
| 192 |
for k in range(eval_iters):
|
| 193 |
X, Y = get_batch(split)
|
| 194 |
with ctx:
|
|
@@ -197,7 +178,7 @@ def estimate_loss():
|
|
| 197 |
total_loss += loss.item()
|
| 198 |
avg_loss = losses.mean()
|
| 199 |
out[split] = avg_loss
|
| 200 |
-
perplexities[split] = torch.exp(avg_loss)
|
| 201 |
model.train()
|
| 202 |
return out, perplexities
|
| 203 |
|
|
@@ -235,19 +216,20 @@ while True:
|
|
| 235 |
log_and_write(log_dir, f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f},train perplexity: {perplexities['train']:.4f}, val perplexity: {perplexities['val']:.4f}")
|
| 236 |
if iter_num % 200 == 0:
|
| 237 |
print_gpu_memory_usage()
|
| 238 |
-
if
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
| 251 |
if iter_num == 0 and eval_only:
|
| 252 |
break
|
| 253 |
|
|
|
|
| 41 |
# -----------------------------------------------------------------------------
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
|
| 45 |
if ddp:
|
| 46 |
init_process_group(backend=backend)
|
|
|
|
| 50 |
device = f'cuda:{ddp_local_rank}'
|
| 51 |
torch.cuda.set_device(device)
|
| 52 |
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
| 53 |
+
seed_offset = ddp_rank
|
|
|
|
|
|
|
| 54 |
assert gradient_accumulation_steps % ddp_world_size == 0
|
| 55 |
gradient_accumulation_steps //= ddp_world_size
|
| 56 |
else:
|
|
|
|
| 57 |
master_process = True
|
| 58 |
seed_offset = 0
|
| 59 |
ddp_world_size = 1
|
|
|
|
| 79 |
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
| 80 |
|
| 81 |
# data loader
|
|
|
|
| 82 |
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
|
| 83 |
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
|
| 84 |
def get_batch(split):
|
|
|
|
| 93 |
x, y = x.to(device), y.to(device)
|
| 94 |
return x, y
|
| 95 |
|
|
|
|
| 96 |
iter_num = 0
|
| 97 |
best_val_loss = 1e9
|
| 98 |
|
|
|
|
| 119 |
elif init_from == 'resume':
|
| 120 |
print(f"Resuming training from {out_dir}")
|
| 121 |
# resume training from a checkpoint.
|
|
|
|
| 122 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 123 |
checkpoint_model_args = checkpoint['model_args']
|
| 124 |
# force these config attributes to be equal otherwise we can't even resume training
|
|
|
|
| 129 |
gptconf = GPTConfig(**model_args)
|
| 130 |
model = GPT(gptconf)
|
| 131 |
state_dict = checkpoint['model']
|
|
|
|
|
|
|
| 132 |
unwanted_prefix = '_orig_mod.'
|
| 133 |
for k,v in list(state_dict.items()):
|
| 134 |
if k.startswith(unwanted_prefix):
|
|
|
|
| 136 |
model.load_state_dict(state_dict)
|
| 137 |
iter_num = checkpoint['iter_num']
|
| 138 |
best_val_loss = checkpoint['best_val_loss']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# crop down the model block size if desired, using model surgery
|
| 140 |
if block_size < model.config.block_size:
|
| 141 |
model.crop_block_size(block_size)
|
|
|
|
| 169 |
model.eval()
|
| 170 |
for split in ['train', 'val']:
|
| 171 |
losses = torch.zeros(eval_iters)
|
| 172 |
+
total_loss = 0
|
| 173 |
for k in range(eval_iters):
|
| 174 |
X, Y = get_batch(split)
|
| 175 |
with ctx:
|
|
|
|
| 178 |
total_loss += loss.item()
|
| 179 |
avg_loss = losses.mean()
|
| 180 |
out[split] = avg_loss
|
| 181 |
+
perplexities[split] = torch.exp(avg_loss)
|
| 182 |
model.train()
|
| 183 |
return out, perplexities
|
| 184 |
|
|
|
|
| 216 |
log_and_write(log_dir, f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f},train perplexity: {perplexities['train']:.4f}, val perplexity: {perplexities['val']:.4f}")
|
| 217 |
if iter_num % 200 == 0:
|
| 218 |
print_gpu_memory_usage()
|
| 219 |
+
if always_save_checkpoint:
|
| 220 |
+
if losses['val'] < best_val_loss or always_save_checkpoint:
|
| 221 |
+
best_val_loss = losses['val']
|
| 222 |
+
if iter_num > 0:
|
| 223 |
+
checkpoint = {
|
| 224 |
+
'model': raw_model.state_dict(),
|
| 225 |
+
'optimizer': optimizer.state_dict(),
|
| 226 |
+
'model_args': model_args,
|
| 227 |
+
'iter_num': iter_num,
|
| 228 |
+
'best_val_loss': best_val_loss,
|
| 229 |
+
'config': config,
|
| 230 |
+
}
|
| 231 |
+
log_and_write(log_dir, f"saving checkpoint to {out_dir}")
|
| 232 |
+
torch.save(checkpoint, os.path.join(out_dir, f'ckpt_{iter_num}.pt'))
|
| 233 |
if iter_num == 0 and eval_only:
|
| 234 |
break
|
| 235 |
|