Spaces:
Running
Running
File size: 5,088 Bytes
9df4cc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
import random
import numpy as np
from transformers import set_seed
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
def print_rank_0(msg, rank=0):
if rank <= 0:
print(msg)
def to_device(batch, device):
output = {}
for k, v in batch.items():
try:
output[k] = v.to(device)
except:
output[k] = v
return output
class MovingAverage:
def __init__(self):
self.count = 0
self.total = 0
self.mean = 0
def update(self, num):
self.total += num
self.count += 1
self.mean = self.total / self.count
return self.mean
def save_hf_format(model, tokenizer, args, sub_folder=""):
# used to save huggingface format, so we can use it for hf.from_pretrained
model_to_save = model.module if hasattr(model, 'module') else model
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
output_dir = os.path.join(args.output_dir, sub_folder)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)
save_dict = model_to_save.state_dict()
for key in list(save_dict.keys()):
if "lora" in key:
del save_dict[key]
torch.save(save_dict, output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)
def set_random_seed(seed):
if seed is not None:
set_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_all_reduce_mean(tensor):
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
tensor = tensor / torch.distributed.get_world_size()
return tensor
def get_optimizer_grouped_parameters(model,
weight_decay,
no_decay_name_list=[
"bias", "LayerNorm.weight"
]):
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters()
if (not any(nd in n
for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay":
weight_decay,
},
{
"params": [
p for n, p in model.named_parameters()
if (any(nd in n
for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay":
0.0,
},
]
return optimizer_grouped_parameters
def _z3_params_to_fetch(param_list):
return [
p for p in param_list
if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
]
def moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0):
zero_stage_3 = (zero_stage == 3)
with torch.no_grad():
for param, param_ema in zip(model.parameters(),
model_ema.parameters()):
# TODO: use prefiltering for efficiency
params_to_fetch = _z3_params_to_fetch([param, param_ema
]) if zero_stage_3 else []
should_gather_param = len(params_to_fetch) > 0
with deepspeed.zero.GatheredParameters(
params_to_fetch, enabled=should_gather_param):
data = param.data
if device is not None:
data = data.to(device)
param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))
def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
zero_stage_3 = (zero_stage == 3)
os.makedirs(save_dir, exist_ok=True)
WEIGHTS_NAME = "pytorch_model.bin"
output_model_file = os.path.join(save_dir, WEIGHTS_NAME)
model_to_save = model_ema.module if hasattr(model_ema,
'module') else model_ema
if not zero_stage_3:
if global_rank == 0:
torch.save(model_to_save.state_dict(), output_model_file)
else:
output_state_dict = {}
for k, v in model_to_save.named_parameters():
if hasattr(v, 'ds_id'):
with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v
]),
enabled=zero_stage_3):
v_p = v.data.cpu()
else:
v_p = v.cpu()
if global_rank == 0 and "lora" not in k:
output_state_dict[k] = v_p
if global_rank == 0:
torch.save(output_state_dict, output_model_file)
del output_state_dict
|