diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/checkpoint-10/config.json b/checkpoint-10/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b59b0b4c67b30baa7b62a3a87fc086e8dd1f8916 --- /dev/null +++ b/checkpoint-10/config.json @@ -0,0 +1,31 @@ +{ + "_name_or_path": "georgeyw/gpt-2-small-init-seed-5", + "architectures": [ + "GPTNeoXForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "bos_token_id": 0, + "classifier_dropout": 0.1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 1024, + "model_type": "gpt_neox", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "rope_scaling": null, + "rotary_emb_base": 10000, + "rotary_pct": 0.25, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.2", + "use_cache": true, + "use_parallel_residual": true, + "vocab_size": 50304 +} diff --git a/checkpoint-10/generation_config.json b/checkpoint-10/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bc00b333fdf0ba3611d022ddfdaeaf527fab8da0 --- /dev/null +++ b/checkpoint-10/generation_config.json @@ -0,0 +1,6 @@ +{ + "_from_model_config": true, + "bos_token_id": 0, + "eos_token_id": 2, + "transformers_version": "4.38.2" +} diff --git a/checkpoint-10/global_step10/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-10/global_step10/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..b0f261ca1ace0d9ceaf6ffecfd2ec127229ca728 --- /dev/null +++ b/checkpoint-10/global_step10/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ed0df787a125d81094d867c743d5ffebb34766f877f42e3b5da9ad0e51fbf81e +size 973946896 diff --git a/checkpoint-10/global_step10/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-10/global_step10/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..e880aa67a87dc43e7d483114be384f110465422e --- /dev/null +++ b/checkpoint-10/global_step10/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f6768e7d81521ef703a230f41460e04972f79ded2de3b754bcd0fb7a96f4ff0 +size 973946832 diff --git a/checkpoint-10/global_step10/mp_rank_00_model_states.pt b/checkpoint-10/global_step10/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..675c740ca5aad38baa852ba4edd6f70c1f44c4c8 --- /dev/null +++ b/checkpoint-10/global_step10/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7589ffd4359a842efe83dc7cc5ce4ff6a161f1fce829498ade432719a3c40d91 +size 324689964 diff --git a/checkpoint-10/latest b/checkpoint-10/latest new file mode 100644 index 0000000000000000000000000000000000000000..e23122de54baa1dd9f514f25b8a62c6026b72e10 --- /dev/null +++ b/checkpoint-10/latest @@ -0,0 +1 @@ +global_step10 \ No newline at end of file diff --git a/checkpoint-10/model.safetensors b/checkpoint-10/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..9324eb0b3d3d7f07394235981c53c889bcefa1c4 --- /dev/null +++ b/checkpoint-10/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6ba4e69094dffe72d300db52887b581d271c61d1a95702fdea83bf7d6d5697f +size 324662984 diff --git a/checkpoint-10/rng_state_0.pth b/checkpoint-10/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..cad18ac770da4331076b9ef49fc91a7f9a5989c3 --- /dev/null +++ b/checkpoint-10/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0bb7d2ecdd48fd7d0be1e75b0e3f29004064381052fa203ed926e88b90ef530 +size 14512 diff --git a/checkpoint-10/rng_state_1.pth b/checkpoint-10/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..197bac5f7fe92d301270b1f25b8fa7a07b568293 --- /dev/null +++ b/checkpoint-10/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:177d534a379bd6b276474c2cb140e318dc65db4457b6c1b6f25a1a9dd563af82 +size 14512 diff --git a/checkpoint-10/trainer_state.json b/checkpoint-10/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..bc3dce9f0d1ca98aab28bebcb08a1a58a8dc8743 --- /dev/null +++ b/checkpoint-10/trainer_state.json @@ -0,0 +1,91 @@ +{ + "best_metric": null, + "best_model_checkpoint": null, + "epoch": 0.009995002498750625, + "eval_steps": 500, + "global_step": 10, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 0.0, + "grad_norm": 3.3340553137590683, + "learning_rate": 0.0, + "loss": 11.0, + "step": 1 + }, + { + "epoch": 0.0, + "grad_norm": 2.398799707355898, + "learning_rate": 5.9999999999999995e-05, + "loss": 10.125, + "step": 2 + }, + { + "epoch": 0.0, + "grad_norm": 2.3943029297945575, + "learning_rate": 0.00011999999999999999, + "loss": 10.1172, + "step": 3 + }, + { + "epoch": 0.0, + "grad_norm": 1.9959117709404242, + "learning_rate": 0.00017999999999999998, + "loss": 9.875, + "step": 4 + }, + { + "epoch": 0.0, + "grad_norm": 1.8270696218303057, + "learning_rate": 0.00023999999999999998, + "loss": 9.6641, + "step": 5 + }, + { + "epoch": 0.01, + "grad_norm": 1.7854351602113614, + "learning_rate": 0.0003, + "loss": 9.4844, + "step": 6 + }, + { + "epoch": 0.01, + "grad_norm": 1.7194174424274788, + "learning_rate": 0.00035999999999999997, + "loss": 9.3281, + "step": 7 + }, + { + "epoch": 0.01, + "grad_norm": 1.463772638994466, + "learning_rate": 0.00041999999999999996, + "loss": 9.2109, + "step": 8 + }, + { + "epoch": 0.01, + "grad_norm": 1.439323678271545, + "learning_rate": 0.00047999999999999996, + "loss": 8.9453, + "step": 9 + }, + { + "epoch": 0.01, + "grad_norm": 1.2936126396494727, + "learning_rate": 0.00054, + "loss": 8.7109, + "step": 10 + } + ], + "logging_steps": 1, + "max_steps": 1000, + "num_input_tokens_seen": 0, + "num_train_epochs": 1, + "save_steps": 10, + "total_flos": 0.0, + "train_batch_size": 32, + "trial_name": null, + "trial_params": null +} diff --git a/checkpoint-10/training_args.bin b/checkpoint-10/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..12fdb7967b1254c497de146410ac3cd352b2b9c7 --- /dev/null +++ b/checkpoint-10/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95cc4290cc90782d57f7376defd26743b3a36943fc93e80e2734385bc57e8b78 +size 6520 diff --git a/checkpoint-10/zero_to_fp32.py b/checkpoint-10/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/checkpoint-10/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/checkpoint-100/config.json b/checkpoint-100/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b59b0b4c67b30baa7b62a3a87fc086e8dd1f8916 --- /dev/null +++ b/checkpoint-100/config.json @@ -0,0 +1,31 @@ +{ + "_name_or_path": "georgeyw/gpt-2-small-init-seed-5", + "architectures": [ + "GPTNeoXForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "bos_token_id": 0, + "classifier_dropout": 0.1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 1024, + "model_type": "gpt_neox", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "rope_scaling": null, + "rotary_emb_base": 10000, + "rotary_pct": 0.25, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.2", + "use_cache": true, + "use_parallel_residual": true, + "vocab_size": 50304 +} diff --git a/checkpoint-100/generation_config.json b/checkpoint-100/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bc00b333fdf0ba3611d022ddfdaeaf527fab8da0 --- /dev/null +++ b/checkpoint-100/generation_config.json @@ -0,0 +1,6 @@ +{ + "_from_model_config": true, + "bos_token_id": 0, + "eos_token_id": 2, + "transformers_version": "4.38.2" +} diff --git a/checkpoint-100/global_step100/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-100/global_step100/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..ecdd6fd2fa56484b74bba91246a451e0adef01b1 --- /dev/null +++ b/checkpoint-100/global_step100/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f13244aa35b694472ea5bd7fde3bcee569b6a0b7251253f818b7e5ae6a048797 +size 973946896 diff --git a/checkpoint-100/global_step100/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-100/global_step100/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..8e7fe8649241838d677a68ae576e4ac80b433b38 --- /dev/null +++ b/checkpoint-100/global_step100/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1039a6c333c1647ef2ddfc010c1b966f38d1001aff73ac8b3cdede7858acdc7 +size 973946832 diff --git a/checkpoint-100/global_step100/mp_rank_00_model_states.pt b/checkpoint-100/global_step100/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7d48c4a24978aa645f94e4efa846ec4ca0e361e8 --- /dev/null +++ b/checkpoint-100/global_step100/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26fe4253d323b90c0fa9e0d5da8cb64819f86baf288d72f171c73184ca877562 +size 324689964 diff --git a/checkpoint-100/latest b/checkpoint-100/latest new file mode 100644 index 0000000000000000000000000000000000000000..744ae7dbad571b6f37ec6c7066549494261bb59e --- /dev/null +++ b/checkpoint-100/latest @@ -0,0 +1 @@ +global_step100 \ No newline at end of file diff --git a/checkpoint-100/model.safetensors b/checkpoint-100/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..9fa79b4f67f024efc247dbe7f86dc19f93f9353c --- /dev/null +++ b/checkpoint-100/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c9a5c8eae560953cb2afb2b3bf5d1330d089af788edaa1bf1d3f77b218837bc +size 324662984 diff --git a/checkpoint-100/rng_state_0.pth b/checkpoint-100/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..cad18ac770da4331076b9ef49fc91a7f9a5989c3 --- /dev/null +++ b/checkpoint-100/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0bb7d2ecdd48fd7d0be1e75b0e3f29004064381052fa203ed926e88b90ef530 +size 14512 diff --git a/checkpoint-100/rng_state_1.pth b/checkpoint-100/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..197bac5f7fe92d301270b1f25b8fa7a07b568293 --- /dev/null +++ b/checkpoint-100/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:177d534a379bd6b276474c2cb140e318dc65db4457b6c1b6f25a1a9dd563af82 +size 14512 diff --git a/checkpoint-100/trainer_state.json b/checkpoint-100/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..24d042a91824d3191100387ef9316497f75dbc39 --- /dev/null +++ b/checkpoint-100/trainer_state.json @@ -0,0 +1,721 @@ +{ + "best_metric": null, + "best_model_checkpoint": null, + "epoch": 0.09995002498750624, + "eval_steps": 500, + "global_step": 100, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 0.0, + "grad_norm": 3.3340563149001086, + "learning_rate": 0.0, + "loss": 11.0, + "step": 1 + }, + { + "epoch": 0.0, + "grad_norm": 2.398812329952019, + "learning_rate": 5.9999999999999995e-05, + "loss": 10.125, + "step": 2 + }, + { + "epoch": 0.0, + "grad_norm": 2.394322446895115, + "learning_rate": 0.00011999999999999999, + "loss": 10.1172, + "step": 3 + }, + { + "epoch": 0.0, + "grad_norm": 1.9958816684399585, + "learning_rate": 0.00017999999999999998, + "loss": 9.875, + "step": 4 + }, + { + "epoch": 0.0, + "grad_norm": 1.8270465897882062, + "learning_rate": 0.00023999999999999998, + "loss": 9.6641, + "step": 5 + }, + { + "epoch": 0.01, + "grad_norm": 1.7854046471397795, + "learning_rate": 0.0003, + "loss": 9.4844, + "step": 6 + }, + { + "epoch": 0.01, + "grad_norm": 1.719416749115252, + "learning_rate": 0.00035999999999999997, + "loss": 9.3281, + "step": 7 + }, + { + "epoch": 0.01, + "grad_norm": 1.4637825746112274, + "learning_rate": 0.00041999999999999996, + "loss": 9.2109, + "step": 8 + }, + { + "epoch": 0.01, + "grad_norm": 1.4393631015406718, + "learning_rate": 0.00047999999999999996, + "loss": 8.9453, + "step": 9 + }, + { + "epoch": 0.01, + "grad_norm": 1.2936734586915988, + "learning_rate": 0.00054, + "loss": 8.7109, + "step": 10 + }, + { + "epoch": 0.01, + "grad_norm": 1.0756922378227356, + "learning_rate": 0.0005999986405514987, + "loss": 8.4609, + "step": 11 + }, + { + "epoch": 0.01, + "grad_norm": 0.9277829127413892, + "learning_rate": 0.0005999945622196846, + "loss": 8.2344, + "step": 12 + }, + { + "epoch": 0.01, + "grad_norm": 0.8084581786682467, + "learning_rate": 0.0005999877650456265, + "loss": 8.125, + "step": 13 + }, + { + "epoch": 0.01, + "grad_norm": 0.7635084596900947, + "learning_rate": 0.000599978249097772, + "loss": 7.9766, + "step": 14 + }, + { + "epoch": 0.01, + "grad_norm": 0.9186699644247788, + "learning_rate": 0.0005999660144719463, + "loss": 7.8555, + "step": 15 + }, + { + "epoch": 0.02, + "grad_norm": 0.6609504256551479, + "learning_rate": 0.0005999510612913519, + "loss": 7.7734, + "step": 16 + }, + { + "epoch": 0.02, + "grad_norm": 0.7086232844782971, + "learning_rate": 0.0005999333897065673, + "loss": 7.7148, + "step": 17 + }, + { + "epoch": 0.02, + "grad_norm": 16.38048851691348, + "learning_rate": 0.0005999129998955453, + "loss": 8.4844, + "step": 18 + }, + { + "epoch": 0.02, + "grad_norm": 1.3057527590449889, + "learning_rate": 0.0005998898920636111, + "loss": 7.7539, + "step": 19 + }, + { + "epoch": 0.02, + "grad_norm": 0.6966048242948986, + "learning_rate": 0.00059986406644346, + "loss": 7.75, + "step": 20 + }, + { + "epoch": 0.02, + "grad_norm": 0.6348089115348993, + "learning_rate": 0.0005998355232951559, + "loss": 7.7031, + "step": 21 + }, + { + "epoch": 0.02, + "grad_norm": 0.7829163518610293, + "learning_rate": 0.0005998042629061279, + "loss": 7.6992, + "step": 22 + }, + { + "epoch": 0.02, + "grad_norm": 0.5900591778980369, + "learning_rate": 0.0005997702855911678, + "loss": 7.6016, + "step": 23 + }, + { + "epoch": 0.02, + "grad_norm": 0.4655170213064256, + "learning_rate": 0.0005997335916924268, + "loss": 7.5977, + "step": 24 + }, + { + "epoch": 0.02, + "grad_norm": 0.6287348258915756, + "learning_rate": 0.0005996941815794121, + "loss": 7.5586, + "step": 25 + }, + { + "epoch": 0.03, + "grad_norm": 0.6137321903884564, + "learning_rate": 0.0005996520556489831, + "loss": 7.5898, + "step": 26 + }, + { + "epoch": 0.03, + "grad_norm": 0.44962562710631065, + "learning_rate": 0.0005996072143253473, + "loss": 7.4336, + "step": 27 + }, + { + "epoch": 0.03, + "grad_norm": 0.46130046454703316, + "learning_rate": 0.0005995596580600566, + "loss": 7.4023, + "step": 28 + }, + { + "epoch": 0.03, + "grad_norm": 0.4686712675731326, + "learning_rate": 0.0005995093873320018, + "loss": 7.3789, + "step": 29 + }, + { + "epoch": 0.03, + "grad_norm": 0.4672147564288997, + "learning_rate": 0.0005994564026474087, + "loss": 7.3711, + "step": 30 + }, + { + "epoch": 0.03, + "grad_norm": 0.40408354581233474, + "learning_rate": 0.0005994007045398324, + "loss": 7.3672, + "step": 31 + }, + { + "epoch": 0.03, + "grad_norm": 0.46032146732584733, + "learning_rate": 0.0005993422935701524, + "loss": 7.3477, + "step": 32 + }, + { + "epoch": 0.03, + "grad_norm": 0.4765534634593268, + "learning_rate": 0.0005992811703265664, + "loss": 7.3555, + "step": 33 + }, + { + "epoch": 0.03, + "grad_norm": 0.46208489386235113, + "learning_rate": 0.0005992173354245849, + "loss": 7.3047, + "step": 34 + }, + { + "epoch": 0.03, + "grad_norm": 0.2956144524964961, + "learning_rate": 0.0005991507895070244, + "loss": 7.3125, + "step": 35 + }, + { + "epoch": 0.04, + "grad_norm": 0.4834645389868856, + "learning_rate": 0.0005990815332440017, + "loss": 7.207, + "step": 36 + }, + { + "epoch": 0.04, + "grad_norm": 0.4411831350968505, + "learning_rate": 0.0005990095673329266, + "loss": 7.1758, + "step": 37 + }, + { + "epoch": 0.04, + "grad_norm": 0.24809297748968667, + "learning_rate": 0.0005989348924984951, + "loss": 7.2188, + "step": 38 + }, + { + "epoch": 0.04, + "grad_norm": 0.39402988416840584, + "learning_rate": 0.0005988575094926817, + "loss": 7.1953, + "step": 39 + }, + { + "epoch": 0.04, + "grad_norm": 0.3868345222189167, + "learning_rate": 0.0005987774190947328, + "loss": 7.1641, + "step": 40 + }, + { + "epoch": 0.04, + "grad_norm": 0.3777261230135448, + "learning_rate": 0.0005986946221111575, + "loss": 7.1328, + "step": 41 + }, + { + "epoch": 0.04, + "grad_norm": 0.4687511444077827, + "learning_rate": 0.0005986091193757206, + "loss": 7.0898, + "step": 42 + }, + { + "epoch": 0.04, + "grad_norm": 0.34935796211612463, + "learning_rate": 0.0005985209117494337, + "loss": 7.1367, + "step": 43 + }, + { + "epoch": 0.04, + "grad_norm": 0.38764476686849886, + "learning_rate": 0.0005984300001205466, + "loss": 7.125, + "step": 44 + }, + { + "epoch": 0.04, + "grad_norm": 0.3956487898882936, + "learning_rate": 0.0005983363854045386, + "loss": 7.1094, + "step": 45 + }, + { + "epoch": 0.05, + "grad_norm": 0.31140257544677513, + "learning_rate": 0.0005982400685441084, + "loss": 7.0898, + "step": 46 + }, + { + "epoch": 0.05, + "grad_norm": 0.3664476570531787, + "learning_rate": 0.0005981410505091662, + "loss": 7.0664, + "step": 47 + }, + { + "epoch": 0.05, + "grad_norm": 0.31891741142945207, + "learning_rate": 0.0005980393322968223, + "loss": 7.0273, + "step": 48 + }, + { + "epoch": 0.05, + "grad_norm": 0.4533529037337155, + "learning_rate": 0.0005979349149313778, + "loss": 7.0586, + "step": 49 + }, + { + "epoch": 0.05, + "grad_norm": 0.30532331638835586, + "learning_rate": 0.0005978277994643147, + "loss": 7.0195, + "step": 50 + }, + { + "epoch": 0.05, + "grad_norm": 0.6501991746260075, + "learning_rate": 0.0005977179869742844, + "loss": 6.9648, + "step": 51 + }, + { + "epoch": 0.05, + "grad_norm": 0.43904455901717926, + "learning_rate": 0.0005976054785670975, + "loss": 6.9805, + "step": 52 + }, + { + "epoch": 0.05, + "grad_norm": 0.4826001598483571, + "learning_rate": 0.0005974902753757124, + "loss": 6.9297, + "step": 53 + }, + { + "epoch": 0.05, + "grad_norm": 0.2924998027034648, + "learning_rate": 0.000597372378560224, + "loss": 6.8984, + "step": 54 + }, + { + "epoch": 0.05, + "grad_norm": 0.4439033666380787, + "learning_rate": 0.0005972517893078517, + "loss": 6.8945, + "step": 55 + }, + { + "epoch": 0.06, + "grad_norm": 0.6135914255073411, + "learning_rate": 0.0005971285088329284, + "loss": 6.9727, + "step": 56 + }, + { + "epoch": 0.06, + "grad_norm": 0.5575686565598483, + "learning_rate": 0.0005970025383768866, + "loss": 6.9219, + "step": 57 + }, + { + "epoch": 0.06, + "grad_norm": 0.4820951675994578, + "learning_rate": 0.0005968738792082478, + "loss": 6.8516, + "step": 58 + }, + { + "epoch": 0.06, + "grad_norm": 0.40164190019465584, + "learning_rate": 0.0005967425326226082, + "loss": 6.7734, + "step": 59 + }, + { + "epoch": 0.06, + "grad_norm": 0.46129863945181293, + "learning_rate": 0.0005966084999426265, + "loss": 6.8125, + "step": 60 + }, + { + "epoch": 0.06, + "grad_norm": 0.33322355827118677, + "learning_rate": 0.0005964717825180101, + "loss": 6.7891, + "step": 61 + }, + { + "epoch": 0.06, + "grad_norm": 0.3847525153855558, + "learning_rate": 0.0005963323817255024, + "loss": 6.8242, + "step": 62 + }, + { + "epoch": 0.06, + "grad_norm": 0.3384433591375982, + "learning_rate": 0.0005961902989688674, + "loss": 6.707, + "step": 63 + }, + { + "epoch": 0.06, + "grad_norm": 0.3937003195165685, + "learning_rate": 0.000596045535678877, + "loss": 6.8203, + "step": 64 + }, + { + "epoch": 0.06, + "grad_norm": 0.35423488053528107, + "learning_rate": 0.0005958980933132962, + "loss": 6.7383, + "step": 65 + }, + { + "epoch": 0.07, + "grad_norm": 0.36005939745315396, + "learning_rate": 0.0005957479733568675, + "loss": 6.7109, + "step": 66 + }, + { + "epoch": 0.07, + "grad_norm": 0.3499278317706933, + "learning_rate": 0.0005955951773212976, + "loss": 6.7266, + "step": 67 + }, + { + "epoch": 0.07, + "grad_norm": 0.3708385192137018, + "learning_rate": 0.0005954397067452407, + "loss": 6.7617, + "step": 68 + }, + { + "epoch": 0.07, + "grad_norm": 0.3775657656205869, + "learning_rate": 0.0005952815631942839, + "loss": 6.7148, + "step": 69 + }, + { + "epoch": 0.07, + "grad_norm": 0.3040083750375816, + "learning_rate": 0.0005951207482609307, + "loss": 6.5938, + "step": 70 + }, + { + "epoch": 0.07, + "grad_norm": 0.3443020808841468, + "learning_rate": 0.0005949572635645861, + "loss": 6.6523, + "step": 71 + }, + { + "epoch": 0.07, + "grad_norm": 0.3520066316939, + "learning_rate": 0.0005947911107515389, + "loss": 6.6211, + "step": 72 + }, + { + "epoch": 0.07, + "grad_norm": 0.3739040572679613, + "learning_rate": 0.0005946222914949462, + "loss": 6.5547, + "step": 73 + }, + { + "epoch": 0.07, + "grad_norm": 0.34890731989025553, + "learning_rate": 0.000594450807494816, + "loss": 6.5859, + "step": 74 + }, + { + "epoch": 0.07, + "grad_norm": 0.40910932350136514, + "learning_rate": 0.0005942766604779903, + "loss": 6.5547, + "step": 75 + }, + { + "epoch": 0.08, + "grad_norm": 0.5698342865852906, + "learning_rate": 0.0005940998521981274, + "loss": 6.457, + "step": 76 + }, + { + "epoch": 0.08, + "grad_norm": 0.5179452709555474, + "learning_rate": 0.0005939203844356852, + "loss": 6.5547, + "step": 77 + }, + { + "epoch": 0.08, + "grad_norm": 0.5222512938673792, + "learning_rate": 0.0005937382589979016, + "loss": 6.5039, + "step": 78 + }, + { + "epoch": 0.08, + "grad_norm": 0.5682332793686307, + "learning_rate": 0.0005935534777187781, + "loss": 6.5547, + "step": 79 + }, + { + "epoch": 0.08, + "grad_norm": 0.3869287710460676, + "learning_rate": 0.0005933660424590598, + "loss": 6.5156, + "step": 80 + }, + { + "epoch": 0.08, + "grad_norm": 0.3078211032807607, + "learning_rate": 0.000593175955106218, + "loss": 6.4258, + "step": 81 + }, + { + "epoch": 0.08, + "grad_norm": 0.3611357511872241, + "learning_rate": 0.00059298321757443, + "loss": 6.4727, + "step": 82 + }, + { + "epoch": 0.08, + "grad_norm": 0.29633467844266953, + "learning_rate": 0.0005927878318045608, + "loss": 6.3281, + "step": 83 + }, + { + "epoch": 0.08, + "grad_norm": 0.3257574200776832, + "learning_rate": 0.0005925897997641426, + "loss": 6.3203, + "step": 84 + }, + { + "epoch": 0.08, + "grad_norm": 0.2824054533852328, + "learning_rate": 0.0005923891234473562, + "loss": 6.4062, + "step": 85 + }, + { + "epoch": 0.09, + "grad_norm": 0.3056199770204573, + "learning_rate": 0.0005921858048750097, + "loss": 6.3984, + "step": 86 + }, + { + "epoch": 0.09, + "grad_norm": 0.2966438824341908, + "learning_rate": 0.000591979846094519, + "loss": 6.3555, + "step": 87 + }, + { + "epoch": 0.09, + "grad_norm": 0.32782438676663733, + "learning_rate": 0.0005917712491798866, + "loss": 6.4023, + "step": 88 + }, + { + "epoch": 0.09, + "grad_norm": 0.3538316399620157, + "learning_rate": 0.0005915600162316811, + "loss": 6.2812, + "step": 89 + }, + { + "epoch": 0.09, + "grad_norm": 0.375858298192913, + "learning_rate": 0.0005913461493770162, + "loss": 6.3086, + "step": 90 + }, + { + "epoch": 0.09, + "grad_norm": 0.5189251339815161, + "learning_rate": 0.0005911296507695284, + "loss": 6.2812, + "step": 91 + }, + { + "epoch": 0.09, + "grad_norm": 0.6304909542669104, + "learning_rate": 0.0005909105225893564, + "loss": 6.2969, + "step": 92 + }, + { + "epoch": 0.09, + "grad_norm": 0.4655662819622591, + "learning_rate": 0.0005906887670431187, + "loss": 6.1953, + "step": 93 + }, + { + "epoch": 0.09, + "grad_norm": 0.39035390983920965, + "learning_rate": 0.000590464386363891, + "loss": 6.2617, + "step": 94 + }, + { + "epoch": 0.09, + "grad_norm": 0.4918417851770978, + "learning_rate": 0.0005902373828111843, + "loss": 6.2148, + "step": 95 + }, + { + "epoch": 0.1, + "grad_norm": 0.35670770889552555, + "learning_rate": 0.0005900077586709219, + "loss": 6.2461, + "step": 96 + }, + { + "epoch": 0.1, + "grad_norm": 0.4177985869939347, + "learning_rate": 0.0005897755162554163, + "loss": 6.1797, + "step": 97 + }, + { + "epoch": 0.1, + "grad_norm": 0.3742471130708234, + "learning_rate": 0.000589540657903346, + "loss": 6.1406, + "step": 98 + }, + { + "epoch": 0.1, + "grad_norm": 0.28627666723978284, + "learning_rate": 0.0005893031859797322, + "loss": 6.2031, + "step": 99 + }, + { + "epoch": 0.1, + "grad_norm": 0.32238563846046103, + "learning_rate": 0.0005890631028759143, + "loss": 6.0625, + "step": 100 + } + ], + "logging_steps": 1, + "max_steps": 1000, + "num_input_tokens_seen": 0, + "num_train_epochs": 1, + "save_steps": 100, + "total_flos": 0.0, + "train_batch_size": 32, + "trial_name": null, + "trial_params": null +} diff --git a/checkpoint-100/training_args.bin b/checkpoint-100/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..9362a9e736fc862ece575b9f1b9d54b14c10d0b5 --- /dev/null +++ b/checkpoint-100/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36ce7ab48fa86ef42491eaad3583773d2b60353997a5e7b6fb4ffc1414828749 +size 6520 diff --git a/checkpoint-100/zero_to_fp32.py b/checkpoint-100/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/checkpoint-100/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/checkpoint-20/config.json b/checkpoint-20/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b59b0b4c67b30baa7b62a3a87fc086e8dd1f8916 --- /dev/null +++ b/checkpoint-20/config.json @@ -0,0 +1,31 @@ +{ + "_name_or_path": "georgeyw/gpt-2-small-init-seed-5", + "architectures": [ + "GPTNeoXForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "bos_token_id": 0, + "classifier_dropout": 0.1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 1024, + "model_type": "gpt_neox", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "rope_scaling": null, + "rotary_emb_base": 10000, + "rotary_pct": 0.25, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.2", + "use_cache": true, + "use_parallel_residual": true, + "vocab_size": 50304 +} diff --git a/checkpoint-20/generation_config.json b/checkpoint-20/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bc00b333fdf0ba3611d022ddfdaeaf527fab8da0 --- /dev/null +++ b/checkpoint-20/generation_config.json @@ -0,0 +1,6 @@ +{ + "_from_model_config": true, + "bos_token_id": 0, + "eos_token_id": 2, + "transformers_version": "4.38.2" +} diff --git a/checkpoint-20/global_step20/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-20/global_step20/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..6bfb90110247f55ab160f67d26b22ceea539e114 --- /dev/null +++ b/checkpoint-20/global_step20/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d9619b0d3ceeed4c45a4c6fd5be8946b1be76a5801c466123be2ca841e5e337 +size 973946896 diff --git a/checkpoint-20/global_step20/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-20/global_step20/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..297f068c376d3445af4da4235e076870c1eb4638 --- /dev/null +++ b/checkpoint-20/global_step20/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c51b9f289ed8d24931448d39d8478de2e97527186251a29c92850b1de7562ea +size 973946832 diff --git a/checkpoint-20/global_step20/mp_rank_00_model_states.pt b/checkpoint-20/global_step20/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..170878a2a94a51a5d303842c2132bf13084951e0 --- /dev/null +++ b/checkpoint-20/global_step20/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06da85a865332b0c223e8dc427f16b0c81330388f43d5d7b854b9b349b9e2f89 +size 324689964 diff --git a/checkpoint-20/latest b/checkpoint-20/latest new file mode 100644 index 0000000000000000000000000000000000000000..11e5c63223cdf01f44f9f3129915f9de3d647f31 --- /dev/null +++ b/checkpoint-20/latest @@ -0,0 +1 @@ +global_step20 \ No newline at end of file diff --git a/checkpoint-20/model.safetensors b/checkpoint-20/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..591c2ddd08e3428fd39bffd37e5a33b4903ffe98 --- /dev/null +++ b/checkpoint-20/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:99984de686ddb1e07f6ea7778c91ab6e3547271af7b06526b184e6578f6bf40d +size 324662984 diff --git a/checkpoint-20/rng_state_0.pth b/checkpoint-20/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..cad18ac770da4331076b9ef49fc91a7f9a5989c3 --- /dev/null +++ b/checkpoint-20/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0bb7d2ecdd48fd7d0be1e75b0e3f29004064381052fa203ed926e88b90ef530 +size 14512 diff --git a/checkpoint-20/rng_state_1.pth b/checkpoint-20/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..197bac5f7fe92d301270b1f25b8fa7a07b568293 --- /dev/null +++ b/checkpoint-20/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:177d534a379bd6b276474c2cb140e318dc65db4457b6c1b6f25a1a9dd563af82 +size 14512 diff --git a/checkpoint-20/trainer_state.json b/checkpoint-20/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..2118c2a7fefe174872efc2845c011857f50e5f5a --- /dev/null +++ b/checkpoint-20/trainer_state.json @@ -0,0 +1,161 @@ +{ + "best_metric": null, + "best_model_checkpoint": null, + "epoch": 0.01999000499750125, + "eval_steps": 500, + "global_step": 20, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 0.0, + "grad_norm": 3.3340553137590683, + "learning_rate": 0.0, + "loss": 11.0, + "step": 1 + }, + { + "epoch": 0.0, + "grad_norm": 2.398799707355898, + "learning_rate": 5.9999999999999995e-05, + "loss": 10.125, + "step": 2 + }, + { + "epoch": 0.0, + "grad_norm": 2.3943029297945575, + "learning_rate": 0.00011999999999999999, + "loss": 10.1172, + "step": 3 + }, + { + "epoch": 0.0, + "grad_norm": 1.9959117709404242, + "learning_rate": 0.00017999999999999998, + "loss": 9.875, + "step": 4 + }, + { + "epoch": 0.0, + "grad_norm": 1.8270696218303057, + "learning_rate": 0.00023999999999999998, + "loss": 9.6641, + "step": 5 + }, + { + "epoch": 0.01, + "grad_norm": 1.7854351602113614, + "learning_rate": 0.0003, + "loss": 9.4844, + "step": 6 + }, + { + "epoch": 0.01, + "grad_norm": 1.7194174424274788, + "learning_rate": 0.00035999999999999997, + "loss": 9.3281, + "step": 7 + }, + { + "epoch": 0.01, + "grad_norm": 1.463772638994466, + "learning_rate": 0.00041999999999999996, + "loss": 9.2109, + "step": 8 + }, + { + "epoch": 0.01, + "grad_norm": 1.439323678271545, + "learning_rate": 0.00047999999999999996, + "loss": 8.9453, + "step": 9 + }, + { + "epoch": 0.01, + "grad_norm": 1.2936126396494727, + "learning_rate": 0.00054, + "loss": 8.7109, + "step": 10 + }, + { + "epoch": 0.01, + "grad_norm": 1.0757761814549318, + "learning_rate": 0.0005999986405514987, + "loss": 8.4609, + "step": 11 + }, + { + "epoch": 0.01, + "grad_norm": 0.9278570154341632, + "learning_rate": 0.0005999945622196846, + "loss": 8.2344, + "step": 12 + }, + { + "epoch": 0.01, + "grad_norm": 0.8086775215724974, + "learning_rate": 0.0005999877650456265, + "loss": 8.125, + "step": 13 + }, + { + "epoch": 0.01, + "grad_norm": 0.7630413213242441, + "learning_rate": 0.000599978249097772, + "loss": 7.9766, + "step": 14 + }, + { + "epoch": 0.01, + "grad_norm": 0.9172017565891333, + "learning_rate": 0.0005999660144719463, + "loss": 7.8555, + "step": 15 + }, + { + "epoch": 0.02, + "grad_norm": 0.6610052304024877, + "learning_rate": 0.0005999510612913519, + "loss": 7.7734, + "step": 16 + }, + { + "epoch": 0.02, + "grad_norm": 0.7091485456070775, + "learning_rate": 0.0005999333897065673, + "loss": 7.7148, + "step": 17 + }, + { + "epoch": 0.02, + "grad_norm": 16.771353248766836, + "learning_rate": 0.0005999129998955453, + "loss": 8.5078, + "step": 18 + }, + { + "epoch": 0.02, + "grad_norm": 1.3123969082989795, + "learning_rate": 0.0005998898920636111, + "loss": 7.7539, + "step": 19 + }, + { + "epoch": 0.02, + "grad_norm": 0.6992078172905232, + "learning_rate": 0.00059986406644346, + "loss": 7.75, + "step": 20 + } + ], + "logging_steps": 1, + "max_steps": 1000, + "num_input_tokens_seen": 0, + "num_train_epochs": 1, + "save_steps": 10, + "total_flos": 0.0, + "train_batch_size": 32, + "trial_name": null, + "trial_params": null +} diff --git a/checkpoint-20/training_args.bin b/checkpoint-20/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..12fdb7967b1254c497de146410ac3cd352b2b9c7 --- /dev/null +++ b/checkpoint-20/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95cc4290cc90782d57f7376defd26743b3a36943fc93e80e2734385bc57e8b78 +size 6520 diff --git a/checkpoint-20/zero_to_fp32.py b/checkpoint-20/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/checkpoint-20/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/checkpoint-200/config.json b/checkpoint-200/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b59b0b4c67b30baa7b62a3a87fc086e8dd1f8916 --- /dev/null +++ b/checkpoint-200/config.json @@ -0,0 +1,31 @@ +{ + "_name_or_path": "georgeyw/gpt-2-small-init-seed-5", + "architectures": [ + "GPTNeoXForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "bos_token_id": 0, + "classifier_dropout": 0.1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 1024, + "model_type": "gpt_neox", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "rope_scaling": null, + "rotary_emb_base": 10000, + "rotary_pct": 0.25, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.2", + "use_cache": true, + "use_parallel_residual": true, + "vocab_size": 50304 +} diff --git a/checkpoint-200/generation_config.json b/checkpoint-200/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bc00b333fdf0ba3611d022ddfdaeaf527fab8da0 --- /dev/null +++ b/checkpoint-200/generation_config.json @@ -0,0 +1,6 @@ +{ + "_from_model_config": true, + "bos_token_id": 0, + "eos_token_id": 2, + "transformers_version": "4.38.2" +} diff --git a/checkpoint-200/global_step200/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-200/global_step200/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..6de199a1d646205134a3db3536fa08e4d46c02fb --- /dev/null +++ b/checkpoint-200/global_step200/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc073673977d2260d446a5288c182308ae00d243deb8aa1287eed2bdc0ca55eb +size 973946896 diff --git a/checkpoint-200/global_step200/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-200/global_step200/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..76e6c760b2c1ea38488bfd86de0d5f3b944329b8 --- /dev/null +++ b/checkpoint-200/global_step200/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1dd88f92e1b9309a12e7505328b079538457daedef3bcf8d59f5cfcef86aca6d +size 973946832 diff --git a/checkpoint-200/global_step200/mp_rank_00_model_states.pt b/checkpoint-200/global_step200/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..79a32bb7f22145d96bc8e0f8d7099ea41c47e0fd --- /dev/null +++ b/checkpoint-200/global_step200/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea030c07bc1914d14db0809ed9a4fc3f9c76b4d11dbdabd30acf57c7c93e685c +size 324689964 diff --git a/checkpoint-200/latest b/checkpoint-200/latest new file mode 100644 index 0000000000000000000000000000000000000000..753e24e10f3a2489150f458205cf759fd8b6081f --- /dev/null +++ b/checkpoint-200/latest @@ -0,0 +1 @@ +global_step200 \ No newline at end of file diff --git a/checkpoint-200/model.safetensors b/checkpoint-200/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..1659e7f1283c8c39d8352c935c1baa41b2800cee --- /dev/null +++ b/checkpoint-200/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd81f3f4d4881c9ad52c0835b8eedd34cd74ec334ef06ba64c40c85ac825a476 +size 324662984 diff --git a/checkpoint-200/rng_state_0.pth b/checkpoint-200/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..cad18ac770da4331076b9ef49fc91a7f9a5989c3 --- /dev/null +++ b/checkpoint-200/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0bb7d2ecdd48fd7d0be1e75b0e3f29004064381052fa203ed926e88b90ef530 +size 14512 diff --git a/checkpoint-200/rng_state_1.pth b/checkpoint-200/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..197bac5f7fe92d301270b1f25b8fa7a07b568293 --- /dev/null +++ b/checkpoint-200/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:177d534a379bd6b276474c2cb140e318dc65db4457b6c1b6f25a1a9dd563af82 +size 14512 diff --git a/checkpoint-200/trainer_state.json b/checkpoint-200/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..2352ecb5f7adffe7d35666c768e419bf44d5e9d1 --- /dev/null +++ b/checkpoint-200/trainer_state.json @@ -0,0 +1,1421 @@ +{ + "best_metric": null, + "best_model_checkpoint": null, + "epoch": 0.19990004997501248, + "eval_steps": 500, + "global_step": 200, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 0.0, + "grad_norm": 3.3340563149001086, + "learning_rate": 0.0, + "loss": 11.0, + "step": 1 + }, + { + "epoch": 0.0, + "grad_norm": 2.398812329952019, + "learning_rate": 5.9999999999999995e-05, + "loss": 10.125, + "step": 2 + }, + { + "epoch": 0.0, + "grad_norm": 2.394322446895115, + "learning_rate": 0.00011999999999999999, + "loss": 10.1172, + "step": 3 + }, + { + "epoch": 0.0, + "grad_norm": 1.9958816684399585, + "learning_rate": 0.00017999999999999998, + "loss": 9.875, + "step": 4 + }, + { + "epoch": 0.0, + "grad_norm": 1.8270465897882062, + "learning_rate": 0.00023999999999999998, + "loss": 9.6641, + "step": 5 + }, + { + "epoch": 0.01, + "grad_norm": 1.7854046471397795, + "learning_rate": 0.0003, + "loss": 9.4844, + "step": 6 + }, + { + "epoch": 0.01, + "grad_norm": 1.719416749115252, + "learning_rate": 0.00035999999999999997, + "loss": 9.3281, + "step": 7 + }, + { + "epoch": 0.01, + "grad_norm": 1.4637825746112274, + "learning_rate": 0.00041999999999999996, + "loss": 9.2109, + "step": 8 + }, + { + "epoch": 0.01, + "grad_norm": 1.4393631015406718, + "learning_rate": 0.00047999999999999996, + "loss": 8.9453, + "step": 9 + }, + { + "epoch": 0.01, + "grad_norm": 1.2936734586915988, + "learning_rate": 0.00054, + "loss": 8.7109, + "step": 10 + }, + { + "epoch": 0.01, + "grad_norm": 1.0756922378227356, + "learning_rate": 0.0005999986405514987, + "loss": 8.4609, + "step": 11 + }, + { + "epoch": 0.01, + "grad_norm": 0.9277829127413892, + "learning_rate": 0.0005999945622196846, + "loss": 8.2344, + "step": 12 + }, + { + "epoch": 0.01, + "grad_norm": 0.8084581786682467, + "learning_rate": 0.0005999877650456265, + "loss": 8.125, + "step": 13 + }, + { + "epoch": 0.01, + "grad_norm": 0.7635084596900947, + "learning_rate": 0.000599978249097772, + "loss": 7.9766, + "step": 14 + }, + { + "epoch": 0.01, + "grad_norm": 0.9186699644247788, + "learning_rate": 0.0005999660144719463, + "loss": 7.8555, + "step": 15 + }, + { + "epoch": 0.02, + "grad_norm": 0.6609504256551479, + "learning_rate": 0.0005999510612913519, + "loss": 7.7734, + "step": 16 + }, + { + "epoch": 0.02, + "grad_norm": 0.7086232844782971, + "learning_rate": 0.0005999333897065673, + "loss": 7.7148, + "step": 17 + }, + { + "epoch": 0.02, + "grad_norm": 16.38048851691348, + "learning_rate": 0.0005999129998955453, + "loss": 8.4844, + "step": 18 + }, + { + "epoch": 0.02, + "grad_norm": 1.3057527590449889, + "learning_rate": 0.0005998898920636111, + "loss": 7.7539, + "step": 19 + }, + { + "epoch": 0.02, + "grad_norm": 0.6966048242948986, + "learning_rate": 0.00059986406644346, + "loss": 7.75, + "step": 20 + }, + { + "epoch": 0.02, + "grad_norm": 0.6348089115348993, + "learning_rate": 0.0005998355232951559, + "loss": 7.7031, + "step": 21 + }, + { + "epoch": 0.02, + "grad_norm": 0.7829163518610293, + "learning_rate": 0.0005998042629061279, + "loss": 7.6992, + "step": 22 + }, + { + "epoch": 0.02, + "grad_norm": 0.5900591778980369, + "learning_rate": 0.0005997702855911678, + "loss": 7.6016, + "step": 23 + }, + { + "epoch": 0.02, + "grad_norm": 0.4655170213064256, + "learning_rate": 0.0005997335916924268, + "loss": 7.5977, + "step": 24 + }, + { + "epoch": 0.02, + "grad_norm": 0.6287348258915756, + "learning_rate": 0.0005996941815794121, + "loss": 7.5586, + "step": 25 + }, + { + "epoch": 0.03, + "grad_norm": 0.6137321903884564, + "learning_rate": 0.0005996520556489831, + "loss": 7.5898, + "step": 26 + }, + { + "epoch": 0.03, + "grad_norm": 0.44962562710631065, + "learning_rate": 0.0005996072143253473, + "loss": 7.4336, + "step": 27 + }, + { + "epoch": 0.03, + "grad_norm": 0.46130046454703316, + "learning_rate": 0.0005995596580600566, + "loss": 7.4023, + "step": 28 + }, + { + "epoch": 0.03, + "grad_norm": 0.4686712675731326, + "learning_rate": 0.0005995093873320018, + "loss": 7.3789, + "step": 29 + }, + { + "epoch": 0.03, + "grad_norm": 0.4672147564288997, + "learning_rate": 0.0005994564026474087, + "loss": 7.3711, + "step": 30 + }, + { + "epoch": 0.03, + "grad_norm": 0.40408354581233474, + "learning_rate": 0.0005994007045398324, + "loss": 7.3672, + "step": 31 + }, + { + "epoch": 0.03, + "grad_norm": 0.46032146732584733, + "learning_rate": 0.0005993422935701524, + "loss": 7.3477, + "step": 32 + }, + { + "epoch": 0.03, + "grad_norm": 0.4765534634593268, + "learning_rate": 0.0005992811703265664, + "loss": 7.3555, + "step": 33 + }, + { + "epoch": 0.03, + "grad_norm": 0.46208489386235113, + "learning_rate": 0.0005992173354245849, + "loss": 7.3047, + "step": 34 + }, + { + "epoch": 0.03, + "grad_norm": 0.2956144524964961, + "learning_rate": 0.0005991507895070244, + "loss": 7.3125, + "step": 35 + }, + { + "epoch": 0.04, + "grad_norm": 0.4834645389868856, + "learning_rate": 0.0005990815332440017, + "loss": 7.207, + "step": 36 + }, + { + "epoch": 0.04, + "grad_norm": 0.4411831350968505, + "learning_rate": 0.0005990095673329266, + "loss": 7.1758, + "step": 37 + }, + { + "epoch": 0.04, + "grad_norm": 0.24809297748968667, + "learning_rate": 0.0005989348924984951, + "loss": 7.2188, + "step": 38 + }, + { + "epoch": 0.04, + "grad_norm": 0.39402988416840584, + "learning_rate": 0.0005988575094926817, + "loss": 7.1953, + "step": 39 + }, + { + "epoch": 0.04, + "grad_norm": 0.3868345222189167, + "learning_rate": 0.0005987774190947328, + "loss": 7.1641, + "step": 40 + }, + { + "epoch": 0.04, + "grad_norm": 0.3777261230135448, + "learning_rate": 0.0005986946221111575, + "loss": 7.1328, + "step": 41 + }, + { + "epoch": 0.04, + "grad_norm": 0.4687511444077827, + "learning_rate": 0.0005986091193757206, + "loss": 7.0898, + "step": 42 + }, + { + "epoch": 0.04, + "grad_norm": 0.34935796211612463, + "learning_rate": 0.0005985209117494337, + "loss": 7.1367, + "step": 43 + }, + { + "epoch": 0.04, + "grad_norm": 0.38764476686849886, + "learning_rate": 0.0005984300001205466, + "loss": 7.125, + "step": 44 + }, + { + "epoch": 0.04, + "grad_norm": 0.3956487898882936, + "learning_rate": 0.0005983363854045386, + "loss": 7.1094, + "step": 45 + }, + { + "epoch": 0.05, + "grad_norm": 0.31140257544677513, + "learning_rate": 0.0005982400685441084, + "loss": 7.0898, + "step": 46 + }, + { + "epoch": 0.05, + "grad_norm": 0.3664476570531787, + "learning_rate": 0.0005981410505091662, + "loss": 7.0664, + "step": 47 + }, + { + "epoch": 0.05, + "grad_norm": 0.31891741142945207, + "learning_rate": 0.0005980393322968223, + "loss": 7.0273, + "step": 48 + }, + { + "epoch": 0.05, + "grad_norm": 0.4533529037337155, + "learning_rate": 0.0005979349149313778, + "loss": 7.0586, + "step": 49 + }, + { + "epoch": 0.05, + "grad_norm": 0.30532331638835586, + "learning_rate": 0.0005978277994643147, + "loss": 7.0195, + "step": 50 + }, + { + "epoch": 0.05, + "grad_norm": 0.6501991746260075, + "learning_rate": 0.0005977179869742844, + "loss": 6.9648, + "step": 51 + }, + { + "epoch": 0.05, + "grad_norm": 0.43904455901717926, + "learning_rate": 0.0005976054785670975, + "loss": 6.9805, + "step": 52 + }, + { + "epoch": 0.05, + "grad_norm": 0.4826001598483571, + "learning_rate": 0.0005974902753757124, + "loss": 6.9297, + "step": 53 + }, + { + "epoch": 0.05, + "grad_norm": 0.2924998027034648, + "learning_rate": 0.000597372378560224, + "loss": 6.8984, + "step": 54 + }, + { + "epoch": 0.05, + "grad_norm": 0.4439033666380787, + "learning_rate": 0.0005972517893078517, + "loss": 6.8945, + "step": 55 + }, + { + "epoch": 0.06, + "grad_norm": 0.6135914255073411, + "learning_rate": 0.0005971285088329284, + "loss": 6.9727, + "step": 56 + }, + { + "epoch": 0.06, + "grad_norm": 0.5575686565598483, + "learning_rate": 0.0005970025383768866, + "loss": 6.9219, + "step": 57 + }, + { + "epoch": 0.06, + "grad_norm": 0.4820951675994578, + "learning_rate": 0.0005968738792082478, + "loss": 6.8516, + "step": 58 + }, + { + "epoch": 0.06, + "grad_norm": 0.40164190019465584, + "learning_rate": 0.0005967425326226082, + "loss": 6.7734, + "step": 59 + }, + { + "epoch": 0.06, + "grad_norm": 0.46129863945181293, + "learning_rate": 0.0005966084999426265, + "loss": 6.8125, + "step": 60 + }, + { + "epoch": 0.06, + "grad_norm": 0.33322355827118677, + "learning_rate": 0.0005964717825180101, + "loss": 6.7891, + "step": 61 + }, + { + "epoch": 0.06, + "grad_norm": 0.3847525153855558, + "learning_rate": 0.0005963323817255024, + "loss": 6.8242, + "step": 62 + }, + { + "epoch": 0.06, + "grad_norm": 0.3384433591375982, + "learning_rate": 0.0005961902989688674, + "loss": 6.707, + "step": 63 + }, + { + "epoch": 0.06, + "grad_norm": 0.3937003195165685, + "learning_rate": 0.000596045535678877, + "loss": 6.8203, + "step": 64 + }, + { + "epoch": 0.06, + "grad_norm": 0.35423488053528107, + "learning_rate": 0.0005958980933132962, + "loss": 6.7383, + "step": 65 + }, + { + "epoch": 0.07, + "grad_norm": 0.36005939745315396, + "learning_rate": 0.0005957479733568675, + "loss": 6.7109, + "step": 66 + }, + { + "epoch": 0.07, + "grad_norm": 0.3499278317706933, + "learning_rate": 0.0005955951773212976, + "loss": 6.7266, + "step": 67 + }, + { + "epoch": 0.07, + "grad_norm": 0.3708385192137018, + "learning_rate": 0.0005954397067452407, + "loss": 6.7617, + "step": 68 + }, + { + "epoch": 0.07, + "grad_norm": 0.3775657656205869, + "learning_rate": 0.0005952815631942839, + "loss": 6.7148, + "step": 69 + }, + { + "epoch": 0.07, + "grad_norm": 0.3040083750375816, + "learning_rate": 0.0005951207482609307, + "loss": 6.5938, + "step": 70 + }, + { + "epoch": 0.07, + "grad_norm": 0.3443020808841468, + "learning_rate": 0.0005949572635645861, + "loss": 6.6523, + "step": 71 + }, + { + "epoch": 0.07, + "grad_norm": 0.3520066316939, + "learning_rate": 0.0005947911107515389, + "loss": 6.6211, + "step": 72 + }, + { + "epoch": 0.07, + "grad_norm": 0.3739040572679613, + "learning_rate": 0.0005946222914949462, + "loss": 6.5547, + "step": 73 + }, + { + "epoch": 0.07, + "grad_norm": 0.34890731989025553, + "learning_rate": 0.000594450807494816, + "loss": 6.5859, + "step": 74 + }, + { + "epoch": 0.07, + "grad_norm": 0.40910932350136514, + "learning_rate": 0.0005942766604779903, + "loss": 6.5547, + "step": 75 + }, + { + "epoch": 0.08, + "grad_norm": 0.5698342865852906, + "learning_rate": 0.0005940998521981274, + "loss": 6.457, + "step": 76 + }, + { + "epoch": 0.08, + "grad_norm": 0.5179452709555474, + "learning_rate": 0.0005939203844356852, + "loss": 6.5547, + "step": 77 + }, + { + "epoch": 0.08, + "grad_norm": 0.5222512938673792, + "learning_rate": 0.0005937382589979016, + "loss": 6.5039, + "step": 78 + }, + { + "epoch": 0.08, + "grad_norm": 0.5682332793686307, + "learning_rate": 0.0005935534777187781, + "loss": 6.5547, + "step": 79 + }, + { + "epoch": 0.08, + "grad_norm": 0.3869287710460676, + "learning_rate": 0.0005933660424590598, + "loss": 6.5156, + "step": 80 + }, + { + "epoch": 0.08, + "grad_norm": 0.3078211032807607, + "learning_rate": 0.000593175955106218, + "loss": 6.4258, + "step": 81 + }, + { + "epoch": 0.08, + "grad_norm": 0.3611357511872241, + "learning_rate": 0.00059298321757443, + "loss": 6.4727, + "step": 82 + }, + { + "epoch": 0.08, + "grad_norm": 0.29633467844266953, + "learning_rate": 0.0005927878318045608, + "loss": 6.3281, + "step": 83 + }, + { + "epoch": 0.08, + "grad_norm": 0.3257574200776832, + "learning_rate": 0.0005925897997641426, + "loss": 6.3203, + "step": 84 + }, + { + "epoch": 0.08, + "grad_norm": 0.2824054533852328, + "learning_rate": 0.0005923891234473562, + "loss": 6.4062, + "step": 85 + }, + { + "epoch": 0.09, + "grad_norm": 0.3056199770204573, + "learning_rate": 0.0005921858048750097, + "loss": 6.3984, + "step": 86 + }, + { + "epoch": 0.09, + "grad_norm": 0.2966438824341908, + "learning_rate": 0.000591979846094519, + "loss": 6.3555, + "step": 87 + }, + { + "epoch": 0.09, + "grad_norm": 0.32782438676663733, + "learning_rate": 0.0005917712491798866, + "loss": 6.4023, + "step": 88 + }, + { + "epoch": 0.09, + "grad_norm": 0.3538316399620157, + "learning_rate": 0.0005915600162316811, + "loss": 6.2812, + "step": 89 + }, + { + "epoch": 0.09, + "grad_norm": 0.375858298192913, + "learning_rate": 0.0005913461493770162, + "loss": 6.3086, + "step": 90 + }, + { + "epoch": 0.09, + "grad_norm": 0.5189251339815161, + "learning_rate": 0.0005911296507695284, + "loss": 6.2812, + "step": 91 + }, + { + "epoch": 0.09, + "grad_norm": 0.6304909542669104, + "learning_rate": 0.0005909105225893564, + "loss": 6.2969, + "step": 92 + }, + { + "epoch": 0.09, + "grad_norm": 0.4655662819622591, + "learning_rate": 0.0005906887670431187, + "loss": 6.1953, + "step": 93 + }, + { + "epoch": 0.09, + "grad_norm": 0.39035390983920965, + "learning_rate": 0.000590464386363891, + "loss": 6.2617, + "step": 94 + }, + { + "epoch": 0.09, + "grad_norm": 0.4918417851770978, + "learning_rate": 0.0005902373828111843, + "loss": 6.2148, + "step": 95 + }, + { + "epoch": 0.1, + "grad_norm": 0.35670770889552555, + "learning_rate": 0.0005900077586709219, + "loss": 6.2461, + "step": 96 + }, + { + "epoch": 0.1, + "grad_norm": 0.4177985869939347, + "learning_rate": 0.0005897755162554163, + "loss": 6.1797, + "step": 97 + }, + { + "epoch": 0.1, + "grad_norm": 0.3742471130708234, + "learning_rate": 0.000589540657903346, + "loss": 6.1406, + "step": 98 + }, + { + "epoch": 0.1, + "grad_norm": 0.28627666723978284, + "learning_rate": 0.0005893031859797322, + "loss": 6.2031, + "step": 99 + }, + { + "epoch": 0.1, + "grad_norm": 0.32238563846046103, + "learning_rate": 0.0005890631028759143, + "loss": 6.0625, + "step": 100 + }, + { + "epoch": 0.1, + "grad_norm": 0.2556625657587849, + "learning_rate": 0.0005888204110095265, + "loss": 6.1797, + "step": 101 + }, + { + "epoch": 0.1, + "grad_norm": 0.35463629701710253, + "learning_rate": 0.0005885751128244734, + "loss": 6.125, + "step": 102 + }, + { + "epoch": 0.1, + "grad_norm": 0.31975770214936095, + "learning_rate": 0.0005883272107909048, + "loss": 6.1836, + "step": 103 + }, + { + "epoch": 0.1, + "grad_norm": 0.3464621815245048, + "learning_rate": 0.0005880767074051915, + "loss": 6.125, + "step": 104 + }, + { + "epoch": 0.1, + "grad_norm": 0.3663428920796654, + "learning_rate": 0.0005878236051898998, + "loss": 6.0781, + "step": 105 + }, + { + "epoch": 0.11, + "grad_norm": 0.31594460565215293, + "learning_rate": 0.0005875679066937664, + "loss": 6.082, + "step": 106 + }, + { + "epoch": 0.11, + "grad_norm": 0.3552617109396582, + "learning_rate": 0.000587309614491672, + "loss": 6.1016, + "step": 107 + }, + { + "epoch": 0.11, + "grad_norm": 0.307016409692456, + "learning_rate": 0.0005870487311846164, + "loss": 6.1406, + "step": 108 + }, + { + "epoch": 0.11, + "grad_norm": 0.32188902148474213, + "learning_rate": 0.0005867852593996914, + "loss": 6.0039, + "step": 109 + }, + { + "epoch": 0.11, + "grad_norm": 0.25501199715105083, + "learning_rate": 0.0005865192017900551, + "loss": 6.0938, + "step": 110 + }, + { + "epoch": 0.11, + "grad_norm": 0.3416203070024056, + "learning_rate": 0.0005862505610349049, + "loss": 6.0234, + "step": 111 + }, + { + "epoch": 0.11, + "grad_norm": 0.3562508875852537, + "learning_rate": 0.0005859793398394498, + "loss": 6.0469, + "step": 112 + }, + { + "epoch": 0.11, + "grad_norm": 0.4443953757302568, + "learning_rate": 0.0005857055409348845, + "loss": 5.9766, + "step": 113 + }, + { + "epoch": 0.11, + "grad_norm": 0.42023839332714596, + "learning_rate": 0.0005854291670783607, + "loss": 6.0781, + "step": 114 + }, + { + "epoch": 0.11, + "grad_norm": 0.4618323255809241, + "learning_rate": 0.0005851502210529604, + "loss": 5.9727, + "step": 115 + }, + { + "epoch": 0.12, + "grad_norm": 0.379195014798667, + "learning_rate": 0.0005848687056676668, + "loss": 5.9922, + "step": 116 + }, + { + "epoch": 0.12, + "grad_norm": 0.3931552573296799, + "learning_rate": 0.0005845846237573366, + "loss": 5.9492, + "step": 117 + }, + { + "epoch": 0.12, + "grad_norm": 0.2567080044949908, + "learning_rate": 0.0005842979781826717, + "loss": 6.0273, + "step": 118 + }, + { + "epoch": 0.12, + "grad_norm": 0.4190305965377807, + "learning_rate": 0.0005840087718301895, + "loss": 6.0391, + "step": 119 + }, + { + "epoch": 0.12, + "grad_norm": 0.3996803869430228, + "learning_rate": 0.0005837170076121951, + "loss": 5.9531, + "step": 120 + }, + { + "epoch": 0.12, + "grad_norm": 0.478219248015785, + "learning_rate": 0.000583422688466751, + "loss": 6.0586, + "step": 121 + }, + { + "epoch": 0.12, + "grad_norm": 0.40869844309811526, + "learning_rate": 0.0005831258173576474, + "loss": 6.0117, + "step": 122 + }, + { + "epoch": 0.12, + "grad_norm": 0.3728598080697978, + "learning_rate": 0.0005828263972743733, + "loss": 5.9375, + "step": 123 + }, + { + "epoch": 0.12, + "grad_norm": 0.3560055462882015, + "learning_rate": 0.0005825244312320856, + "loss": 5.9531, + "step": 124 + }, + { + "epoch": 0.12, + "grad_norm": 0.40446932887864323, + "learning_rate": 0.0005822199222715787, + "loss": 5.9609, + "step": 125 + }, + { + "epoch": 0.13, + "grad_norm": 0.38514065739946723, + "learning_rate": 0.000581912873459255, + "loss": 5.8594, + "step": 126 + }, + { + "epoch": 0.13, + "grad_norm": 0.35367576386319416, + "learning_rate": 0.0005816032878870921, + "loss": 5.9023, + "step": 127 + }, + { + "epoch": 0.13, + "grad_norm": 0.3341681995122829, + "learning_rate": 0.0005812911686726135, + "loss": 5.9062, + "step": 128 + }, + { + "epoch": 0.13, + "grad_norm": 0.3387022688975784, + "learning_rate": 0.0005809765189588563, + "loss": 5.8945, + "step": 129 + }, + { + "epoch": 0.13, + "grad_norm": 0.31638659898934757, + "learning_rate": 0.0005806593419143395, + "loss": 5.8242, + "step": 130 + }, + { + "epoch": 0.13, + "grad_norm": 0.3229678508227436, + "learning_rate": 0.0005803396407330325, + "loss": 5.8516, + "step": 131 + }, + { + "epoch": 0.13, + "grad_norm": 0.35499490868584455, + "learning_rate": 0.0005800174186343226, + "loss": 5.9258, + "step": 132 + }, + { + "epoch": 0.13, + "grad_norm": 0.40753171542848754, + "learning_rate": 0.0005796926788629828, + "loss": 5.8242, + "step": 133 + }, + { + "epoch": 0.13, + "grad_norm": 0.3625374018348824, + "learning_rate": 0.0005793654246891389, + "loss": 5.832, + "step": 134 + }, + { + "epoch": 0.13, + "grad_norm": 0.3583489573569317, + "learning_rate": 0.000579035659408237, + "loss": 5.8398, + "step": 135 + }, + { + "epoch": 0.14, + "grad_norm": 0.39657706318861896, + "learning_rate": 0.0005787033863410095, + "loss": 5.8633, + "step": 136 + }, + { + "epoch": 0.14, + "grad_norm": 0.3965837889564036, + "learning_rate": 0.0005783686088334428, + "loss": 5.8633, + "step": 137 + }, + { + "epoch": 0.14, + "grad_norm": 0.29496474301865566, + "learning_rate": 0.0005780313302567424, + "loss": 5.8203, + "step": 138 + }, + { + "epoch": 0.14, + "grad_norm": 0.44637192639243695, + "learning_rate": 0.0005776915540073001, + "loss": 5.8477, + "step": 139 + }, + { + "epoch": 0.14, + "grad_norm": 0.39605473508683114, + "learning_rate": 0.0005773492835066587, + "loss": 5.7383, + "step": 140 + }, + { + "epoch": 0.14, + "grad_norm": 0.3008962634266945, + "learning_rate": 0.0005770045222014786, + "loss": 5.7617, + "step": 141 + }, + { + "epoch": 0.14, + "grad_norm": 0.36915495506607826, + "learning_rate": 0.0005766572735635022, + "loss": 5.7695, + "step": 142 + }, + { + "epoch": 0.14, + "grad_norm": 0.3282300349560706, + "learning_rate": 0.0005763075410895193, + "loss": 5.8281, + "step": 143 + }, + { + "epoch": 0.14, + "grad_norm": 0.2747449814083844, + "learning_rate": 0.0005759553283013323, + "loss": 5.7812, + "step": 144 + }, + { + "epoch": 0.14, + "grad_norm": 0.28905882704179764, + "learning_rate": 0.00057560063874572, + "loss": 5.7344, + "step": 145 + }, + { + "epoch": 0.15, + "grad_norm": 0.280625988867192, + "learning_rate": 0.000575243475994402, + "loss": 5.7773, + "step": 146 + }, + { + "epoch": 0.15, + "grad_norm": 0.41061863948012467, + "learning_rate": 0.0005748838436440035, + "loss": 5.7578, + "step": 147 + }, + { + "epoch": 0.15, + "grad_norm": 0.4920152483870267, + "learning_rate": 0.0005745217453160183, + "loss": 5.7305, + "step": 148 + }, + { + "epoch": 0.15, + "grad_norm": 0.5463207978955044, + "learning_rate": 0.0005741571846567725, + "loss": 5.7383, + "step": 149 + }, + { + "epoch": 0.15, + "grad_norm": 0.3986359831157306, + "learning_rate": 0.0005737901653373878, + "loss": 5.668, + "step": 150 + }, + { + "epoch": 0.15, + "grad_norm": 0.37908758170100293, + "learning_rate": 0.0005734206910537447, + "loss": 5.6875, + "step": 151 + }, + { + "epoch": 0.15, + "grad_norm": 0.35929793070492694, + "learning_rate": 0.0005730487655264451, + "loss": 5.7188, + "step": 152 + }, + { + "epoch": 0.15, + "grad_norm": 0.4217799574145456, + "learning_rate": 0.0005726743925007751, + "loss": 5.7305, + "step": 153 + }, + { + "epoch": 0.15, + "grad_norm": 0.4024411981587195, + "learning_rate": 0.0005722975757466667, + "loss": 5.6289, + "step": 154 + }, + { + "epoch": 0.15, + "grad_norm": 0.3472391905877033, + "learning_rate": 0.0005719183190586606, + "loss": 5.6523, + "step": 155 + }, + { + "epoch": 0.16, + "grad_norm": 0.31752956812138816, + "learning_rate": 0.0005715366262558675, + "loss": 5.6172, + "step": 156 + }, + { + "epoch": 0.16, + "grad_norm": 0.3170152384332457, + "learning_rate": 0.0005711525011819294, + "loss": 5.6172, + "step": 157 + }, + { + "epoch": 0.16, + "grad_norm": 0.40520629326601837, + "learning_rate": 0.0005707659477049818, + "loss": 5.625, + "step": 158 + }, + { + "epoch": 0.16, + "grad_norm": 0.3965976910198806, + "learning_rate": 0.0005703769697176137, + "loss": 5.6562, + "step": 159 + }, + { + "epoch": 0.16, + "grad_norm": 0.40422960541801994, + "learning_rate": 0.0005699855711368293, + "loss": 5.6836, + "step": 160 + }, + { + "epoch": 0.16, + "grad_norm": 0.3780813184050647, + "learning_rate": 0.0005695917559040079, + "loss": 5.5938, + "step": 161 + }, + { + "epoch": 0.16, + "grad_norm": 0.36917638857736573, + "learning_rate": 0.0005691955279848645, + "loss": 5.668, + "step": 162 + }, + { + "epoch": 0.16, + "grad_norm": 0.37769176081037814, + "learning_rate": 0.0005687968913694098, + "loss": 5.4961, + "step": 163 + }, + { + "epoch": 0.16, + "grad_norm": 0.3255116524991148, + "learning_rate": 0.0005683958500719103, + "loss": 5.5117, + "step": 164 + }, + { + "epoch": 0.16, + "grad_norm": 0.31897629016848805, + "learning_rate": 0.0005679924081308471, + "loss": 5.5664, + "step": 165 + }, + { + "epoch": 0.17, + "grad_norm": 0.2869064236553046, + "learning_rate": 0.0005675865696088764, + "loss": 5.5391, + "step": 166 + }, + { + "epoch": 0.17, + "grad_norm": 0.29226729022634845, + "learning_rate": 0.0005671783385927873, + "loss": 5.5586, + "step": 167 + }, + { + "epoch": 0.17, + "grad_norm": 0.2534117210955766, + "learning_rate": 0.0005667677191934618, + "loss": 5.5312, + "step": 168 + }, + { + "epoch": 0.17, + "grad_norm": 0.289828484125484, + "learning_rate": 0.0005663547155458326, + "loss": 5.6484, + "step": 169 + }, + { + "epoch": 0.17, + "grad_norm": 0.2717242930342115, + "learning_rate": 0.0005659393318088419, + "loss": 5.5352, + "step": 170 + }, + { + "epoch": 0.17, + "grad_norm": 0.3595538109137759, + "learning_rate": 0.0005655215721653993, + "loss": 5.5742, + "step": 171 + }, + { + "epoch": 0.17, + "grad_norm": 0.4255054350471108, + "learning_rate": 0.0005651014408223398, + "loss": 5.5469, + "step": 172 + }, + { + "epoch": 0.17, + "grad_norm": 0.3670561941219979, + "learning_rate": 0.0005646789420103814, + "loss": 5.5078, + "step": 173 + }, + { + "epoch": 0.17, + "grad_norm": 0.40280130904983164, + "learning_rate": 0.0005642540799840822, + "loss": 5.5, + "step": 174 + }, + { + "epoch": 0.17, + "grad_norm": 0.41159472035983025, + "learning_rate": 0.0005638268590217984, + "loss": 5.5039, + "step": 175 + }, + { + "epoch": 0.18, + "grad_norm": 0.4316778037513652, + "learning_rate": 0.0005633972834256401, + "loss": 5.5352, + "step": 176 + }, + { + "epoch": 0.18, + "grad_norm": 0.5674781128363939, + "learning_rate": 0.000562965357521429, + "loss": 5.4336, + "step": 177 + }, + { + "epoch": 0.18, + "grad_norm": 0.41654662151365446, + "learning_rate": 0.0005625310856586541, + "loss": 5.6211, + "step": 178 + }, + { + "epoch": 0.18, + "grad_norm": 0.5159976364107484, + "learning_rate": 0.0005620944722104282, + "loss": 5.4844, + "step": 179 + }, + { + "epoch": 0.18, + "grad_norm": 0.34364678177014185, + "learning_rate": 0.0005616555215734438, + "loss": 5.4922, + "step": 180 + }, + { + "epoch": 0.18, + "grad_norm": 0.3708077784459011, + "learning_rate": 0.0005612142381679289, + "loss": 5.5234, + "step": 181 + }, + { + "epoch": 0.18, + "grad_norm": 0.3620051253453866, + "learning_rate": 0.0005607706264376028, + "loss": 5.4961, + "step": 182 + }, + { + "epoch": 0.18, + "grad_norm": 0.34735585210929654, + "learning_rate": 0.0005603246908496305, + "loss": 5.4453, + "step": 183 + }, + { + "epoch": 0.18, + "grad_norm": 0.37719874705792217, + "learning_rate": 0.0005598764358945783, + "loss": 5.4844, + "step": 184 + }, + { + "epoch": 0.18, + "grad_norm": 0.3749130664831207, + "learning_rate": 0.0005594258660863689, + "loss": 5.4648, + "step": 185 + }, + { + "epoch": 0.19, + "grad_norm": 0.40951353306235827, + "learning_rate": 0.0005589729859622351, + "loss": 5.5039, + "step": 186 + }, + { + "epoch": 0.19, + "grad_norm": 0.40146882563949804, + "learning_rate": 0.0005585178000826745, + "loss": 5.3672, + "step": 187 + }, + { + "epoch": 0.19, + "grad_norm": 0.4062987628428303, + "learning_rate": 0.0005580603130314043, + "loss": 5.3984, + "step": 188 + }, + { + "epoch": 0.19, + "grad_norm": 0.35626322654799136, + "learning_rate": 0.0005576005294153138, + "loss": 5.3984, + "step": 189 + }, + { + "epoch": 0.19, + "grad_norm": 0.3140647930801716, + "learning_rate": 0.0005571384538644188, + "loss": 5.3906, + "step": 190 + }, + { + "epoch": 0.19, + "grad_norm": 0.2990060538353662, + "learning_rate": 0.0005566740910318153, + "loss": 5.3711, + "step": 191 + }, + { + "epoch": 0.19, + "grad_norm": 0.3337525907515936, + "learning_rate": 0.0005562074455936315, + "loss": 5.4023, + "step": 192 + }, + { + "epoch": 0.19, + "grad_norm": 0.3381587051014816, + "learning_rate": 0.000555738522248982, + "loss": 5.4414, + "step": 193 + }, + { + "epoch": 0.19, + "grad_norm": 0.2954008999469894, + "learning_rate": 0.0005552673257199197, + "loss": 5.418, + "step": 194 + }, + { + "epoch": 0.19, + "grad_norm": 0.3242310900810155, + "learning_rate": 0.0005547938607513882, + "loss": 5.418, + "step": 195 + }, + { + "epoch": 0.2, + "grad_norm": 0.3149021804393678, + "learning_rate": 0.0005543181321111747, + "loss": 5.4375, + "step": 196 + }, + { + "epoch": 0.2, + "grad_norm": 0.32859412218218814, + "learning_rate": 0.0005538401445898612, + "loss": 5.4492, + "step": 197 + }, + { + "epoch": 0.2, + "grad_norm": 0.2960282598050701, + "learning_rate": 0.0005533599030007768, + "loss": 5.3867, + "step": 198 + }, + { + "epoch": 0.2, + "grad_norm": 0.2866762878199755, + "learning_rate": 0.0005528774121799489, + "loss": 5.3789, + "step": 199 + }, + { + "epoch": 0.2, + "grad_norm": 0.34865216327038784, + "learning_rate": 0.0005523926769860549, + "loss": 5.3711, + "step": 200 + } + ], + "logging_steps": 1, + "max_steps": 1000, + "num_input_tokens_seen": 0, + "num_train_epochs": 1, + "save_steps": 100, + "total_flos": 0.0, + "train_batch_size": 32, + "trial_name": null, + "trial_params": null +} diff --git a/checkpoint-200/training_args.bin b/checkpoint-200/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..09b35ec8ac2a16eb45febe1d655d456e47af68d1 --- /dev/null +++ b/checkpoint-200/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbd9a6067cf818494e2505097746a1cad30533fc72eb13916de34f7671e3e456 +size 6520 diff --git a/checkpoint-200/zero_to_fp32.py b/checkpoint-200/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/checkpoint-200/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/checkpoint-25/config.json b/checkpoint-25/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b59b0b4c67b30baa7b62a3a87fc086e8dd1f8916 --- /dev/null +++ b/checkpoint-25/config.json @@ -0,0 +1,31 @@ +{ + "_name_or_path": "georgeyw/gpt-2-small-init-seed-5", + "architectures": [ + "GPTNeoXForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "bos_token_id": 0, + "classifier_dropout": 0.1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 1024, + "model_type": "gpt_neox", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "rope_scaling": null, + "rotary_emb_base": 10000, + "rotary_pct": 0.25, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.2", + "use_cache": true, + "use_parallel_residual": true, + "vocab_size": 50304 +} diff --git a/checkpoint-25/model.safetensors b/checkpoint-25/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..08b3394d6c995bc8ac96781e497c9d4f84aa5bab --- /dev/null +++ b/checkpoint-25/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e0bd0a6b0ef632c251c6e794edc382636fe78ad048636c064c92ee8c123ccd3 +size 324662984 diff --git a/checkpoint-25/training_args.bin b/checkpoint-25/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..9362a9e736fc862ece575b9f1b9d54b14c10d0b5 --- /dev/null +++ b/checkpoint-25/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36ce7ab48fa86ef42491eaad3583773d2b60353997a5e7b6fb4ffc1414828749 +size 6520 diff --git a/checkpoint-300/config.json b/checkpoint-300/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b59b0b4c67b30baa7b62a3a87fc086e8dd1f8916 --- /dev/null +++ b/checkpoint-300/config.json @@ -0,0 +1,31 @@ +{ + "_name_or_path": "georgeyw/gpt-2-small-init-seed-5", + "architectures": [ + "GPTNeoXForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "bos_token_id": 0, + "classifier_dropout": 0.1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 1024, + "model_type": "gpt_neox", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "rope_scaling": null, + "rotary_emb_base": 10000, + "rotary_pct": 0.25, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.2", + "use_cache": true, + "use_parallel_residual": true, + "vocab_size": 50304 +} diff --git a/checkpoint-300/generation_config.json b/checkpoint-300/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bc00b333fdf0ba3611d022ddfdaeaf527fab8da0 --- /dev/null +++ b/checkpoint-300/generation_config.json @@ -0,0 +1,6 @@ +{ + "_from_model_config": true, + "bos_token_id": 0, + "eos_token_id": 2, + "transformers_version": "4.38.2" +} diff --git a/checkpoint-300/global_step300/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-300/global_step300/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2f46ebd42a21d20b4b1a29b08f0d5d7098d7f9dd --- /dev/null +++ b/checkpoint-300/global_step300/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71c5be6373fb827d208e43746b2e985a638d20b096d1935fa0b92c3b73969e87 +size 973946896 diff --git a/checkpoint-300/global_step300/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-300/global_step300/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..f81ff09c77842a2032984769f82ab87e080e4f9c --- /dev/null +++ b/checkpoint-300/global_step300/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:132bca4d07aadd3f56b2aef04a2d5892b4f02f0873f1d8b5f27d942fa900e224 +size 973946832 diff --git a/checkpoint-300/global_step300/mp_rank_00_model_states.pt b/checkpoint-300/global_step300/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..43d8a60a7b2687da827688d4e265b92fd1ff35b6 --- /dev/null +++ b/checkpoint-300/global_step300/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e986775e89f677510fd5c4f357f85e9e46ce0ae4e92db51ebb800efc1a64486 +size 324689964 diff --git a/checkpoint-300/latest b/checkpoint-300/latest new file mode 100644 index 0000000000000000000000000000000000000000..6761b575fffac7f1984044dcb6446b3a51da04c8 --- /dev/null +++ b/checkpoint-300/latest @@ -0,0 +1 @@ +global_step300 \ No newline at end of file diff --git a/checkpoint-300/model.safetensors b/checkpoint-300/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..839c981bd06ed3c4b3d086dd3792367f5d319d52 --- /dev/null +++ b/checkpoint-300/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba8550b5d8f29acb5ae68f5f221f4f61f56e0eecb0a274c9047640dfea097117 +size 324662984 diff --git a/checkpoint-300/rng_state_0.pth b/checkpoint-300/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..cad18ac770da4331076b9ef49fc91a7f9a5989c3 --- /dev/null +++ b/checkpoint-300/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0bb7d2ecdd48fd7d0be1e75b0e3f29004064381052fa203ed926e88b90ef530 +size 14512 diff --git a/checkpoint-300/rng_state_1.pth b/checkpoint-300/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..197bac5f7fe92d301270b1f25b8fa7a07b568293 --- /dev/null +++ b/checkpoint-300/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:177d534a379bd6b276474c2cb140e318dc65db4457b6c1b6f25a1a9dd563af82 +size 14512 diff --git a/checkpoint-300/trainer_state.json b/checkpoint-300/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..16c46764e434baf3d6f4e6aec47b930ebb0afc75 --- /dev/null +++ b/checkpoint-300/trainer_state.json @@ -0,0 +1,2121 @@ +{ + "best_metric": null, + "best_model_checkpoint": null, + "epoch": 0.29985007496251875, + "eval_steps": 500, + "global_step": 300, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 0.0, + "grad_norm": 3.3340563149001086, + "learning_rate": 0.0, + "loss": 11.0, + "step": 1 + }, + { + "epoch": 0.0, + "grad_norm": 2.398812329952019, + "learning_rate": 5.9999999999999995e-05, + "loss": 10.125, + "step": 2 + }, + { + "epoch": 0.0, + "grad_norm": 2.394322446895115, + "learning_rate": 0.00011999999999999999, + "loss": 10.1172, + "step": 3 + }, + { + "epoch": 0.0, + "grad_norm": 1.9958816684399585, + "learning_rate": 0.00017999999999999998, + "loss": 9.875, + "step": 4 + }, + { + "epoch": 0.0, + "grad_norm": 1.8270465897882062, + "learning_rate": 0.00023999999999999998, + "loss": 9.6641, + "step": 5 + }, + { + "epoch": 0.01, + "grad_norm": 1.7854046471397795, + "learning_rate": 0.0003, + "loss": 9.4844, + "step": 6 + }, + { + "epoch": 0.01, + "grad_norm": 1.719416749115252, + "learning_rate": 0.00035999999999999997, + "loss": 9.3281, + "step": 7 + }, + { + "epoch": 0.01, + "grad_norm": 1.4637825746112274, + "learning_rate": 0.00041999999999999996, + "loss": 9.2109, + "step": 8 + }, + { + "epoch": 0.01, + "grad_norm": 1.4393631015406718, + "learning_rate": 0.00047999999999999996, + "loss": 8.9453, + "step": 9 + }, + { + "epoch": 0.01, + "grad_norm": 1.2936734586915988, + "learning_rate": 0.00054, + "loss": 8.7109, + "step": 10 + }, + { + "epoch": 0.01, + "grad_norm": 1.0756922378227356, + "learning_rate": 0.0005999986405514987, + "loss": 8.4609, + "step": 11 + }, + { + "epoch": 0.01, + "grad_norm": 0.9277829127413892, + "learning_rate": 0.0005999945622196846, + "loss": 8.2344, + "step": 12 + }, + { + "epoch": 0.01, + "grad_norm": 0.8084581786682467, + "learning_rate": 0.0005999877650456265, + "loss": 8.125, + "step": 13 + }, + { + "epoch": 0.01, + "grad_norm": 0.7635084596900947, + "learning_rate": 0.000599978249097772, + "loss": 7.9766, + "step": 14 + }, + { + "epoch": 0.01, + "grad_norm": 0.9186699644247788, + "learning_rate": 0.0005999660144719463, + "loss": 7.8555, + "step": 15 + }, + { + "epoch": 0.02, + "grad_norm": 0.6609504256551479, + "learning_rate": 0.0005999510612913519, + "loss": 7.7734, + "step": 16 + }, + { + "epoch": 0.02, + "grad_norm": 0.7086232844782971, + "learning_rate": 0.0005999333897065673, + "loss": 7.7148, + "step": 17 + }, + { + "epoch": 0.02, + "grad_norm": 16.38048851691348, + "learning_rate": 0.0005999129998955453, + "loss": 8.4844, + "step": 18 + }, + { + "epoch": 0.02, + "grad_norm": 1.3057527590449889, + "learning_rate": 0.0005998898920636111, + "loss": 7.7539, + "step": 19 + }, + { + "epoch": 0.02, + "grad_norm": 0.6966048242948986, + "learning_rate": 0.00059986406644346, + "loss": 7.75, + "step": 20 + }, + { + "epoch": 0.02, + "grad_norm": 0.6348089115348993, + "learning_rate": 0.0005998355232951559, + "loss": 7.7031, + "step": 21 + }, + { + "epoch": 0.02, + "grad_norm": 0.7829163518610293, + "learning_rate": 0.0005998042629061279, + "loss": 7.6992, + "step": 22 + }, + { + "epoch": 0.02, + "grad_norm": 0.5900591778980369, + "learning_rate": 0.0005997702855911678, + "loss": 7.6016, + "step": 23 + }, + { + "epoch": 0.02, + "grad_norm": 0.4655170213064256, + "learning_rate": 0.0005997335916924268, + "loss": 7.5977, + "step": 24 + }, + { + "epoch": 0.02, + "grad_norm": 0.6287348258915756, + "learning_rate": 0.0005996941815794121, + "loss": 7.5586, + "step": 25 + }, + { + "epoch": 0.03, + "grad_norm": 0.6137321903884564, + "learning_rate": 0.0005996520556489831, + "loss": 7.5898, + "step": 26 + }, + { + "epoch": 0.03, + "grad_norm": 0.44962562710631065, + "learning_rate": 0.0005996072143253473, + "loss": 7.4336, + "step": 27 + }, + { + "epoch": 0.03, + "grad_norm": 0.46130046454703316, + "learning_rate": 0.0005995596580600566, + "loss": 7.4023, + "step": 28 + }, + { + "epoch": 0.03, + "grad_norm": 0.4686712675731326, + "learning_rate": 0.0005995093873320018, + "loss": 7.3789, + "step": 29 + }, + { + "epoch": 0.03, + "grad_norm": 0.4672147564288997, + "learning_rate": 0.0005994564026474087, + "loss": 7.3711, + "step": 30 + }, + { + "epoch": 0.03, + "grad_norm": 0.40408354581233474, + "learning_rate": 0.0005994007045398324, + "loss": 7.3672, + "step": 31 + }, + { + "epoch": 0.03, + "grad_norm": 0.46032146732584733, + "learning_rate": 0.0005993422935701524, + "loss": 7.3477, + "step": 32 + }, + { + "epoch": 0.03, + "grad_norm": 0.4765534634593268, + "learning_rate": 0.0005992811703265664, + "loss": 7.3555, + "step": 33 + }, + { + "epoch": 0.03, + "grad_norm": 0.46208489386235113, + "learning_rate": 0.0005992173354245849, + "loss": 7.3047, + "step": 34 + }, + { + "epoch": 0.03, + "grad_norm": 0.2956144524964961, + "learning_rate": 0.0005991507895070244, + "loss": 7.3125, + "step": 35 + }, + { + "epoch": 0.04, + "grad_norm": 0.4834645389868856, + "learning_rate": 0.0005990815332440017, + "loss": 7.207, + "step": 36 + }, + { + "epoch": 0.04, + "grad_norm": 0.4411831350968505, + "learning_rate": 0.0005990095673329266, + "loss": 7.1758, + "step": 37 + }, + { + "epoch": 0.04, + "grad_norm": 0.24809297748968667, + "learning_rate": 0.0005989348924984951, + "loss": 7.2188, + "step": 38 + }, + { + "epoch": 0.04, + "grad_norm": 0.39402988416840584, + "learning_rate": 0.0005988575094926817, + "loss": 7.1953, + "step": 39 + }, + { + "epoch": 0.04, + "grad_norm": 0.3868345222189167, + "learning_rate": 0.0005987774190947328, + "loss": 7.1641, + "step": 40 + }, + { + "epoch": 0.04, + "grad_norm": 0.3777261230135448, + "learning_rate": 0.0005986946221111575, + "loss": 7.1328, + "step": 41 + }, + { + "epoch": 0.04, + "grad_norm": 0.4687511444077827, + "learning_rate": 0.0005986091193757206, + "loss": 7.0898, + "step": 42 + }, + { + "epoch": 0.04, + "grad_norm": 0.34935796211612463, + "learning_rate": 0.0005985209117494337, + "loss": 7.1367, + "step": 43 + }, + { + "epoch": 0.04, + "grad_norm": 0.38764476686849886, + "learning_rate": 0.0005984300001205466, + "loss": 7.125, + "step": 44 + }, + { + "epoch": 0.04, + "grad_norm": 0.3956487898882936, + "learning_rate": 0.0005983363854045386, + "loss": 7.1094, + "step": 45 + }, + { + "epoch": 0.05, + "grad_norm": 0.31140257544677513, + "learning_rate": 0.0005982400685441084, + "loss": 7.0898, + "step": 46 + }, + { + "epoch": 0.05, + "grad_norm": 0.3664476570531787, + "learning_rate": 0.0005981410505091662, + "loss": 7.0664, + "step": 47 + }, + { + "epoch": 0.05, + "grad_norm": 0.31891741142945207, + "learning_rate": 0.0005980393322968223, + "loss": 7.0273, + "step": 48 + }, + { + "epoch": 0.05, + "grad_norm": 0.4533529037337155, + "learning_rate": 0.0005979349149313778, + "loss": 7.0586, + "step": 49 + }, + { + "epoch": 0.05, + "grad_norm": 0.30532331638835586, + "learning_rate": 0.0005978277994643147, + "loss": 7.0195, + "step": 50 + }, + { + "epoch": 0.05, + "grad_norm": 0.6501991746260075, + "learning_rate": 0.0005977179869742844, + "loss": 6.9648, + "step": 51 + }, + { + "epoch": 0.05, + "grad_norm": 0.43904455901717926, + "learning_rate": 0.0005976054785670975, + "loss": 6.9805, + "step": 52 + }, + { + "epoch": 0.05, + "grad_norm": 0.4826001598483571, + "learning_rate": 0.0005974902753757124, + "loss": 6.9297, + "step": 53 + }, + { + "epoch": 0.05, + "grad_norm": 0.2924998027034648, + "learning_rate": 0.000597372378560224, + "loss": 6.8984, + "step": 54 + }, + { + "epoch": 0.05, + "grad_norm": 0.4439033666380787, + "learning_rate": 0.0005972517893078517, + "loss": 6.8945, + "step": 55 + }, + { + "epoch": 0.06, + "grad_norm": 0.6135914255073411, + "learning_rate": 0.0005971285088329284, + "loss": 6.9727, + "step": 56 + }, + { + "epoch": 0.06, + "grad_norm": 0.5575686565598483, + "learning_rate": 0.0005970025383768866, + "loss": 6.9219, + "step": 57 + }, + { + "epoch": 0.06, + "grad_norm": 0.4820951675994578, + "learning_rate": 0.0005968738792082478, + "loss": 6.8516, + "step": 58 + }, + { + "epoch": 0.06, + "grad_norm": 0.40164190019465584, + "learning_rate": 0.0005967425326226082, + "loss": 6.7734, + "step": 59 + }, + { + "epoch": 0.06, + "grad_norm": 0.46129863945181293, + "learning_rate": 0.0005966084999426265, + "loss": 6.8125, + "step": 60 + }, + { + "epoch": 0.06, + "grad_norm": 0.33322355827118677, + "learning_rate": 0.0005964717825180101, + "loss": 6.7891, + "step": 61 + }, + { + "epoch": 0.06, + "grad_norm": 0.3847525153855558, + "learning_rate": 0.0005963323817255024, + "loss": 6.8242, + "step": 62 + }, + { + "epoch": 0.06, + "grad_norm": 0.3384433591375982, + "learning_rate": 0.0005961902989688674, + "loss": 6.707, + "step": 63 + }, + { + "epoch": 0.06, + "grad_norm": 0.3937003195165685, + "learning_rate": 0.000596045535678877, + "loss": 6.8203, + "step": 64 + }, + { + "epoch": 0.06, + "grad_norm": 0.35423488053528107, + "learning_rate": 0.0005958980933132962, + "loss": 6.7383, + "step": 65 + }, + { + "epoch": 0.07, + "grad_norm": 0.36005939745315396, + "learning_rate": 0.0005957479733568675, + "loss": 6.7109, + "step": 66 + }, + { + "epoch": 0.07, + "grad_norm": 0.3499278317706933, + "learning_rate": 0.0005955951773212976, + "loss": 6.7266, + "step": 67 + }, + { + "epoch": 0.07, + "grad_norm": 0.3708385192137018, + "learning_rate": 0.0005954397067452407, + "loss": 6.7617, + "step": 68 + }, + { + "epoch": 0.07, + "grad_norm": 0.3775657656205869, + "learning_rate": 0.0005952815631942839, + "loss": 6.7148, + "step": 69 + }, + { + "epoch": 0.07, + "grad_norm": 0.3040083750375816, + "learning_rate": 0.0005951207482609307, + "loss": 6.5938, + "step": 70 + }, + { + "epoch": 0.07, + "grad_norm": 0.3443020808841468, + "learning_rate": 0.0005949572635645861, + "loss": 6.6523, + "step": 71 + }, + { + "epoch": 0.07, + "grad_norm": 0.3520066316939, + "learning_rate": 0.0005947911107515389, + "loss": 6.6211, + "step": 72 + }, + { + "epoch": 0.07, + "grad_norm": 0.3739040572679613, + "learning_rate": 0.0005946222914949462, + "loss": 6.5547, + "step": 73 + }, + { + "epoch": 0.07, + "grad_norm": 0.34890731989025553, + "learning_rate": 0.000594450807494816, + "loss": 6.5859, + "step": 74 + }, + { + "epoch": 0.07, + "grad_norm": 0.40910932350136514, + "learning_rate": 0.0005942766604779903, + "loss": 6.5547, + "step": 75 + }, + { + "epoch": 0.08, + "grad_norm": 0.5698342865852906, + "learning_rate": 0.0005940998521981274, + "loss": 6.457, + "step": 76 + }, + { + "epoch": 0.08, + "grad_norm": 0.5179452709555474, + "learning_rate": 0.0005939203844356852, + "loss": 6.5547, + "step": 77 + }, + { + "epoch": 0.08, + "grad_norm": 0.5222512938673792, + "learning_rate": 0.0005937382589979016, + "loss": 6.5039, + "step": 78 + }, + { + "epoch": 0.08, + "grad_norm": 0.5682332793686307, + "learning_rate": 0.0005935534777187781, + "loss": 6.5547, + "step": 79 + }, + { + "epoch": 0.08, + "grad_norm": 0.3869287710460676, + "learning_rate": 0.0005933660424590598, + "loss": 6.5156, + "step": 80 + }, + { + "epoch": 0.08, + "grad_norm": 0.3078211032807607, + "learning_rate": 0.000593175955106218, + "loss": 6.4258, + "step": 81 + }, + { + "epoch": 0.08, + "grad_norm": 0.3611357511872241, + "learning_rate": 0.00059298321757443, + "loss": 6.4727, + "step": 82 + }, + { + "epoch": 0.08, + "grad_norm": 0.29633467844266953, + "learning_rate": 0.0005927878318045608, + "loss": 6.3281, + "step": 83 + }, + { + "epoch": 0.08, + "grad_norm": 0.3257574200776832, + "learning_rate": 0.0005925897997641426, + "loss": 6.3203, + "step": 84 + }, + { + "epoch": 0.08, + "grad_norm": 0.2824054533852328, + "learning_rate": 0.0005923891234473562, + "loss": 6.4062, + "step": 85 + }, + { + "epoch": 0.09, + "grad_norm": 0.3056199770204573, + "learning_rate": 0.0005921858048750097, + "loss": 6.3984, + "step": 86 + }, + { + "epoch": 0.09, + "grad_norm": 0.2966438824341908, + "learning_rate": 0.000591979846094519, + "loss": 6.3555, + "step": 87 + }, + { + "epoch": 0.09, + "grad_norm": 0.32782438676663733, + "learning_rate": 0.0005917712491798866, + "loss": 6.4023, + "step": 88 + }, + { + "epoch": 0.09, + "grad_norm": 0.3538316399620157, + "learning_rate": 0.0005915600162316811, + "loss": 6.2812, + "step": 89 + }, + { + "epoch": 0.09, + "grad_norm": 0.375858298192913, + "learning_rate": 0.0005913461493770162, + "loss": 6.3086, + "step": 90 + }, + { + "epoch": 0.09, + "grad_norm": 0.5189251339815161, + "learning_rate": 0.0005911296507695284, + "loss": 6.2812, + "step": 91 + }, + { + "epoch": 0.09, + "grad_norm": 0.6304909542669104, + "learning_rate": 0.0005909105225893564, + "loss": 6.2969, + "step": 92 + }, + { + "epoch": 0.09, + "grad_norm": 0.4655662819622591, + "learning_rate": 0.0005906887670431187, + "loss": 6.1953, + "step": 93 + }, + { + "epoch": 0.09, + "grad_norm": 0.39035390983920965, + "learning_rate": 0.000590464386363891, + "loss": 6.2617, + "step": 94 + }, + { + "epoch": 0.09, + "grad_norm": 0.4918417851770978, + "learning_rate": 0.0005902373828111843, + "loss": 6.2148, + "step": 95 + }, + { + "epoch": 0.1, + "grad_norm": 0.35670770889552555, + "learning_rate": 0.0005900077586709219, + "loss": 6.2461, + "step": 96 + }, + { + "epoch": 0.1, + "grad_norm": 0.4177985869939347, + "learning_rate": 0.0005897755162554163, + "loss": 6.1797, + "step": 97 + }, + { + "epoch": 0.1, + "grad_norm": 0.3742471130708234, + "learning_rate": 0.000589540657903346, + "loss": 6.1406, + "step": 98 + }, + { + "epoch": 0.1, + "grad_norm": 0.28627666723978284, + "learning_rate": 0.0005893031859797322, + "loss": 6.2031, + "step": 99 + }, + { + "epoch": 0.1, + "grad_norm": 0.32238563846046103, + "learning_rate": 0.0005890631028759143, + "loss": 6.0625, + "step": 100 + }, + { + "epoch": 0.1, + "grad_norm": 0.2556625657587849, + "learning_rate": 0.0005888204110095265, + "loss": 6.1797, + "step": 101 + }, + { + "epoch": 0.1, + "grad_norm": 0.35463629701710253, + "learning_rate": 0.0005885751128244734, + "loss": 6.125, + "step": 102 + }, + { + "epoch": 0.1, + "grad_norm": 0.31975770214936095, + "learning_rate": 0.0005883272107909048, + "loss": 6.1836, + "step": 103 + }, + { + "epoch": 0.1, + "grad_norm": 0.3464621815245048, + "learning_rate": 0.0005880767074051915, + "loss": 6.125, + "step": 104 + }, + { + "epoch": 0.1, + "grad_norm": 0.3663428920796654, + "learning_rate": 0.0005878236051898998, + "loss": 6.0781, + "step": 105 + }, + { + "epoch": 0.11, + "grad_norm": 0.31594460565215293, + "learning_rate": 0.0005875679066937664, + "loss": 6.082, + "step": 106 + }, + { + "epoch": 0.11, + "grad_norm": 0.3552617109396582, + "learning_rate": 0.000587309614491672, + "loss": 6.1016, + "step": 107 + }, + { + "epoch": 0.11, + "grad_norm": 0.307016409692456, + "learning_rate": 0.0005870487311846164, + "loss": 6.1406, + "step": 108 + }, + { + "epoch": 0.11, + "grad_norm": 0.32188902148474213, + "learning_rate": 0.0005867852593996914, + "loss": 6.0039, + "step": 109 + }, + { + "epoch": 0.11, + "grad_norm": 0.25501199715105083, + "learning_rate": 0.0005865192017900551, + "loss": 6.0938, + "step": 110 + }, + { + "epoch": 0.11, + "grad_norm": 0.3416203070024056, + "learning_rate": 0.0005862505610349049, + "loss": 6.0234, + "step": 111 + }, + { + "epoch": 0.11, + "grad_norm": 0.3562508875852537, + "learning_rate": 0.0005859793398394498, + "loss": 6.0469, + "step": 112 + }, + { + "epoch": 0.11, + "grad_norm": 0.4443953757302568, + "learning_rate": 0.0005857055409348845, + "loss": 5.9766, + "step": 113 + }, + { + "epoch": 0.11, + "grad_norm": 0.42023839332714596, + "learning_rate": 0.0005854291670783607, + "loss": 6.0781, + "step": 114 + }, + { + "epoch": 0.11, + "grad_norm": 0.4618323255809241, + "learning_rate": 0.0005851502210529604, + "loss": 5.9727, + "step": 115 + }, + { + "epoch": 0.12, + "grad_norm": 0.379195014798667, + "learning_rate": 0.0005848687056676668, + "loss": 5.9922, + "step": 116 + }, + { + "epoch": 0.12, + "grad_norm": 0.3931552573296799, + "learning_rate": 0.0005845846237573366, + "loss": 5.9492, + "step": 117 + }, + { + "epoch": 0.12, + "grad_norm": 0.2567080044949908, + "learning_rate": 0.0005842979781826717, + "loss": 6.0273, + "step": 118 + }, + { + "epoch": 0.12, + "grad_norm": 0.4190305965377807, + "learning_rate": 0.0005840087718301895, + "loss": 6.0391, + "step": 119 + }, + { + "epoch": 0.12, + "grad_norm": 0.3996803869430228, + "learning_rate": 0.0005837170076121951, + "loss": 5.9531, + "step": 120 + }, + { + "epoch": 0.12, + "grad_norm": 0.478219248015785, + "learning_rate": 0.000583422688466751, + "loss": 6.0586, + "step": 121 + }, + { + "epoch": 0.12, + "grad_norm": 0.40869844309811526, + "learning_rate": 0.0005831258173576474, + "loss": 6.0117, + "step": 122 + }, + { + "epoch": 0.12, + "grad_norm": 0.3728598080697978, + "learning_rate": 0.0005828263972743733, + "loss": 5.9375, + "step": 123 + }, + { + "epoch": 0.12, + "grad_norm": 0.3560055462882015, + "learning_rate": 0.0005825244312320856, + "loss": 5.9531, + "step": 124 + }, + { + "epoch": 0.12, + "grad_norm": 0.40446932887864323, + "learning_rate": 0.0005822199222715787, + "loss": 5.9609, + "step": 125 + }, + { + "epoch": 0.13, + "grad_norm": 0.38514065739946723, + "learning_rate": 0.000581912873459255, + "loss": 5.8594, + "step": 126 + }, + { + "epoch": 0.13, + "grad_norm": 0.35367576386319416, + "learning_rate": 0.0005816032878870921, + "loss": 5.9023, + "step": 127 + }, + { + "epoch": 0.13, + "grad_norm": 0.3341681995122829, + "learning_rate": 0.0005812911686726135, + "loss": 5.9062, + "step": 128 + }, + { + "epoch": 0.13, + "grad_norm": 0.3387022688975784, + "learning_rate": 0.0005809765189588563, + "loss": 5.8945, + "step": 129 + }, + { + "epoch": 0.13, + "grad_norm": 0.31638659898934757, + "learning_rate": 0.0005806593419143395, + "loss": 5.8242, + "step": 130 + }, + { + "epoch": 0.13, + "grad_norm": 0.3229678508227436, + "learning_rate": 0.0005803396407330325, + "loss": 5.8516, + "step": 131 + }, + { + "epoch": 0.13, + "grad_norm": 0.35499490868584455, + "learning_rate": 0.0005800174186343226, + "loss": 5.9258, + "step": 132 + }, + { + "epoch": 0.13, + "grad_norm": 0.40753171542848754, + "learning_rate": 0.0005796926788629828, + "loss": 5.8242, + "step": 133 + }, + { + "epoch": 0.13, + "grad_norm": 0.3625374018348824, + "learning_rate": 0.0005793654246891389, + "loss": 5.832, + "step": 134 + }, + { + "epoch": 0.13, + "grad_norm": 0.3583489573569317, + "learning_rate": 0.000579035659408237, + "loss": 5.8398, + "step": 135 + }, + { + "epoch": 0.14, + "grad_norm": 0.39657706318861896, + "learning_rate": 0.0005787033863410095, + "loss": 5.8633, + "step": 136 + }, + { + "epoch": 0.14, + "grad_norm": 0.3965837889564036, + "learning_rate": 0.0005783686088334428, + "loss": 5.8633, + "step": 137 + }, + { + "epoch": 0.14, + "grad_norm": 0.29496474301865566, + "learning_rate": 0.0005780313302567424, + "loss": 5.8203, + "step": 138 + }, + { + "epoch": 0.14, + "grad_norm": 0.44637192639243695, + "learning_rate": 0.0005776915540073001, + "loss": 5.8477, + "step": 139 + }, + { + "epoch": 0.14, + "grad_norm": 0.39605473508683114, + "learning_rate": 0.0005773492835066587, + "loss": 5.7383, + "step": 140 + }, + { + "epoch": 0.14, + "grad_norm": 0.3008962634266945, + "learning_rate": 0.0005770045222014786, + "loss": 5.7617, + "step": 141 + }, + { + "epoch": 0.14, + "grad_norm": 0.36915495506607826, + "learning_rate": 0.0005766572735635022, + "loss": 5.7695, + "step": 142 + }, + { + "epoch": 0.14, + "grad_norm": 0.3282300349560706, + "learning_rate": 0.0005763075410895193, + "loss": 5.8281, + "step": 143 + }, + { + "epoch": 0.14, + "grad_norm": 0.2747449814083844, + "learning_rate": 0.0005759553283013323, + "loss": 5.7812, + "step": 144 + }, + { + "epoch": 0.14, + "grad_norm": 0.28905882704179764, + "learning_rate": 0.00057560063874572, + "loss": 5.7344, + "step": 145 + }, + { + "epoch": 0.15, + "grad_norm": 0.280625988867192, + "learning_rate": 0.000575243475994402, + "loss": 5.7773, + "step": 146 + }, + { + "epoch": 0.15, + "grad_norm": 0.41061863948012467, + "learning_rate": 0.0005748838436440035, + "loss": 5.7578, + "step": 147 + }, + { + "epoch": 0.15, + "grad_norm": 0.4920152483870267, + "learning_rate": 0.0005745217453160183, + "loss": 5.7305, + "step": 148 + }, + { + "epoch": 0.15, + "grad_norm": 0.5463207978955044, + "learning_rate": 0.0005741571846567725, + "loss": 5.7383, + "step": 149 + }, + { + "epoch": 0.15, + "grad_norm": 0.3986359831157306, + "learning_rate": 0.0005737901653373878, + "loss": 5.668, + "step": 150 + }, + { + "epoch": 0.15, + "grad_norm": 0.37908758170100293, + "learning_rate": 0.0005734206910537447, + "loss": 5.6875, + "step": 151 + }, + { + "epoch": 0.15, + "grad_norm": 0.35929793070492694, + "learning_rate": 0.0005730487655264451, + "loss": 5.7188, + "step": 152 + }, + { + "epoch": 0.15, + "grad_norm": 0.4217799574145456, + "learning_rate": 0.0005726743925007751, + "loss": 5.7305, + "step": 153 + }, + { + "epoch": 0.15, + "grad_norm": 0.4024411981587195, + "learning_rate": 0.0005722975757466667, + "loss": 5.6289, + "step": 154 + }, + { + "epoch": 0.15, + "grad_norm": 0.3472391905877033, + "learning_rate": 0.0005719183190586606, + "loss": 5.6523, + "step": 155 + }, + { + "epoch": 0.16, + "grad_norm": 0.31752956812138816, + "learning_rate": 0.0005715366262558675, + "loss": 5.6172, + "step": 156 + }, + { + "epoch": 0.16, + "grad_norm": 0.3170152384332457, + "learning_rate": 0.0005711525011819294, + "loss": 5.6172, + "step": 157 + }, + { + "epoch": 0.16, + "grad_norm": 0.40520629326601837, + "learning_rate": 0.0005707659477049818, + "loss": 5.625, + "step": 158 + }, + { + "epoch": 0.16, + "grad_norm": 0.3965976910198806, + "learning_rate": 0.0005703769697176137, + "loss": 5.6562, + "step": 159 + }, + { + "epoch": 0.16, + "grad_norm": 0.40422960541801994, + "learning_rate": 0.0005699855711368293, + "loss": 5.6836, + "step": 160 + }, + { + "epoch": 0.16, + "grad_norm": 0.3780813184050647, + "learning_rate": 0.0005695917559040079, + "loss": 5.5938, + "step": 161 + }, + { + "epoch": 0.16, + "grad_norm": 0.36917638857736573, + "learning_rate": 0.0005691955279848645, + "loss": 5.668, + "step": 162 + }, + { + "epoch": 0.16, + "grad_norm": 0.37769176081037814, + "learning_rate": 0.0005687968913694098, + "loss": 5.4961, + "step": 163 + }, + { + "epoch": 0.16, + "grad_norm": 0.3255116524991148, + "learning_rate": 0.0005683958500719103, + "loss": 5.5117, + "step": 164 + }, + { + "epoch": 0.16, + "grad_norm": 0.31897629016848805, + "learning_rate": 0.0005679924081308471, + "loss": 5.5664, + "step": 165 + }, + { + "epoch": 0.17, + "grad_norm": 0.2869064236553046, + "learning_rate": 0.0005675865696088764, + "loss": 5.5391, + "step": 166 + }, + { + "epoch": 0.17, + "grad_norm": 0.29226729022634845, + "learning_rate": 0.0005671783385927873, + "loss": 5.5586, + "step": 167 + }, + { + "epoch": 0.17, + "grad_norm": 0.2534117210955766, + "learning_rate": 0.0005667677191934618, + "loss": 5.5312, + "step": 168 + }, + { + "epoch": 0.17, + "grad_norm": 0.289828484125484, + "learning_rate": 0.0005663547155458326, + "loss": 5.6484, + "step": 169 + }, + { + "epoch": 0.17, + "grad_norm": 0.2717242930342115, + "learning_rate": 0.0005659393318088419, + "loss": 5.5352, + "step": 170 + }, + { + "epoch": 0.17, + "grad_norm": 0.3595538109137759, + "learning_rate": 0.0005655215721653993, + "loss": 5.5742, + "step": 171 + }, + { + "epoch": 0.17, + "grad_norm": 0.4255054350471108, + "learning_rate": 0.0005651014408223398, + "loss": 5.5469, + "step": 172 + }, + { + "epoch": 0.17, + "grad_norm": 0.3670561941219979, + "learning_rate": 0.0005646789420103814, + "loss": 5.5078, + "step": 173 + }, + { + "epoch": 0.17, + "grad_norm": 0.40280130904983164, + "learning_rate": 0.0005642540799840822, + "loss": 5.5, + "step": 174 + }, + { + "epoch": 0.17, + "grad_norm": 0.41159472035983025, + "learning_rate": 0.0005638268590217984, + "loss": 5.5039, + "step": 175 + }, + { + "epoch": 0.18, + "grad_norm": 0.4316778037513652, + "learning_rate": 0.0005633972834256401, + "loss": 5.5352, + "step": 176 + }, + { + "epoch": 0.18, + "grad_norm": 0.5674781128363939, + "learning_rate": 0.000562965357521429, + "loss": 5.4336, + "step": 177 + }, + { + "epoch": 0.18, + "grad_norm": 0.41654662151365446, + "learning_rate": 0.0005625310856586541, + "loss": 5.6211, + "step": 178 + }, + { + "epoch": 0.18, + "grad_norm": 0.5159976364107484, + "learning_rate": 0.0005620944722104282, + "loss": 5.4844, + "step": 179 + }, + { + "epoch": 0.18, + "grad_norm": 0.34364678177014185, + "learning_rate": 0.0005616555215734438, + "loss": 5.4922, + "step": 180 + }, + { + "epoch": 0.18, + "grad_norm": 0.3708077784459011, + "learning_rate": 0.0005612142381679289, + "loss": 5.5234, + "step": 181 + }, + { + "epoch": 0.18, + "grad_norm": 0.3620051253453866, + "learning_rate": 0.0005607706264376028, + "loss": 5.4961, + "step": 182 + }, + { + "epoch": 0.18, + "grad_norm": 0.34735585210929654, + "learning_rate": 0.0005603246908496305, + "loss": 5.4453, + "step": 183 + }, + { + "epoch": 0.18, + "grad_norm": 0.37719874705792217, + "learning_rate": 0.0005598764358945783, + "loss": 5.4844, + "step": 184 + }, + { + "epoch": 0.18, + "grad_norm": 0.3749130664831207, + "learning_rate": 0.0005594258660863689, + "loss": 5.4648, + "step": 185 + }, + { + "epoch": 0.19, + "grad_norm": 0.40951353306235827, + "learning_rate": 0.0005589729859622351, + "loss": 5.5039, + "step": 186 + }, + { + "epoch": 0.19, + "grad_norm": 0.40146882563949804, + "learning_rate": 0.0005585178000826745, + "loss": 5.3672, + "step": 187 + }, + { + "epoch": 0.19, + "grad_norm": 0.4062987628428303, + "learning_rate": 0.0005580603130314043, + "loss": 5.3984, + "step": 188 + }, + { + "epoch": 0.19, + "grad_norm": 0.35626322654799136, + "learning_rate": 0.0005576005294153138, + "loss": 5.3984, + "step": 189 + }, + { + "epoch": 0.19, + "grad_norm": 0.3140647930801716, + "learning_rate": 0.0005571384538644188, + "loss": 5.3906, + "step": 190 + }, + { + "epoch": 0.19, + "grad_norm": 0.2990060538353662, + "learning_rate": 0.0005566740910318153, + "loss": 5.3711, + "step": 191 + }, + { + "epoch": 0.19, + "grad_norm": 0.3337525907515936, + "learning_rate": 0.0005562074455936315, + "loss": 5.4023, + "step": 192 + }, + { + "epoch": 0.19, + "grad_norm": 0.3381587051014816, + "learning_rate": 0.000555738522248982, + "loss": 5.4414, + "step": 193 + }, + { + "epoch": 0.19, + "grad_norm": 0.2954008999469894, + "learning_rate": 0.0005552673257199197, + "loss": 5.418, + "step": 194 + }, + { + "epoch": 0.19, + "grad_norm": 0.3242310900810155, + "learning_rate": 0.0005547938607513882, + "loss": 5.418, + "step": 195 + }, + { + "epoch": 0.2, + "grad_norm": 0.3149021804393678, + "learning_rate": 0.0005543181321111747, + "loss": 5.4375, + "step": 196 + }, + { + "epoch": 0.2, + "grad_norm": 0.32859412218218814, + "learning_rate": 0.0005538401445898612, + "loss": 5.4492, + "step": 197 + }, + { + "epoch": 0.2, + "grad_norm": 0.2960282598050701, + "learning_rate": 0.0005533599030007768, + "loss": 5.3867, + "step": 198 + }, + { + "epoch": 0.2, + "grad_norm": 0.2866762878199755, + "learning_rate": 0.0005528774121799489, + "loss": 5.3789, + "step": 199 + }, + { + "epoch": 0.2, + "grad_norm": 0.34865216327038784, + "learning_rate": 0.0005523926769860549, + "loss": 5.3711, + "step": 200 + }, + { + "epoch": 0.2, + "grad_norm": 0.4043023482242469, + "learning_rate": 0.0005519057023003725, + "loss": 5.3906, + "step": 201 + }, + { + "epoch": 0.2, + "grad_norm": 0.4069960968887199, + "learning_rate": 0.0005514164930267316, + "loss": 5.2773, + "step": 202 + }, + { + "epoch": 0.2, + "grad_norm": 0.4051152667506829, + "learning_rate": 0.0005509250540914641, + "loss": 5.3242, + "step": 203 + }, + { + "epoch": 0.2, + "grad_norm": 0.375026562862574, + "learning_rate": 0.0005504313904433546, + "loss": 5.4258, + "step": 204 + }, + { + "epoch": 0.2, + "grad_norm": 0.3326184185943848, + "learning_rate": 0.0005499355070535906, + "loss": 5.375, + "step": 205 + }, + { + "epoch": 0.21, + "grad_norm": 0.3695014522224558, + "learning_rate": 0.0005494374089157123, + "loss": 5.3984, + "step": 206 + }, + { + "epoch": 0.21, + "grad_norm": 0.2793258171824813, + "learning_rate": 0.0005489371010455625, + "loss": 5.2891, + "step": 207 + }, + { + "epoch": 0.21, + "grad_norm": 0.2879966080096621, + "learning_rate": 0.0005484345884812357, + "loss": 5.3867, + "step": 208 + }, + { + "epoch": 0.21, + "grad_norm": 0.32599687735840654, + "learning_rate": 0.0005479298762830281, + "loss": 5.3203, + "step": 209 + }, + { + "epoch": 0.21, + "grad_norm": 0.31305226164510963, + "learning_rate": 0.0005474229695333857, + "loss": 5.3281, + "step": 210 + }, + { + "epoch": 0.21, + "grad_norm": 0.3514527997420013, + "learning_rate": 0.000546913873336854, + "loss": 5.3008, + "step": 211 + }, + { + "epoch": 0.21, + "grad_norm": 0.38188707638514424, + "learning_rate": 0.0005464025928200261, + "loss": 5.3086, + "step": 212 + }, + { + "epoch": 0.21, + "grad_norm": 0.3865148796842015, + "learning_rate": 0.0005458891331314909, + "loss": 5.2656, + "step": 213 + }, + { + "epoch": 0.21, + "grad_norm": 0.4304784604066023, + "learning_rate": 0.0005453734994417819, + "loss": 5.3125, + "step": 214 + }, + { + "epoch": 0.21, + "grad_norm": 0.40269356862192995, + "learning_rate": 0.0005448556969433247, + "loss": 5.2617, + "step": 215 + }, + { + "epoch": 0.22, + "grad_norm": 0.30541089575928587, + "learning_rate": 0.0005443357308503845, + "loss": 5.2422, + "step": 216 + }, + { + "epoch": 0.22, + "grad_norm": 0.29104576978792596, + "learning_rate": 0.0005438136063990142, + "loss": 5.2109, + "step": 217 + }, + { + "epoch": 0.22, + "grad_norm": 0.291891354913362, + "learning_rate": 0.0005432893288470012, + "loss": 5.2617, + "step": 218 + }, + { + "epoch": 0.22, + "grad_norm": 0.3301944866145271, + "learning_rate": 0.0005427629034738149, + "loss": 5.2188, + "step": 219 + }, + { + "epoch": 0.22, + "grad_norm": 0.33824328942983417, + "learning_rate": 0.0005422343355805525, + "loss": 5.293, + "step": 220 + }, + { + "epoch": 0.22, + "grad_norm": 0.3539026997032359, + "learning_rate": 0.0005417036304898872, + "loss": 5.2695, + "step": 221 + }, + { + "epoch": 0.22, + "grad_norm": 0.38720918633148693, + "learning_rate": 0.0005411707935460132, + "loss": 5.2227, + "step": 222 + }, + { + "epoch": 0.22, + "grad_norm": 0.4539797383631105, + "learning_rate": 0.0005406358301145925, + "loss": 5.2539, + "step": 223 + }, + { + "epoch": 0.22, + "grad_norm": 0.40620115793500733, + "learning_rate": 0.0005400987455827012, + "loss": 5.2852, + "step": 224 + }, + { + "epoch": 0.22, + "grad_norm": 0.3680272948713411, + "learning_rate": 0.0005395595453587743, + "loss": 5.2617, + "step": 225 + }, + { + "epoch": 0.23, + "grad_norm": 0.3919096232059878, + "learning_rate": 0.0005390182348725522, + "loss": 5.2305, + "step": 226 + }, + { + "epoch": 0.23, + "grad_norm": 0.3783288666206609, + "learning_rate": 0.0005384748195750255, + "loss": 5.2031, + "step": 227 + }, + { + "epoch": 0.23, + "grad_norm": 0.34519921770570766, + "learning_rate": 0.0005379293049383802, + "loss": 5.2227, + "step": 228 + }, + { + "epoch": 0.23, + "grad_norm": 0.3548414963147158, + "learning_rate": 0.0005373816964559426, + "loss": 5.2891, + "step": 229 + }, + { + "epoch": 0.23, + "grad_norm": 0.36291865229291537, + "learning_rate": 0.000536831999642124, + "loss": 5.2266, + "step": 230 + }, + { + "epoch": 0.23, + "grad_norm": 0.313916097271022, + "learning_rate": 0.0005362802200323654, + "loss": 5.1055, + "step": 231 + }, + { + "epoch": 0.23, + "grad_norm": 0.29232836352032804, + "learning_rate": 0.0005357263631830811, + "loss": 5.1406, + "step": 232 + }, + { + "epoch": 0.23, + "grad_norm": 0.34482143058503106, + "learning_rate": 0.0005351704346716036, + "loss": 5.2305, + "step": 233 + }, + { + "epoch": 0.23, + "grad_norm": 0.3079065808428287, + "learning_rate": 0.0005346124400961267, + "loss": 5.2031, + "step": 234 + }, + { + "epoch": 0.23, + "grad_norm": 0.2869436862887739, + "learning_rate": 0.0005340523850756497, + "loss": 5.2539, + "step": 235 + }, + { + "epoch": 0.24, + "grad_norm": 0.27208356804470046, + "learning_rate": 0.0005334902752499204, + "loss": 5.1484, + "step": 236 + }, + { + "epoch": 0.24, + "grad_norm": 0.27768753128858653, + "learning_rate": 0.0005329261162793785, + "loss": 5.1758, + "step": 237 + }, + { + "epoch": 0.24, + "grad_norm": 0.2701859056468535, + "learning_rate": 0.0005323599138450985, + "loss": 5.1562, + "step": 238 + }, + { + "epoch": 0.24, + "grad_norm": 0.2940215458662745, + "learning_rate": 0.0005317916736487328, + "loss": 5.1406, + "step": 239 + }, + { + "epoch": 0.24, + "grad_norm": 0.29636403080234647, + "learning_rate": 0.0005312214014124536, + "loss": 5.1719, + "step": 240 + }, + { + "epoch": 0.24, + "grad_norm": 0.3513688083715198, + "learning_rate": 0.0005306491028788964, + "loss": 5.0664, + "step": 241 + }, + { + "epoch": 0.24, + "grad_norm": 0.455104024911365, + "learning_rate": 0.0005300747838111007, + "loss": 5.1289, + "step": 242 + }, + { + "epoch": 0.24, + "grad_norm": 0.5257166308389952, + "learning_rate": 0.0005294984499924532, + "loss": 5.1523, + "step": 243 + }, + { + "epoch": 0.24, + "grad_norm": 0.440798061960299, + "learning_rate": 0.0005289201072266293, + "loss": 5.1289, + "step": 244 + }, + { + "epoch": 0.24, + "grad_norm": 0.4965659619997502, + "learning_rate": 0.0005283397613375339, + "loss": 5.1211, + "step": 245 + }, + { + "epoch": 0.25, + "grad_norm": 0.40267641703114215, + "learning_rate": 0.0005277574181692438, + "loss": 5.0586, + "step": 246 + }, + { + "epoch": 0.25, + "grad_norm": 0.4013007078780512, + "learning_rate": 0.0005271730835859485, + "loss": 5.0273, + "step": 247 + }, + { + "epoch": 0.25, + "grad_norm": 0.38447773033555227, + "learning_rate": 0.0005265867634718903, + "loss": 5.1367, + "step": 248 + }, + { + "epoch": 0.25, + "grad_norm": 0.37763602900633203, + "learning_rate": 0.0005259984637313066, + "loss": 5.1055, + "step": 249 + }, + { + "epoch": 0.25, + "grad_norm": 0.344024017964152, + "learning_rate": 0.0005254081902883689, + "loss": 5.0898, + "step": 250 + }, + { + "epoch": 0.25, + "grad_norm": 0.35441912273779097, + "learning_rate": 0.0005248159490871245, + "loss": 5.1016, + "step": 251 + }, + { + "epoch": 0.25, + "grad_norm": 0.2877284013478678, + "learning_rate": 0.0005242217460914358, + "loss": 5.0664, + "step": 252 + }, + { + "epoch": 0.25, + "grad_norm": 0.3143093064571279, + "learning_rate": 0.0005236255872849201, + "loss": 5.1484, + "step": 253 + }, + { + "epoch": 0.25, + "grad_norm": 0.31206187291371684, + "learning_rate": 0.00052302747867089, + "loss": 5.1328, + "step": 254 + }, + { + "epoch": 0.25, + "grad_norm": 0.3150920418962865, + "learning_rate": 0.000522427426272293, + "loss": 5.1289, + "step": 255 + }, + { + "epoch": 0.26, + "grad_norm": 0.3195539774191906, + "learning_rate": 0.0005218254361316495, + "loss": 5.0898, + "step": 256 + }, + { + "epoch": 0.26, + "grad_norm": 0.24548404338795576, + "learning_rate": 0.000521221514310994, + "loss": 5.1016, + "step": 257 + }, + { + "epoch": 0.26, + "grad_norm": 0.25649802021467205, + "learning_rate": 0.0005206156668918122, + "loss": 5.1289, + "step": 258 + }, + { + "epoch": 0.26, + "grad_norm": 0.25018114739252273, + "learning_rate": 0.0005200078999749811, + "loss": 5.0508, + "step": 259 + }, + { + "epoch": 0.26, + "grad_norm": 0.2740344343745378, + "learning_rate": 0.0005193982196807067, + "loss": 5.082, + "step": 260 + }, + { + "epoch": 0.26, + "grad_norm": 0.30807201125247574, + "learning_rate": 0.0005187866321484628, + "loss": 5.0078, + "step": 261 + }, + { + "epoch": 0.26, + "grad_norm": 0.32367849723934244, + "learning_rate": 0.0005181731435369292, + "loss": 5.0625, + "step": 262 + }, + { + "epoch": 0.26, + "grad_norm": 0.3465653029312147, + "learning_rate": 0.0005175577600239292, + "loss": 5.0078, + "step": 263 + }, + { + "epoch": 0.26, + "grad_norm": 0.3716869632171198, + "learning_rate": 0.0005169404878063681, + "loss": 5.0977, + "step": 264 + }, + { + "epoch": 0.26, + "grad_norm": 0.37681584996379275, + "learning_rate": 0.0005163213331001702, + "loss": 5.082, + "step": 265 + }, + { + "epoch": 0.27, + "grad_norm": 0.34462519335888353, + "learning_rate": 0.0005157003021402166, + "loss": 4.9844, + "step": 266 + }, + { + "epoch": 0.27, + "grad_norm": 0.39514090390949574, + "learning_rate": 0.000515077401180282, + "loss": 5.0312, + "step": 267 + }, + { + "epoch": 0.27, + "grad_norm": 0.46469822376758096, + "learning_rate": 0.0005144526364929722, + "loss": 5.0234, + "step": 268 + }, + { + "epoch": 0.27, + "grad_norm": 0.34570371767844565, + "learning_rate": 0.0005138260143696608, + "loss": 5.0352, + "step": 269 + }, + { + "epoch": 0.27, + "grad_norm": 0.2920012584285204, + "learning_rate": 0.0005131975411204257, + "loss": 4.9805, + "step": 270 + }, + { + "epoch": 0.27, + "grad_norm": 0.34109638913820345, + "learning_rate": 0.0005125672230739852, + "loss": 4.9844, + "step": 271 + }, + { + "epoch": 0.27, + "grad_norm": 0.2976316922487618, + "learning_rate": 0.0005119350665776353, + "loss": 4.9727, + "step": 272 + }, + { + "epoch": 0.27, + "grad_norm": 0.38160864657971466, + "learning_rate": 0.0005113010779971848, + "loss": 5.0312, + "step": 273 + }, + { + "epoch": 0.27, + "grad_norm": 0.40407725833100544, + "learning_rate": 0.0005106652637168917, + "loss": 5.0312, + "step": 274 + }, + { + "epoch": 0.27, + "grad_norm": 0.36275741793161437, + "learning_rate": 0.0005100276301393987, + "loss": 5.0391, + "step": 275 + }, + { + "epoch": 0.28, + "grad_norm": 0.35097531980231905, + "learning_rate": 0.0005093881836856688, + "loss": 4.9844, + "step": 276 + }, + { + "epoch": 0.28, + "grad_norm": 0.3615382021996322, + "learning_rate": 0.000508746930794921, + "loss": 4.9453, + "step": 277 + }, + { + "epoch": 0.28, + "grad_norm": 0.3260265986197515, + "learning_rate": 0.0005081038779245643, + "loss": 5.0078, + "step": 278 + }, + { + "epoch": 0.28, + "grad_norm": 0.3230813193726234, + "learning_rate": 0.0005074590315501345, + "loss": 5.0, + "step": 279 + }, + { + "epoch": 0.28, + "grad_norm": 0.43011368510100667, + "learning_rate": 0.000506812398165227, + "loss": 4.9961, + "step": 280 + }, + { + "epoch": 0.28, + "grad_norm": 0.4688261606016039, + "learning_rate": 0.0005061639842814328, + "loss": 4.9883, + "step": 281 + }, + { + "epoch": 0.28, + "grad_norm": 0.4082387881237382, + "learning_rate": 0.0005055137964282728, + "loss": 4.9492, + "step": 282 + }, + { + "epoch": 0.28, + "grad_norm": 0.4102411273145604, + "learning_rate": 0.0005048618411531315, + "loss": 4.9492, + "step": 283 + }, + { + "epoch": 0.28, + "grad_norm": 0.3333699558922032, + "learning_rate": 0.000504208125021191, + "loss": 4.9492, + "step": 284 + }, + { + "epoch": 0.28, + "grad_norm": 0.3014113897515229, + "learning_rate": 0.0005035526546153656, + "loss": 4.9922, + "step": 285 + }, + { + "epoch": 0.29, + "grad_norm": 0.33242045759712463, + "learning_rate": 0.000502895436536235, + "loss": 4.8906, + "step": 286 + }, + { + "epoch": 0.29, + "grad_norm": 0.27804952465315824, + "learning_rate": 0.000502236477401978, + "loss": 4.8828, + "step": 287 + }, + { + "epoch": 0.29, + "grad_norm": 0.346783453227663, + "learning_rate": 0.0005015757838483058, + "loss": 4.9453, + "step": 288 + }, + { + "epoch": 0.29, + "grad_norm": 0.33206265244928296, + "learning_rate": 0.000500913362528395, + "loss": 4.9102, + "step": 289 + }, + { + "epoch": 0.29, + "grad_norm": 0.31507543033475727, + "learning_rate": 0.000500249220112821, + "loss": 4.9336, + "step": 290 + }, + { + "epoch": 0.29, + "grad_norm": 0.34558992633865376, + "learning_rate": 0.0004995833632894907, + "loss": 4.8867, + "step": 291 + }, + { + "epoch": 0.29, + "grad_norm": 0.3596650694441014, + "learning_rate": 0.0004989157987635748, + "loss": 4.9141, + "step": 292 + }, + { + "epoch": 0.29, + "grad_norm": 0.26520540250540703, + "learning_rate": 0.0004982465332574405, + "loss": 4.9648, + "step": 293 + }, + { + "epoch": 0.29, + "grad_norm": 0.2957335916241638, + "learning_rate": 0.0004975755735105844, + "loss": 4.9297, + "step": 294 + }, + { + "epoch": 0.29, + "grad_norm": 0.33075169113632213, + "learning_rate": 0.0004969029262795634, + "loss": 4.9102, + "step": 295 + }, + { + "epoch": 0.3, + "grad_norm": 0.3588819230985392, + "learning_rate": 0.0004962285983379276, + "loss": 4.8672, + "step": 296 + }, + { + "epoch": 0.3, + "grad_norm": 0.3441202272395266, + "learning_rate": 0.0004955525964761522, + "loss": 4.8203, + "step": 297 + }, + { + "epoch": 0.3, + "grad_norm": 0.3150553179412103, + "learning_rate": 0.0004948749275015682, + "loss": 4.8945, + "step": 298 + }, + { + "epoch": 0.3, + "grad_norm": 0.31033579532429983, + "learning_rate": 0.0004941955982382948, + "loss": 4.9336, + "step": 299 + }, + { + "epoch": 0.3, + "grad_norm": 0.3118267914201189, + "learning_rate": 0.0004935146155271699, + "loss": 4.8125, + "step": 300 + } + ], + "logging_steps": 1, + "max_steps": 1000, + "num_input_tokens_seen": 0, + "num_train_epochs": 1, + "save_steps": 100, + "total_flos": 0.0, + "train_batch_size": 32, + "trial_name": null, + "trial_params": null +} diff --git a/checkpoint-300/training_args.bin b/checkpoint-300/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..09b35ec8ac2a16eb45febe1d655d456e47af68d1 --- /dev/null +++ b/checkpoint-300/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbd9a6067cf818494e2505097746a1cad30533fc72eb13916de34f7671e3e456 +size 6520 diff --git a/checkpoint-300/zero_to_fp32.py b/checkpoint-300/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/checkpoint-300/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/checkpoint-400/config.json b/checkpoint-400/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b59b0b4c67b30baa7b62a3a87fc086e8dd1f8916 --- /dev/null +++ b/checkpoint-400/config.json @@ -0,0 +1,31 @@ +{ + "_name_or_path": "georgeyw/gpt-2-small-init-seed-5", + "architectures": [ + "GPTNeoXForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "bos_token_id": 0, + "classifier_dropout": 0.1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 1024, + "model_type": "gpt_neox", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "rope_scaling": null, + "rotary_emb_base": 10000, + "rotary_pct": 0.25, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.2", + "use_cache": true, + "use_parallel_residual": true, + "vocab_size": 50304 +} diff --git a/checkpoint-400/generation_config.json b/checkpoint-400/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..bc00b333fdf0ba3611d022ddfdaeaf527fab8da0 --- /dev/null +++ b/checkpoint-400/generation_config.json @@ -0,0 +1,6 @@ +{ + "_from_model_config": true, + "bos_token_id": 0, + "eos_token_id": 2, + "transformers_version": "4.38.2" +} diff --git a/checkpoint-400/global_step400/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-400/global_step400/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..f7b9939f41c8e261e7cac9ef63719a153477d5bb --- /dev/null +++ b/checkpoint-400/global_step400/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a79a645fbfebd19d0c0bc738afcb5a06400eb6af3b5a9a758726a7ae9188b688 +size 973946896 diff --git a/checkpoint-400/global_step400/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-400/global_step400/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..e3ad2b0dedaeae43ce6cd31d8a8c2c058e4505f0 --- /dev/null +++ b/checkpoint-400/global_step400/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37810609d31b6e7e0ddbc5d8142074e8d5cace38a773d8c0c79857976b7f3452 +size 973946832 diff --git a/checkpoint-400/global_step400/mp_rank_00_model_states.pt b/checkpoint-400/global_step400/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..1abc504a1c257915da28038415e7fde3b4036626 --- /dev/null +++ b/checkpoint-400/global_step400/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbf23858dde2df1331a7e83cb65d6a167f69174c52c8bdde00c501d631e091d8 +size 324689964 diff --git a/checkpoint-400/latest b/checkpoint-400/latest new file mode 100644 index 0000000000000000000000000000000000000000..e5bdf58d4f29d34e909da25905fad376f73e7c29 --- /dev/null +++ b/checkpoint-400/latest @@ -0,0 +1 @@ +global_step400 \ No newline at end of file diff --git a/checkpoint-400/model.safetensors b/checkpoint-400/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..9136e3fab302e9d02503ea10b357d5942da86148 --- /dev/null +++ b/checkpoint-400/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66ed36c4dcc6ceec22851060da063587ce67e78fc89cf6d774ac3f1218f9ac5d +size 324662984 diff --git a/checkpoint-400/rng_state_0.pth b/checkpoint-400/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..cad18ac770da4331076b9ef49fc91a7f9a5989c3 --- /dev/null +++ b/checkpoint-400/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0bb7d2ecdd48fd7d0be1e75b0e3f29004064381052fa203ed926e88b90ef530 +size 14512 diff --git a/checkpoint-400/rng_state_1.pth b/checkpoint-400/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..197bac5f7fe92d301270b1f25b8fa7a07b568293 --- /dev/null +++ b/checkpoint-400/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:177d534a379bd6b276474c2cb140e318dc65db4457b6c1b6f25a1a9dd563af82 +size 14512 diff --git a/checkpoint-400/trainer_state.json b/checkpoint-400/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..7d9b9ea1ca1b8c0eeb3c6a748f7f63407db8e15e --- /dev/null +++ b/checkpoint-400/trainer_state.json @@ -0,0 +1,2821 @@ +{ + "best_metric": null, + "best_model_checkpoint": null, + "epoch": 0.39980009995002497, + "eval_steps": 500, + "global_step": 400, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 0.0, + "grad_norm": 3.3340563149001086, + "learning_rate": 0.0, + "loss": 11.0, + "step": 1 + }, + { + "epoch": 0.0, + "grad_norm": 2.398812329952019, + "learning_rate": 5.9999999999999995e-05, + "loss": 10.125, + "step": 2 + }, + { + "epoch": 0.0, + "grad_norm": 2.394322446895115, + "learning_rate": 0.00011999999999999999, + "loss": 10.1172, + "step": 3 + }, + { + "epoch": 0.0, + "grad_norm": 1.9958816684399585, + "learning_rate": 0.00017999999999999998, + "loss": 9.875, + "step": 4 + }, + { + "epoch": 0.0, + "grad_norm": 1.8270465897882062, + "learning_rate": 0.00023999999999999998, + "loss": 9.6641, + "step": 5 + }, + { + "epoch": 0.01, + "grad_norm": 1.7854046471397795, + "learning_rate": 0.0003, + "loss": 9.4844, + "step": 6 + }, + { + "epoch": 0.01, + "grad_norm": 1.719416749115252, + "learning_rate": 0.00035999999999999997, + "loss": 9.3281, + "step": 7 + }, + { + "epoch": 0.01, + "grad_norm": 1.4637825746112274, + "learning_rate": 0.00041999999999999996, + "loss": 9.2109, + "step": 8 + }, + { + "epoch": 0.01, + "grad_norm": 1.4393631015406718, + "learning_rate": 0.00047999999999999996, + "loss": 8.9453, + "step": 9 + }, + { + "epoch": 0.01, + "grad_norm": 1.2936734586915988, + "learning_rate": 0.00054, + "loss": 8.7109, + "step": 10 + }, + { + "epoch": 0.01, + "grad_norm": 1.0756922378227356, + "learning_rate": 0.0005999986405514987, + "loss": 8.4609, + "step": 11 + }, + { + "epoch": 0.01, + "grad_norm": 0.9277829127413892, + "learning_rate": 0.0005999945622196846, + "loss": 8.2344, + "step": 12 + }, + { + "epoch": 0.01, + "grad_norm": 0.8084581786682467, + "learning_rate": 0.0005999877650456265, + "loss": 8.125, + "step": 13 + }, + { + "epoch": 0.01, + "grad_norm": 0.7635084596900947, + "learning_rate": 0.000599978249097772, + "loss": 7.9766, + "step": 14 + }, + { + "epoch": 0.01, + "grad_norm": 0.9186699644247788, + "learning_rate": 0.0005999660144719463, + "loss": 7.8555, + "step": 15 + }, + { + "epoch": 0.02, + "grad_norm": 0.6609504256551479, + "learning_rate": 0.0005999510612913519, + "loss": 7.7734, + "step": 16 + }, + { + "epoch": 0.02, + "grad_norm": 0.7086232844782971, + "learning_rate": 0.0005999333897065673, + "loss": 7.7148, + "step": 17 + }, + { + "epoch": 0.02, + "grad_norm": 16.38048851691348, + "learning_rate": 0.0005999129998955453, + "loss": 8.4844, + "step": 18 + }, + { + "epoch": 0.02, + "grad_norm": 1.3057527590449889, + "learning_rate": 0.0005998898920636111, + "loss": 7.7539, + "step": 19 + }, + { + "epoch": 0.02, + "grad_norm": 0.6966048242948986, + "learning_rate": 0.00059986406644346, + "loss": 7.75, + "step": 20 + }, + { + "epoch": 0.02, + "grad_norm": 0.6348089115348993, + "learning_rate": 0.0005998355232951559, + "loss": 7.7031, + "step": 21 + }, + { + "epoch": 0.02, + "grad_norm": 0.7829163518610293, + "learning_rate": 0.0005998042629061279, + "loss": 7.6992, + "step": 22 + }, + { + "epoch": 0.02, + "grad_norm": 0.5900591778980369, + "learning_rate": 0.0005997702855911678, + "loss": 7.6016, + "step": 23 + }, + { + "epoch": 0.02, + "grad_norm": 0.4655170213064256, + "learning_rate": 0.0005997335916924268, + "loss": 7.5977, + "step": 24 + }, + { + "epoch": 0.02, + "grad_norm": 0.6287348258915756, + "learning_rate": 0.0005996941815794121, + "loss": 7.5586, + "step": 25 + }, + { + "epoch": 0.03, + "grad_norm": 0.6137321903884564, + "learning_rate": 0.0005996520556489831, + "loss": 7.5898, + "step": 26 + }, + { + "epoch": 0.03, + "grad_norm": 0.44962562710631065, + "learning_rate": 0.0005996072143253473, + "loss": 7.4336, + "step": 27 + }, + { + "epoch": 0.03, + "grad_norm": 0.46130046454703316, + "learning_rate": 0.0005995596580600566, + "loss": 7.4023, + "step": 28 + }, + { + "epoch": 0.03, + "grad_norm": 0.4686712675731326, + "learning_rate": 0.0005995093873320018, + "loss": 7.3789, + "step": 29 + }, + { + "epoch": 0.03, + "grad_norm": 0.4672147564288997, + "learning_rate": 0.0005994564026474087, + "loss": 7.3711, + "step": 30 + }, + { + "epoch": 0.03, + "grad_norm": 0.40408354581233474, + "learning_rate": 0.0005994007045398324, + "loss": 7.3672, + "step": 31 + }, + { + "epoch": 0.03, + "grad_norm": 0.46032146732584733, + "learning_rate": 0.0005993422935701524, + "loss": 7.3477, + "step": 32 + }, + { + "epoch": 0.03, + "grad_norm": 0.4765534634593268, + "learning_rate": 0.0005992811703265664, + "loss": 7.3555, + "step": 33 + }, + { + "epoch": 0.03, + "grad_norm": 0.46208489386235113, + "learning_rate": 0.0005992173354245849, + "loss": 7.3047, + "step": 34 + }, + { + "epoch": 0.03, + "grad_norm": 0.2956144524964961, + "learning_rate": 0.0005991507895070244, + "loss": 7.3125, + "step": 35 + }, + { + "epoch": 0.04, + "grad_norm": 0.4834645389868856, + "learning_rate": 0.0005990815332440017, + "loss": 7.207, + "step": 36 + }, + { + "epoch": 0.04, + "grad_norm": 0.4411831350968505, + "learning_rate": 0.0005990095673329266, + "loss": 7.1758, + "step": 37 + }, + { + "epoch": 0.04, + "grad_norm": 0.24809297748968667, + "learning_rate": 0.0005989348924984951, + "loss": 7.2188, + "step": 38 + }, + { + "epoch": 0.04, + "grad_norm": 0.39402988416840584, + "learning_rate": 0.0005988575094926817, + "loss": 7.1953, + "step": 39 + }, + { + "epoch": 0.04, + "grad_norm": 0.3868345222189167, + "learning_rate": 0.0005987774190947328, + "loss": 7.1641, + "step": 40 + }, + { + "epoch": 0.04, + "grad_norm": 0.3777261230135448, + "learning_rate": 0.0005986946221111575, + "loss": 7.1328, + "step": 41 + }, + { + "epoch": 0.04, + "grad_norm": 0.4687511444077827, + "learning_rate": 0.0005986091193757206, + "loss": 7.0898, + "step": 42 + }, + { + "epoch": 0.04, + "grad_norm": 0.34935796211612463, + "learning_rate": 0.0005985209117494337, + "loss": 7.1367, + "step": 43 + }, + { + "epoch": 0.04, + "grad_norm": 0.38764476686849886, + "learning_rate": 0.0005984300001205466, + "loss": 7.125, + "step": 44 + }, + { + "epoch": 0.04, + "grad_norm": 0.3956487898882936, + "learning_rate": 0.0005983363854045386, + "loss": 7.1094, + "step": 45 + }, + { + "epoch": 0.05, + "grad_norm": 0.31140257544677513, + "learning_rate": 0.0005982400685441084, + "loss": 7.0898, + "step": 46 + }, + { + "epoch": 0.05, + "grad_norm": 0.3664476570531787, + "learning_rate": 0.0005981410505091662, + "loss": 7.0664, + "step": 47 + }, + { + "epoch": 0.05, + "grad_norm": 0.31891741142945207, + "learning_rate": 0.0005980393322968223, + "loss": 7.0273, + "step": 48 + }, + { + "epoch": 0.05, + "grad_norm": 0.4533529037337155, + "learning_rate": 0.0005979349149313778, + "loss": 7.0586, + "step": 49 + }, + { + "epoch": 0.05, + "grad_norm": 0.30532331638835586, + "learning_rate": 0.0005978277994643147, + "loss": 7.0195, + "step": 50 + }, + { + "epoch": 0.05, + "grad_norm": 0.6501991746260075, + "learning_rate": 0.0005977179869742844, + "loss": 6.9648, + "step": 51 + }, + { + "epoch": 0.05, + "grad_norm": 0.43904455901717926, + "learning_rate": 0.0005976054785670975, + "loss": 6.9805, + "step": 52 + }, + { + "epoch": 0.05, + "grad_norm": 0.4826001598483571, + "learning_rate": 0.0005974902753757124, + "loss": 6.9297, + "step": 53 + }, + { + "epoch": 0.05, + "grad_norm": 0.2924998027034648, + "learning_rate": 0.000597372378560224, + "loss": 6.8984, + "step": 54 + }, + { + "epoch": 0.05, + "grad_norm": 0.4439033666380787, + "learning_rate": 0.0005972517893078517, + "loss": 6.8945, + "step": 55 + }, + { + "epoch": 0.06, + "grad_norm": 0.6135914255073411, + "learning_rate": 0.0005971285088329284, + "loss": 6.9727, + "step": 56 + }, + { + "epoch": 0.06, + "grad_norm": 0.5575686565598483, + "learning_rate": 0.0005970025383768866, + "loss": 6.9219, + "step": 57 + }, + { + "epoch": 0.06, + "grad_norm": 0.4820951675994578, + "learning_rate": 0.0005968738792082478, + "loss": 6.8516, + "step": 58 + }, + { + "epoch": 0.06, + "grad_norm": 0.40164190019465584, + "learning_rate": 0.0005967425326226082, + "loss": 6.7734, + "step": 59 + }, + { + "epoch": 0.06, + "grad_norm": 0.46129863945181293, + "learning_rate": 0.0005966084999426265, + "loss": 6.8125, + "step": 60 + }, + { + "epoch": 0.06, + "grad_norm": 0.33322355827118677, + "learning_rate": 0.0005964717825180101, + "loss": 6.7891, + "step": 61 + }, + { + "epoch": 0.06, + "grad_norm": 0.3847525153855558, + "learning_rate": 0.0005963323817255024, + "loss": 6.8242, + "step": 62 + }, + { + "epoch": 0.06, + "grad_norm": 0.3384433591375982, + "learning_rate": 0.0005961902989688674, + "loss": 6.707, + "step": 63 + }, + { + "epoch": 0.06, + "grad_norm": 0.3937003195165685, + "learning_rate": 0.000596045535678877, + "loss": 6.8203, + "step": 64 + }, + { + "epoch": 0.06, + "grad_norm": 0.35423488053528107, + "learning_rate": 0.0005958980933132962, + "loss": 6.7383, + "step": 65 + }, + { + "epoch": 0.07, + "grad_norm": 0.36005939745315396, + "learning_rate": 0.0005957479733568675, + "loss": 6.7109, + "step": 66 + }, + { + "epoch": 0.07, + "grad_norm": 0.3499278317706933, + "learning_rate": 0.0005955951773212976, + "loss": 6.7266, + "step": 67 + }, + { + "epoch": 0.07, + "grad_norm": 0.3708385192137018, + "learning_rate": 0.0005954397067452407, + "loss": 6.7617, + "step": 68 + }, + { + "epoch": 0.07, + "grad_norm": 0.3775657656205869, + "learning_rate": 0.0005952815631942839, + "loss": 6.7148, + "step": 69 + }, + { + "epoch": 0.07, + "grad_norm": 0.3040083750375816, + "learning_rate": 0.0005951207482609307, + "loss": 6.5938, + "step": 70 + }, + { + "epoch": 0.07, + "grad_norm": 0.3443020808841468, + "learning_rate": 0.0005949572635645861, + "loss": 6.6523, + "step": 71 + }, + { + "epoch": 0.07, + "grad_norm": 0.3520066316939, + "learning_rate": 0.0005947911107515389, + "loss": 6.6211, + "step": 72 + }, + { + "epoch": 0.07, + "grad_norm": 0.3739040572679613, + "learning_rate": 0.0005946222914949462, + "loss": 6.5547, + "step": 73 + }, + { + "epoch": 0.07, + "grad_norm": 0.34890731989025553, + "learning_rate": 0.000594450807494816, + "loss": 6.5859, + "step": 74 + }, + { + "epoch": 0.07, + "grad_norm": 0.40910932350136514, + "learning_rate": 0.0005942766604779903, + "loss": 6.5547, + "step": 75 + }, + { + "epoch": 0.08, + "grad_norm": 0.5698342865852906, + "learning_rate": 0.0005940998521981274, + "loss": 6.457, + "step": 76 + }, + { + "epoch": 0.08, + "grad_norm": 0.5179452709555474, + "learning_rate": 0.0005939203844356852, + "loss": 6.5547, + "step": 77 + }, + { + "epoch": 0.08, + "grad_norm": 0.5222512938673792, + "learning_rate": 0.0005937382589979016, + "loss": 6.5039, + "step": 78 + }, + { + "epoch": 0.08, + "grad_norm": 0.5682332793686307, + "learning_rate": 0.0005935534777187781, + "loss": 6.5547, + "step": 79 + }, + { + "epoch": 0.08, + "grad_norm": 0.3869287710460676, + "learning_rate": 0.0005933660424590598, + "loss": 6.5156, + "step": 80 + }, + { + "epoch": 0.08, + "grad_norm": 0.3078211032807607, + "learning_rate": 0.000593175955106218, + "loss": 6.4258, + "step": 81 + }, + { + "epoch": 0.08, + "grad_norm": 0.3611357511872241, + "learning_rate": 0.00059298321757443, + "loss": 6.4727, + "step": 82 + }, + { + "epoch": 0.08, + "grad_norm": 0.29633467844266953, + "learning_rate": 0.0005927878318045608, + "loss": 6.3281, + "step": 83 + }, + { + "epoch": 0.08, + "grad_norm": 0.3257574200776832, + "learning_rate": 0.0005925897997641426, + "loss": 6.3203, + "step": 84 + }, + { + "epoch": 0.08, + "grad_norm": 0.2824054533852328, + "learning_rate": 0.0005923891234473562, + "loss": 6.4062, + "step": 85 + }, + { + "epoch": 0.09, + "grad_norm": 0.3056199770204573, + "learning_rate": 0.0005921858048750097, + "loss": 6.3984, + "step": 86 + }, + { + "epoch": 0.09, + "grad_norm": 0.2966438824341908, + "learning_rate": 0.000591979846094519, + "loss": 6.3555, + "step": 87 + }, + { + "epoch": 0.09, + "grad_norm": 0.32782438676663733, + "learning_rate": 0.0005917712491798866, + "loss": 6.4023, + "step": 88 + }, + { + "epoch": 0.09, + "grad_norm": 0.3538316399620157, + "learning_rate": 0.0005915600162316811, + "loss": 6.2812, + "step": 89 + }, + { + "epoch": 0.09, + "grad_norm": 0.375858298192913, + "learning_rate": 0.0005913461493770162, + "loss": 6.3086, + "step": 90 + }, + { + "epoch": 0.09, + "grad_norm": 0.5189251339815161, + "learning_rate": 0.0005911296507695284, + "loss": 6.2812, + "step": 91 + }, + { + "epoch": 0.09, + "grad_norm": 0.6304909542669104, + "learning_rate": 0.0005909105225893564, + "loss": 6.2969, + "step": 92 + }, + { + "epoch": 0.09, + "grad_norm": 0.4655662819622591, + "learning_rate": 0.0005906887670431187, + "loss": 6.1953, + "step": 93 + }, + { + "epoch": 0.09, + "grad_norm": 0.39035390983920965, + "learning_rate": 0.000590464386363891, + "loss": 6.2617, + "step": 94 + }, + { + "epoch": 0.09, + "grad_norm": 0.4918417851770978, + "learning_rate": 0.0005902373828111843, + "loss": 6.2148, + "step": 95 + }, + { + "epoch": 0.1, + "grad_norm": 0.35670770889552555, + "learning_rate": 0.0005900077586709219, + "loss": 6.2461, + "step": 96 + }, + { + "epoch": 0.1, + "grad_norm": 0.4177985869939347, + "learning_rate": 0.0005897755162554163, + "loss": 6.1797, + "step": 97 + }, + { + "epoch": 0.1, + "grad_norm": 0.3742471130708234, + "learning_rate": 0.000589540657903346, + "loss": 6.1406, + "step": 98 + }, + { + "epoch": 0.1, + "grad_norm": 0.28627666723978284, + "learning_rate": 0.0005893031859797322, + "loss": 6.2031, + "step": 99 + }, + { + "epoch": 0.1, + "grad_norm": 0.32238563846046103, + "learning_rate": 0.0005890631028759143, + "loss": 6.0625, + "step": 100 + }, + { + "epoch": 0.1, + "grad_norm": 0.2556625657587849, + "learning_rate": 0.0005888204110095265, + "loss": 6.1797, + "step": 101 + }, + { + "epoch": 0.1, + "grad_norm": 0.35463629701710253, + "learning_rate": 0.0005885751128244734, + "loss": 6.125, + "step": 102 + }, + { + "epoch": 0.1, + "grad_norm": 0.31975770214936095, + "learning_rate": 0.0005883272107909048, + "loss": 6.1836, + "step": 103 + }, + { + "epoch": 0.1, + "grad_norm": 0.3464621815245048, + "learning_rate": 0.0005880767074051915, + "loss": 6.125, + "step": 104 + }, + { + "epoch": 0.1, + "grad_norm": 0.3663428920796654, + "learning_rate": 0.0005878236051898998, + "loss": 6.0781, + "step": 105 + }, + { + "epoch": 0.11, + "grad_norm": 0.31594460565215293, + "learning_rate": 0.0005875679066937664, + "loss": 6.082, + "step": 106 + }, + { + "epoch": 0.11, + "grad_norm": 0.3552617109396582, + "learning_rate": 0.000587309614491672, + "loss": 6.1016, + "step": 107 + }, + { + "epoch": 0.11, + "grad_norm": 0.307016409692456, + "learning_rate": 0.0005870487311846164, + "loss": 6.1406, + "step": 108 + }, + { + "epoch": 0.11, + "grad_norm": 0.32188902148474213, + "learning_rate": 0.0005867852593996914, + "loss": 6.0039, + "step": 109 + }, + { + "epoch": 0.11, + "grad_norm": 0.25501199715105083, + "learning_rate": 0.0005865192017900551, + "loss": 6.0938, + "step": 110 + }, + { + "epoch": 0.11, + "grad_norm": 0.3416203070024056, + "learning_rate": 0.0005862505610349049, + "loss": 6.0234, + "step": 111 + }, + { + "epoch": 0.11, + "grad_norm": 0.3562508875852537, + "learning_rate": 0.0005859793398394498, + "loss": 6.0469, + "step": 112 + }, + { + "epoch": 0.11, + "grad_norm": 0.4443953757302568, + "learning_rate": 0.0005857055409348845, + "loss": 5.9766, + "step": 113 + }, + { + "epoch": 0.11, + "grad_norm": 0.42023839332714596, + "learning_rate": 0.0005854291670783607, + "loss": 6.0781, + "step": 114 + }, + { + "epoch": 0.11, + "grad_norm": 0.4618323255809241, + "learning_rate": 0.0005851502210529604, + "loss": 5.9727, + "step": 115 + }, + { + "epoch": 0.12, + "grad_norm": 0.379195014798667, + "learning_rate": 0.0005848687056676668, + "loss": 5.9922, + "step": 116 + }, + { + "epoch": 0.12, + "grad_norm": 0.3931552573296799, + "learning_rate": 0.0005845846237573366, + "loss": 5.9492, + "step": 117 + }, + { + "epoch": 0.12, + "grad_norm": 0.2567080044949908, + "learning_rate": 0.0005842979781826717, + "loss": 6.0273, + "step": 118 + }, + { + "epoch": 0.12, + "grad_norm": 0.4190305965377807, + "learning_rate": 0.0005840087718301895, + "loss": 6.0391, + "step": 119 + }, + { + "epoch": 0.12, + "grad_norm": 0.3996803869430228, + "learning_rate": 0.0005837170076121951, + "loss": 5.9531, + "step": 120 + }, + { + "epoch": 0.12, + "grad_norm": 0.478219248015785, + "learning_rate": 0.000583422688466751, + "loss": 6.0586, + "step": 121 + }, + { + "epoch": 0.12, + "grad_norm": 0.40869844309811526, + "learning_rate": 0.0005831258173576474, + "loss": 6.0117, + "step": 122 + }, + { + "epoch": 0.12, + "grad_norm": 0.3728598080697978, + "learning_rate": 0.0005828263972743733, + "loss": 5.9375, + "step": 123 + }, + { + "epoch": 0.12, + "grad_norm": 0.3560055462882015, + "learning_rate": 0.0005825244312320856, + "loss": 5.9531, + "step": 124 + }, + { + "epoch": 0.12, + "grad_norm": 0.40446932887864323, + "learning_rate": 0.0005822199222715787, + "loss": 5.9609, + "step": 125 + }, + { + "epoch": 0.13, + "grad_norm": 0.38514065739946723, + "learning_rate": 0.000581912873459255, + "loss": 5.8594, + "step": 126 + }, + { + "epoch": 0.13, + "grad_norm": 0.35367576386319416, + "learning_rate": 0.0005816032878870921, + "loss": 5.9023, + "step": 127 + }, + { + "epoch": 0.13, + "grad_norm": 0.3341681995122829, + "learning_rate": 0.0005812911686726135, + "loss": 5.9062, + "step": 128 + }, + { + "epoch": 0.13, + "grad_norm": 0.3387022688975784, + "learning_rate": 0.0005809765189588563, + "loss": 5.8945, + "step": 129 + }, + { + "epoch": 0.13, + "grad_norm": 0.31638659898934757, + "learning_rate": 0.0005806593419143395, + "loss": 5.8242, + "step": 130 + }, + { + "epoch": 0.13, + "grad_norm": 0.3229678508227436, + "learning_rate": 0.0005803396407330325, + "loss": 5.8516, + "step": 131 + }, + { + "epoch": 0.13, + "grad_norm": 0.35499490868584455, + "learning_rate": 0.0005800174186343226, + "loss": 5.9258, + "step": 132 + }, + { + "epoch": 0.13, + "grad_norm": 0.40753171542848754, + "learning_rate": 0.0005796926788629828, + "loss": 5.8242, + "step": 133 + }, + { + "epoch": 0.13, + "grad_norm": 0.3625374018348824, + "learning_rate": 0.0005793654246891389, + "loss": 5.832, + "step": 134 + }, + { + "epoch": 0.13, + "grad_norm": 0.3583489573569317, + "learning_rate": 0.000579035659408237, + "loss": 5.8398, + "step": 135 + }, + { + "epoch": 0.14, + "grad_norm": 0.39657706318861896, + "learning_rate": 0.0005787033863410095, + "loss": 5.8633, + "step": 136 + }, + { + "epoch": 0.14, + "grad_norm": 0.3965837889564036, + "learning_rate": 0.0005783686088334428, + "loss": 5.8633, + "step": 137 + }, + { + "epoch": 0.14, + "grad_norm": 0.29496474301865566, + "learning_rate": 0.0005780313302567424, + "loss": 5.8203, + "step": 138 + }, + { + "epoch": 0.14, + "grad_norm": 0.44637192639243695, + "learning_rate": 0.0005776915540073001, + "loss": 5.8477, + "step": 139 + }, + { + "epoch": 0.14, + "grad_norm": 0.39605473508683114, + "learning_rate": 0.0005773492835066587, + "loss": 5.7383, + "step": 140 + }, + { + "epoch": 0.14, + "grad_norm": 0.3008962634266945, + "learning_rate": 0.0005770045222014786, + "loss": 5.7617, + "step": 141 + }, + { + "epoch": 0.14, + "grad_norm": 0.36915495506607826, + "learning_rate": 0.0005766572735635022, + "loss": 5.7695, + "step": 142 + }, + { + "epoch": 0.14, + "grad_norm": 0.3282300349560706, + "learning_rate": 0.0005763075410895193, + "loss": 5.8281, + "step": 143 + }, + { + "epoch": 0.14, + "grad_norm": 0.2747449814083844, + "learning_rate": 0.0005759553283013323, + "loss": 5.7812, + "step": 144 + }, + { + "epoch": 0.14, + "grad_norm": 0.28905882704179764, + "learning_rate": 0.00057560063874572, + "loss": 5.7344, + "step": 145 + }, + { + "epoch": 0.15, + "grad_norm": 0.280625988867192, + "learning_rate": 0.000575243475994402, + "loss": 5.7773, + "step": 146 + }, + { + "epoch": 0.15, + "grad_norm": 0.41061863948012467, + "learning_rate": 0.0005748838436440035, + "loss": 5.7578, + "step": 147 + }, + { + "epoch": 0.15, + "grad_norm": 0.4920152483870267, + "learning_rate": 0.0005745217453160183, + "loss": 5.7305, + "step": 148 + }, + { + "epoch": 0.15, + "grad_norm": 0.5463207978955044, + "learning_rate": 0.0005741571846567725, + "loss": 5.7383, + "step": 149 + }, + { + "epoch": 0.15, + "grad_norm": 0.3986359831157306, + "learning_rate": 0.0005737901653373878, + "loss": 5.668, + "step": 150 + }, + { + "epoch": 0.15, + "grad_norm": 0.37908758170100293, + "learning_rate": 0.0005734206910537447, + "loss": 5.6875, + "step": 151 + }, + { + "epoch": 0.15, + "grad_norm": 0.35929793070492694, + "learning_rate": 0.0005730487655264451, + "loss": 5.7188, + "step": 152 + }, + { + "epoch": 0.15, + "grad_norm": 0.4217799574145456, + "learning_rate": 0.0005726743925007751, + "loss": 5.7305, + "step": 153 + }, + { + "epoch": 0.15, + "grad_norm": 0.4024411981587195, + "learning_rate": 0.0005722975757466667, + "loss": 5.6289, + "step": 154 + }, + { + "epoch": 0.15, + "grad_norm": 0.3472391905877033, + "learning_rate": 0.0005719183190586606, + "loss": 5.6523, + "step": 155 + }, + { + "epoch": 0.16, + "grad_norm": 0.31752956812138816, + "learning_rate": 0.0005715366262558675, + "loss": 5.6172, + "step": 156 + }, + { + "epoch": 0.16, + "grad_norm": 0.3170152384332457, + "learning_rate": 0.0005711525011819294, + "loss": 5.6172, + "step": 157 + }, + { + "epoch": 0.16, + "grad_norm": 0.40520629326601837, + "learning_rate": 0.0005707659477049818, + "loss": 5.625, + "step": 158 + }, + { + "epoch": 0.16, + "grad_norm": 0.3965976910198806, + "learning_rate": 0.0005703769697176137, + "loss": 5.6562, + "step": 159 + }, + { + "epoch": 0.16, + "grad_norm": 0.40422960541801994, + "learning_rate": 0.0005699855711368293, + "loss": 5.6836, + "step": 160 + }, + { + "epoch": 0.16, + "grad_norm": 0.3780813184050647, + "learning_rate": 0.0005695917559040079, + "loss": 5.5938, + "step": 161 + }, + { + "epoch": 0.16, + "grad_norm": 0.36917638857736573, + "learning_rate": 0.0005691955279848645, + "loss": 5.668, + "step": 162 + }, + { + "epoch": 0.16, + "grad_norm": 0.37769176081037814, + "learning_rate": 0.0005687968913694098, + "loss": 5.4961, + "step": 163 + }, + { + "epoch": 0.16, + "grad_norm": 0.3255116524991148, + "learning_rate": 0.0005683958500719103, + "loss": 5.5117, + "step": 164 + }, + { + "epoch": 0.16, + "grad_norm": 0.31897629016848805, + "learning_rate": 0.0005679924081308471, + "loss": 5.5664, + "step": 165 + }, + { + "epoch": 0.17, + "grad_norm": 0.2869064236553046, + "learning_rate": 0.0005675865696088764, + "loss": 5.5391, + "step": 166 + }, + { + "epoch": 0.17, + "grad_norm": 0.29226729022634845, + "learning_rate": 0.0005671783385927873, + "loss": 5.5586, + "step": 167 + }, + { + "epoch": 0.17, + "grad_norm": 0.2534117210955766, + "learning_rate": 0.0005667677191934618, + "loss": 5.5312, + "step": 168 + }, + { + "epoch": 0.17, + "grad_norm": 0.289828484125484, + "learning_rate": 0.0005663547155458326, + "loss": 5.6484, + "step": 169 + }, + { + "epoch": 0.17, + "grad_norm": 0.2717242930342115, + "learning_rate": 0.0005659393318088419, + "loss": 5.5352, + "step": 170 + }, + { + "epoch": 0.17, + "grad_norm": 0.3595538109137759, + "learning_rate": 0.0005655215721653993, + "loss": 5.5742, + "step": 171 + }, + { + "epoch": 0.17, + "grad_norm": 0.4255054350471108, + "learning_rate": 0.0005651014408223398, + "loss": 5.5469, + "step": 172 + }, + { + "epoch": 0.17, + "grad_norm": 0.3670561941219979, + "learning_rate": 0.0005646789420103814, + "loss": 5.5078, + "step": 173 + }, + { + "epoch": 0.17, + "grad_norm": 0.40280130904983164, + "learning_rate": 0.0005642540799840822, + "loss": 5.5, + "step": 174 + }, + { + "epoch": 0.17, + "grad_norm": 0.41159472035983025, + "learning_rate": 0.0005638268590217984, + "loss": 5.5039, + "step": 175 + }, + { + "epoch": 0.18, + "grad_norm": 0.4316778037513652, + "learning_rate": 0.0005633972834256401, + "loss": 5.5352, + "step": 176 + }, + { + "epoch": 0.18, + "grad_norm": 0.5674781128363939, + "learning_rate": 0.000562965357521429, + "loss": 5.4336, + "step": 177 + }, + { + "epoch": 0.18, + "grad_norm": 0.41654662151365446, + "learning_rate": 0.0005625310856586541, + "loss": 5.6211, + "step": 178 + }, + { + "epoch": 0.18, + "grad_norm": 0.5159976364107484, + "learning_rate": 0.0005620944722104282, + "loss": 5.4844, + "step": 179 + }, + { + "epoch": 0.18, + "grad_norm": 0.34364678177014185, + "learning_rate": 0.0005616555215734438, + "loss": 5.4922, + "step": 180 + }, + { + "epoch": 0.18, + "grad_norm": 0.3708077784459011, + "learning_rate": 0.0005612142381679289, + "loss": 5.5234, + "step": 181 + }, + { + "epoch": 0.18, + "grad_norm": 0.3620051253453866, + "learning_rate": 0.0005607706264376028, + "loss": 5.4961, + "step": 182 + }, + { + "epoch": 0.18, + "grad_norm": 0.34735585210929654, + "learning_rate": 0.0005603246908496305, + "loss": 5.4453, + "step": 183 + }, + { + "epoch": 0.18, + "grad_norm": 0.37719874705792217, + "learning_rate": 0.0005598764358945783, + "loss": 5.4844, + "step": 184 + }, + { + "epoch": 0.18, + "grad_norm": 0.3749130664831207, + "learning_rate": 0.0005594258660863689, + "loss": 5.4648, + "step": 185 + }, + { + "epoch": 0.19, + "grad_norm": 0.40951353306235827, + "learning_rate": 0.0005589729859622351, + "loss": 5.5039, + "step": 186 + }, + { + "epoch": 0.19, + "grad_norm": 0.40146882563949804, + "learning_rate": 0.0005585178000826745, + "loss": 5.3672, + "step": 187 + }, + { + "epoch": 0.19, + "grad_norm": 0.4062987628428303, + "learning_rate": 0.0005580603130314043, + "loss": 5.3984, + "step": 188 + }, + { + "epoch": 0.19, + "grad_norm": 0.35626322654799136, + "learning_rate": 0.0005576005294153138, + "loss": 5.3984, + "step": 189 + }, + { + "epoch": 0.19, + "grad_norm": 0.3140647930801716, + "learning_rate": 0.0005571384538644188, + "loss": 5.3906, + "step": 190 + }, + { + "epoch": 0.19, + "grad_norm": 0.2990060538353662, + "learning_rate": 0.0005566740910318153, + "loss": 5.3711, + "step": 191 + }, + { + "epoch": 0.19, + "grad_norm": 0.3337525907515936, + "learning_rate": 0.0005562074455936315, + "loss": 5.4023, + "step": 192 + }, + { + "epoch": 0.19, + "grad_norm": 0.3381587051014816, + "learning_rate": 0.000555738522248982, + "loss": 5.4414, + "step": 193 + }, + { + "epoch": 0.19, + "grad_norm": 0.2954008999469894, + "learning_rate": 0.0005552673257199197, + "loss": 5.418, + "step": 194 + }, + { + "epoch": 0.19, + "grad_norm": 0.3242310900810155, + "learning_rate": 0.0005547938607513882, + "loss": 5.418, + "step": 195 + }, + { + "epoch": 0.2, + "grad_norm": 0.3149021804393678, + "learning_rate": 0.0005543181321111747, + "loss": 5.4375, + "step": 196 + }, + { + "epoch": 0.2, + "grad_norm": 0.32859412218218814, + "learning_rate": 0.0005538401445898612, + "loss": 5.4492, + "step": 197 + }, + { + "epoch": 0.2, + "grad_norm": 0.2960282598050701, + "learning_rate": 0.0005533599030007768, + "loss": 5.3867, + "step": 198 + }, + { + "epoch": 0.2, + "grad_norm": 0.2866762878199755, + "learning_rate": 0.0005528774121799489, + "loss": 5.3789, + "step": 199 + }, + { + "epoch": 0.2, + "grad_norm": 0.34865216327038784, + "learning_rate": 0.0005523926769860549, + "loss": 5.3711, + "step": 200 + }, + { + "epoch": 0.2, + "grad_norm": 0.4043023482242469, + "learning_rate": 0.0005519057023003725, + "loss": 5.3906, + "step": 201 + }, + { + "epoch": 0.2, + "grad_norm": 0.4069960968887199, + "learning_rate": 0.0005514164930267316, + "loss": 5.2773, + "step": 202 + }, + { + "epoch": 0.2, + "grad_norm": 0.4051152667506829, + "learning_rate": 0.0005509250540914641, + "loss": 5.3242, + "step": 203 + }, + { + "epoch": 0.2, + "grad_norm": 0.375026562862574, + "learning_rate": 0.0005504313904433546, + "loss": 5.4258, + "step": 204 + }, + { + "epoch": 0.2, + "grad_norm": 0.3326184185943848, + "learning_rate": 0.0005499355070535906, + "loss": 5.375, + "step": 205 + }, + { + "epoch": 0.21, + "grad_norm": 0.3695014522224558, + "learning_rate": 0.0005494374089157123, + "loss": 5.3984, + "step": 206 + }, + { + "epoch": 0.21, + "grad_norm": 0.2793258171824813, + "learning_rate": 0.0005489371010455625, + "loss": 5.2891, + "step": 207 + }, + { + "epoch": 0.21, + "grad_norm": 0.2879966080096621, + "learning_rate": 0.0005484345884812357, + "loss": 5.3867, + "step": 208 + }, + { + "epoch": 0.21, + "grad_norm": 0.32599687735840654, + "learning_rate": 0.0005479298762830281, + "loss": 5.3203, + "step": 209 + }, + { + "epoch": 0.21, + "grad_norm": 0.31305226164510963, + "learning_rate": 0.0005474229695333857, + "loss": 5.3281, + "step": 210 + }, + { + "epoch": 0.21, + "grad_norm": 0.3514527997420013, + "learning_rate": 0.000546913873336854, + "loss": 5.3008, + "step": 211 + }, + { + "epoch": 0.21, + "grad_norm": 0.38188707638514424, + "learning_rate": 0.0005464025928200261, + "loss": 5.3086, + "step": 212 + }, + { + "epoch": 0.21, + "grad_norm": 0.3865148796842015, + "learning_rate": 0.0005458891331314909, + "loss": 5.2656, + "step": 213 + }, + { + "epoch": 0.21, + "grad_norm": 0.4304784604066023, + "learning_rate": 0.0005453734994417819, + "loss": 5.3125, + "step": 214 + }, + { + "epoch": 0.21, + "grad_norm": 0.40269356862192995, + "learning_rate": 0.0005448556969433247, + "loss": 5.2617, + "step": 215 + }, + { + "epoch": 0.22, + "grad_norm": 0.30541089575928587, + "learning_rate": 0.0005443357308503845, + "loss": 5.2422, + "step": 216 + }, + { + "epoch": 0.22, + "grad_norm": 0.29104576978792596, + "learning_rate": 0.0005438136063990142, + "loss": 5.2109, + "step": 217 + }, + { + "epoch": 0.22, + "grad_norm": 0.291891354913362, + "learning_rate": 0.0005432893288470012, + "loss": 5.2617, + "step": 218 + }, + { + "epoch": 0.22, + "grad_norm": 0.3301944866145271, + "learning_rate": 0.0005427629034738149, + "loss": 5.2188, + "step": 219 + }, + { + "epoch": 0.22, + "grad_norm": 0.33824328942983417, + "learning_rate": 0.0005422343355805525, + "loss": 5.293, + "step": 220 + }, + { + "epoch": 0.22, + "grad_norm": 0.3539026997032359, + "learning_rate": 0.0005417036304898872, + "loss": 5.2695, + "step": 221 + }, + { + "epoch": 0.22, + "grad_norm": 0.38720918633148693, + "learning_rate": 0.0005411707935460132, + "loss": 5.2227, + "step": 222 + }, + { + "epoch": 0.22, + "grad_norm": 0.4539797383631105, + "learning_rate": 0.0005406358301145925, + "loss": 5.2539, + "step": 223 + }, + { + "epoch": 0.22, + "grad_norm": 0.40620115793500733, + "learning_rate": 0.0005400987455827012, + "loss": 5.2852, + "step": 224 + }, + { + "epoch": 0.22, + "grad_norm": 0.3680272948713411, + "learning_rate": 0.0005395595453587743, + "loss": 5.2617, + "step": 225 + }, + { + "epoch": 0.23, + "grad_norm": 0.3919096232059878, + "learning_rate": 0.0005390182348725522, + "loss": 5.2305, + "step": 226 + }, + { + "epoch": 0.23, + "grad_norm": 0.3783288666206609, + "learning_rate": 0.0005384748195750255, + "loss": 5.2031, + "step": 227 + }, + { + "epoch": 0.23, + "grad_norm": 0.34519921770570766, + "learning_rate": 0.0005379293049383802, + "loss": 5.2227, + "step": 228 + }, + { + "epoch": 0.23, + "grad_norm": 0.3548414963147158, + "learning_rate": 0.0005373816964559426, + "loss": 5.2891, + "step": 229 + }, + { + "epoch": 0.23, + "grad_norm": 0.36291865229291537, + "learning_rate": 0.000536831999642124, + "loss": 5.2266, + "step": 230 + }, + { + "epoch": 0.23, + "grad_norm": 0.313916097271022, + "learning_rate": 0.0005362802200323654, + "loss": 5.1055, + "step": 231 + }, + { + "epoch": 0.23, + "grad_norm": 0.29232836352032804, + "learning_rate": 0.0005357263631830811, + "loss": 5.1406, + "step": 232 + }, + { + "epoch": 0.23, + "grad_norm": 0.34482143058503106, + "learning_rate": 0.0005351704346716036, + "loss": 5.2305, + "step": 233 + }, + { + "epoch": 0.23, + "grad_norm": 0.3079065808428287, + "learning_rate": 0.0005346124400961267, + "loss": 5.2031, + "step": 234 + }, + { + "epoch": 0.23, + "grad_norm": 0.2869436862887739, + "learning_rate": 0.0005340523850756497, + "loss": 5.2539, + "step": 235 + }, + { + "epoch": 0.24, + "grad_norm": 0.27208356804470046, + "learning_rate": 0.0005334902752499204, + "loss": 5.1484, + "step": 236 + }, + { + "epoch": 0.24, + "grad_norm": 0.27768753128858653, + "learning_rate": 0.0005329261162793785, + "loss": 5.1758, + "step": 237 + }, + { + "epoch": 0.24, + "grad_norm": 0.2701859056468535, + "learning_rate": 0.0005323599138450985, + "loss": 5.1562, + "step": 238 + }, + { + "epoch": 0.24, + "grad_norm": 0.2940215458662745, + "learning_rate": 0.0005317916736487328, + "loss": 5.1406, + "step": 239 + }, + { + "epoch": 0.24, + "grad_norm": 0.29636403080234647, + "learning_rate": 0.0005312214014124536, + "loss": 5.1719, + "step": 240 + }, + { + "epoch": 0.24, + "grad_norm": 0.3513688083715198, + "learning_rate": 0.0005306491028788964, + "loss": 5.0664, + "step": 241 + }, + { + "epoch": 0.24, + "grad_norm": 0.455104024911365, + "learning_rate": 0.0005300747838111007, + "loss": 5.1289, + "step": 242 + }, + { + "epoch": 0.24, + "grad_norm": 0.5257166308389952, + "learning_rate": 0.0005294984499924532, + "loss": 5.1523, + "step": 243 + }, + { + "epoch": 0.24, + "grad_norm": 0.440798061960299, + "learning_rate": 0.0005289201072266293, + "loss": 5.1289, + "step": 244 + }, + { + "epoch": 0.24, + "grad_norm": 0.4965659619997502, + "learning_rate": 0.0005283397613375339, + "loss": 5.1211, + "step": 245 + }, + { + "epoch": 0.25, + "grad_norm": 0.40267641703114215, + "learning_rate": 0.0005277574181692438, + "loss": 5.0586, + "step": 246 + }, + { + "epoch": 0.25, + "grad_norm": 0.4013007078780512, + "learning_rate": 0.0005271730835859485, + "loss": 5.0273, + "step": 247 + }, + { + "epoch": 0.25, + "grad_norm": 0.38447773033555227, + "learning_rate": 0.0005265867634718903, + "loss": 5.1367, + "step": 248 + }, + { + "epoch": 0.25, + "grad_norm": 0.37763602900633203, + "learning_rate": 0.0005259984637313066, + "loss": 5.1055, + "step": 249 + }, + { + "epoch": 0.25, + "grad_norm": 0.344024017964152, + "learning_rate": 0.0005254081902883689, + "loss": 5.0898, + "step": 250 + }, + { + "epoch": 0.25, + "grad_norm": 0.35441912273779097, + "learning_rate": 0.0005248159490871245, + "loss": 5.1016, + "step": 251 + }, + { + "epoch": 0.25, + "grad_norm": 0.2877284013478678, + "learning_rate": 0.0005242217460914358, + "loss": 5.0664, + "step": 252 + }, + { + "epoch": 0.25, + "grad_norm": 0.3143093064571279, + "learning_rate": 0.0005236255872849201, + "loss": 5.1484, + "step": 253 + }, + { + "epoch": 0.25, + "grad_norm": 0.31206187291371684, + "learning_rate": 0.00052302747867089, + "loss": 5.1328, + "step": 254 + }, + { + "epoch": 0.25, + "grad_norm": 0.3150920418962865, + "learning_rate": 0.000522427426272293, + "loss": 5.1289, + "step": 255 + }, + { + "epoch": 0.26, + "grad_norm": 0.3195539774191906, + "learning_rate": 0.0005218254361316495, + "loss": 5.0898, + "step": 256 + }, + { + "epoch": 0.26, + "grad_norm": 0.24548404338795576, + "learning_rate": 0.000521221514310994, + "loss": 5.1016, + "step": 257 + }, + { + "epoch": 0.26, + "grad_norm": 0.25649802021467205, + "learning_rate": 0.0005206156668918122, + "loss": 5.1289, + "step": 258 + }, + { + "epoch": 0.26, + "grad_norm": 0.25018114739252273, + "learning_rate": 0.0005200078999749811, + "loss": 5.0508, + "step": 259 + }, + { + "epoch": 0.26, + "grad_norm": 0.2740344343745378, + "learning_rate": 0.0005193982196807067, + "loss": 5.082, + "step": 260 + }, + { + "epoch": 0.26, + "grad_norm": 0.30807201125247574, + "learning_rate": 0.0005187866321484628, + "loss": 5.0078, + "step": 261 + }, + { + "epoch": 0.26, + "grad_norm": 0.32367849723934244, + "learning_rate": 0.0005181731435369292, + "loss": 5.0625, + "step": 262 + }, + { + "epoch": 0.26, + "grad_norm": 0.3465653029312147, + "learning_rate": 0.0005175577600239292, + "loss": 5.0078, + "step": 263 + }, + { + "epoch": 0.26, + "grad_norm": 0.3716869632171198, + "learning_rate": 0.0005169404878063681, + "loss": 5.0977, + "step": 264 + }, + { + "epoch": 0.26, + "grad_norm": 0.37681584996379275, + "learning_rate": 0.0005163213331001702, + "loss": 5.082, + "step": 265 + }, + { + "epoch": 0.27, + "grad_norm": 0.34462519335888353, + "learning_rate": 0.0005157003021402166, + "loss": 4.9844, + "step": 266 + }, + { + "epoch": 0.27, + "grad_norm": 0.39514090390949574, + "learning_rate": 0.000515077401180282, + "loss": 5.0312, + "step": 267 + }, + { + "epoch": 0.27, + "grad_norm": 0.46469822376758096, + "learning_rate": 0.0005144526364929722, + "loss": 5.0234, + "step": 268 + }, + { + "epoch": 0.27, + "grad_norm": 0.34570371767844565, + "learning_rate": 0.0005138260143696608, + "loss": 5.0352, + "step": 269 + }, + { + "epoch": 0.27, + "grad_norm": 0.2920012584285204, + "learning_rate": 0.0005131975411204257, + "loss": 4.9805, + "step": 270 + }, + { + "epoch": 0.27, + "grad_norm": 0.34109638913820345, + "learning_rate": 0.0005125672230739852, + "loss": 4.9844, + "step": 271 + }, + { + "epoch": 0.27, + "grad_norm": 0.2976316922487618, + "learning_rate": 0.0005119350665776353, + "loss": 4.9727, + "step": 272 + }, + { + "epoch": 0.27, + "grad_norm": 0.38160864657971466, + "learning_rate": 0.0005113010779971848, + "loss": 5.0312, + "step": 273 + }, + { + "epoch": 0.27, + "grad_norm": 0.40407725833100544, + "learning_rate": 0.0005106652637168917, + "loss": 5.0312, + "step": 274 + }, + { + "epoch": 0.27, + "grad_norm": 0.36275741793161437, + "learning_rate": 0.0005100276301393987, + "loss": 5.0391, + "step": 275 + }, + { + "epoch": 0.28, + "grad_norm": 0.35097531980231905, + "learning_rate": 0.0005093881836856688, + "loss": 4.9844, + "step": 276 + }, + { + "epoch": 0.28, + "grad_norm": 0.3615382021996322, + "learning_rate": 0.000508746930794921, + "loss": 4.9453, + "step": 277 + }, + { + "epoch": 0.28, + "grad_norm": 0.3260265986197515, + "learning_rate": 0.0005081038779245643, + "loss": 5.0078, + "step": 278 + }, + { + "epoch": 0.28, + "grad_norm": 0.3230813193726234, + "learning_rate": 0.0005074590315501345, + "loss": 5.0, + "step": 279 + }, + { + "epoch": 0.28, + "grad_norm": 0.43011368510100667, + "learning_rate": 0.000506812398165227, + "loss": 4.9961, + "step": 280 + }, + { + "epoch": 0.28, + "grad_norm": 0.4688261606016039, + "learning_rate": 0.0005061639842814328, + "loss": 4.9883, + "step": 281 + }, + { + "epoch": 0.28, + "grad_norm": 0.4082387881237382, + "learning_rate": 0.0005055137964282728, + "loss": 4.9492, + "step": 282 + }, + { + "epoch": 0.28, + "grad_norm": 0.4102411273145604, + "learning_rate": 0.0005048618411531315, + "loss": 4.9492, + "step": 283 + }, + { + "epoch": 0.28, + "grad_norm": 0.3333699558922032, + "learning_rate": 0.000504208125021191, + "loss": 4.9492, + "step": 284 + }, + { + "epoch": 0.28, + "grad_norm": 0.3014113897515229, + "learning_rate": 0.0005035526546153656, + "loss": 4.9922, + "step": 285 + }, + { + "epoch": 0.29, + "grad_norm": 0.33242045759712463, + "learning_rate": 0.000502895436536235, + "loss": 4.8906, + "step": 286 + }, + { + "epoch": 0.29, + "grad_norm": 0.27804952465315824, + "learning_rate": 0.000502236477401978, + "loss": 4.8828, + "step": 287 + }, + { + "epoch": 0.29, + "grad_norm": 0.346783453227663, + "learning_rate": 0.0005015757838483058, + "loss": 4.9453, + "step": 288 + }, + { + "epoch": 0.29, + "grad_norm": 0.33206265244928296, + "learning_rate": 0.000500913362528395, + "loss": 4.9102, + "step": 289 + }, + { + "epoch": 0.29, + "grad_norm": 0.31507543033475727, + "learning_rate": 0.000500249220112821, + "loss": 4.9336, + "step": 290 + }, + { + "epoch": 0.29, + "grad_norm": 0.34558992633865376, + "learning_rate": 0.0004995833632894907, + "loss": 4.8867, + "step": 291 + }, + { + "epoch": 0.29, + "grad_norm": 0.3596650694441014, + "learning_rate": 0.0004989157987635748, + "loss": 4.9141, + "step": 292 + }, + { + "epoch": 0.29, + "grad_norm": 0.26520540250540703, + "learning_rate": 0.0004982465332574405, + "loss": 4.9648, + "step": 293 + }, + { + "epoch": 0.29, + "grad_norm": 0.2957335916241638, + "learning_rate": 0.0004975755735105844, + "loss": 4.9297, + "step": 294 + }, + { + "epoch": 0.29, + "grad_norm": 0.33075169113632213, + "learning_rate": 0.0004969029262795634, + "loss": 4.9102, + "step": 295 + }, + { + "epoch": 0.3, + "grad_norm": 0.3588819230985392, + "learning_rate": 0.0004962285983379276, + "loss": 4.8672, + "step": 296 + }, + { + "epoch": 0.3, + "grad_norm": 0.3441202272395266, + "learning_rate": 0.0004955525964761522, + "loss": 4.8203, + "step": 297 + }, + { + "epoch": 0.3, + "grad_norm": 0.3150553179412103, + "learning_rate": 0.0004948749275015682, + "loss": 4.8945, + "step": 298 + }, + { + "epoch": 0.3, + "grad_norm": 0.31033579532429983, + "learning_rate": 0.0004941955982382948, + "loss": 4.9336, + "step": 299 + }, + { + "epoch": 0.3, + "grad_norm": 0.3118267914201189, + "learning_rate": 0.0004935146155271699, + "loss": 4.8125, + "step": 300 + }, + { + "epoch": 0.3, + "grad_norm": 0.2864784262575031, + "learning_rate": 0.0004928319862256821, + "loss": 4.9141, + "step": 301 + }, + { + "epoch": 0.3, + "grad_norm": 0.3461510509649134, + "learning_rate": 0.0004921477172079008, + "loss": 4.8789, + "step": 302 + }, + { + "epoch": 0.3, + "grad_norm": 0.344033222130809, + "learning_rate": 0.0004914618153644073, + "loss": 4.8477, + "step": 303 + }, + { + "epoch": 0.3, + "grad_norm": 0.3659600581082481, + "learning_rate": 0.0004907742876022257, + "loss": 4.8945, + "step": 304 + }, + { + "epoch": 0.3, + "grad_norm": 0.4139094159706144, + "learning_rate": 0.0004900851408447529, + "loss": 4.875, + "step": 305 + }, + { + "epoch": 0.31, + "grad_norm": 0.4555261722963931, + "learning_rate": 0.000489394382031689, + "loss": 4.8867, + "step": 306 + }, + { + "epoch": 0.31, + "grad_norm": 0.3920206505325869, + "learning_rate": 0.0004887020181189677, + "loss": 4.9453, + "step": 307 + }, + { + "epoch": 0.31, + "grad_norm": 0.3534794151350395, + "learning_rate": 0.00048800805607868586, + "loss": 4.8398, + "step": 308 + }, + { + "epoch": 0.31, + "grad_norm": 0.33621706530080386, + "learning_rate": 0.00048731250289903356, + "loss": 4.8086, + "step": 309 + }, + { + "epoch": 0.31, + "grad_norm": 0.3529975741380056, + "learning_rate": 0.0004866153655842235, + "loss": 4.8359, + "step": 310 + }, + { + "epoch": 0.31, + "grad_norm": 0.3762696746105136, + "learning_rate": 0.00048591665115442067, + "loss": 4.8672, + "step": 311 + }, + { + "epoch": 0.31, + "grad_norm": 0.37045506175739756, + "learning_rate": 0.00048521636664567195, + "loss": 4.7266, + "step": 312 + }, + { + "epoch": 0.31, + "grad_norm": 0.30161316500390495, + "learning_rate": 0.0004845145191098342, + "loss": 4.8242, + "step": 313 + }, + { + "epoch": 0.31, + "grad_norm": 0.3293026838774481, + "learning_rate": 0.00048381111561450447, + "loss": 4.8906, + "step": 314 + }, + { + "epoch": 0.31, + "grad_norm": 0.3172310803921575, + "learning_rate": 0.00048310616324294804, + "loss": 4.8164, + "step": 315 + }, + { + "epoch": 0.32, + "grad_norm": 0.3173002282317345, + "learning_rate": 0.00048239966909402763, + "loss": 4.793, + "step": 316 + }, + { + "epoch": 0.32, + "grad_norm": 0.31178084115992527, + "learning_rate": 0.00048169164028213137, + "loss": 4.7695, + "step": 317 + }, + { + "epoch": 0.32, + "grad_norm": 0.314811454990842, + "learning_rate": 0.00048098208393710154, + "loss": 4.7578, + "step": 318 + }, + { + "epoch": 0.32, + "grad_norm": 0.30440007615630466, + "learning_rate": 0.0004802710072041628, + "loss": 4.8359, + "step": 319 + }, + { + "epoch": 0.32, + "grad_norm": 0.3210092502981338, + "learning_rate": 0.00047955841724384976, + "loss": 4.8203, + "step": 320 + }, + { + "epoch": 0.32, + "grad_norm": 0.3440438445158875, + "learning_rate": 0.00047884432123193545, + "loss": 4.75, + "step": 321 + }, + { + "epoch": 0.32, + "grad_norm": 0.3193345492738186, + "learning_rate": 0.0004781287263593589, + "loss": 4.7148, + "step": 322 + }, + { + "epoch": 0.32, + "grad_norm": 0.30922502822632203, + "learning_rate": 0.00047741163983215233, + "loss": 4.793, + "step": 323 + }, + { + "epoch": 0.32, + "grad_norm": 0.3148665346557643, + "learning_rate": 0.0004766930688713693, + "loss": 4.75, + "step": 324 + }, + { + "epoch": 0.32, + "grad_norm": 0.42986342556682483, + "learning_rate": 0.00047597302071301136, + "loss": 4.7539, + "step": 325 + }, + { + "epoch": 0.33, + "grad_norm": 0.5024276449767565, + "learning_rate": 0.00047525150260795536, + "loss": 4.7656, + "step": 326 + }, + { + "epoch": 0.33, + "grad_norm": 0.46027171009254003, + "learning_rate": 0.00047452852182188073, + "loss": 4.8281, + "step": 327 + }, + { + "epoch": 0.33, + "grad_norm": 0.39490019666145704, + "learning_rate": 0.00047380408563519596, + "loss": 4.75, + "step": 328 + }, + { + "epoch": 0.33, + "grad_norm": 0.41799265356429827, + "learning_rate": 0.0004730782013429653, + "loss": 4.7266, + "step": 329 + }, + { + "epoch": 0.33, + "grad_norm": 0.32777355310017275, + "learning_rate": 0.0004723508762548356, + "loss": 4.7461, + "step": 330 + }, + { + "epoch": 0.33, + "grad_norm": 0.3248196972872973, + "learning_rate": 0.00047162211769496244, + "loss": 4.7227, + "step": 331 + }, + { + "epoch": 0.33, + "grad_norm": 0.28547358165842623, + "learning_rate": 0.00047089193300193637, + "loss": 4.7578, + "step": 332 + }, + { + "epoch": 0.33, + "grad_norm": 0.31153343447725707, + "learning_rate": 0.00047016032952870924, + "loss": 4.668, + "step": 333 + }, + { + "epoch": 0.33, + "grad_norm": 0.2760372416423509, + "learning_rate": 0.0004694273146425197, + "loss": 4.6758, + "step": 334 + }, + { + "epoch": 0.33, + "grad_norm": 0.29463081650056205, + "learning_rate": 0.0004686928957248197, + "loss": 4.6562, + "step": 335 + }, + { + "epoch": 0.34, + "grad_norm": 0.29189633228184125, + "learning_rate": 0.0004679570801711995, + "loss": 4.6914, + "step": 336 + }, + { + "epoch": 0.34, + "grad_norm": 0.26982925360638704, + "learning_rate": 0.00046721987539131364, + "loss": 4.7148, + "step": 337 + }, + { + "epoch": 0.34, + "grad_norm": 0.29787862686017463, + "learning_rate": 0.00046648128880880595, + "loss": 4.6602, + "step": 338 + }, + { + "epoch": 0.34, + "grad_norm": 0.3187987682957709, + "learning_rate": 0.00046574132786123527, + "loss": 4.7031, + "step": 339 + }, + { + "epoch": 0.34, + "grad_norm": 0.3585486863686879, + "learning_rate": 0.00046499999999999997, + "loss": 4.5742, + "step": 340 + }, + { + "epoch": 0.34, + "grad_norm": 0.3259002331050169, + "learning_rate": 0.0004642573126902635, + "loss": 4.7148, + "step": 341 + }, + { + "epoch": 0.34, + "grad_norm": 0.3275487016909797, + "learning_rate": 0.0004635132734108787, + "loss": 4.6797, + "step": 342 + }, + { + "epoch": 0.34, + "grad_norm": 0.3583530324503953, + "learning_rate": 0.000462767889654313, + "loss": 4.7422, + "step": 343 + }, + { + "epoch": 0.34, + "grad_norm": 0.3632248660395384, + "learning_rate": 0.00046202116892657245, + "loss": 4.6641, + "step": 344 + }, + { + "epoch": 0.34, + "grad_norm": 0.38122548821869745, + "learning_rate": 0.00046127311874712655, + "loss": 4.6836, + "step": 345 + }, + { + "epoch": 0.35, + "grad_norm": 0.36808830716304386, + "learning_rate": 0.0004605237466488322, + "loss": 4.6641, + "step": 346 + }, + { + "epoch": 0.35, + "grad_norm": 0.3137223061843467, + "learning_rate": 0.0004597730601778582, + "loss": 4.7422, + "step": 347 + }, + { + "epoch": 0.35, + "grad_norm": 0.3045761075947892, + "learning_rate": 0.00045902106689360903, + "loss": 4.625, + "step": 348 + }, + { + "epoch": 0.35, + "grad_norm": 0.34040203496490573, + "learning_rate": 0.0004582677743686486, + "loss": 4.6445, + "step": 349 + }, + { + "epoch": 0.35, + "grad_norm": 0.3456455549160395, + "learning_rate": 0.00045751319018862434, + "loss": 4.6875, + "step": 350 + }, + { + "epoch": 0.35, + "grad_norm": 0.39535233067020553, + "learning_rate": 0.00045675732195219046, + "loss": 4.5625, + "step": 351 + }, + { + "epoch": 0.35, + "grad_norm": 0.36300954605061336, + "learning_rate": 0.00045600017727093185, + "loss": 4.625, + "step": 352 + }, + { + "epoch": 0.35, + "grad_norm": 0.34321074019054293, + "learning_rate": 0.000455241763769287, + "loss": 4.6875, + "step": 353 + }, + { + "epoch": 0.35, + "grad_norm": 0.4028061367150823, + "learning_rate": 0.00045448208908447144, + "loss": 4.5664, + "step": 354 + }, + { + "epoch": 0.35, + "grad_norm": 0.4408851349392791, + "learning_rate": 0.00045372116086640074, + "loss": 4.6211, + "step": 355 + }, + { + "epoch": 0.36, + "grad_norm": 0.4191120056165356, + "learning_rate": 0.00045295898677761377, + "loss": 4.5781, + "step": 356 + }, + { + "epoch": 0.36, + "grad_norm": 0.3613069682757734, + "learning_rate": 0.00045219557449319506, + "loss": 4.6133, + "step": 357 + }, + { + "epoch": 0.36, + "grad_norm": 0.31397596134892436, + "learning_rate": 0.0004514309317006977, + "loss": 4.5039, + "step": 358 + }, + { + "epoch": 0.36, + "grad_norm": 0.2903352888009877, + "learning_rate": 0.00045066506610006633, + "loss": 4.6328, + "step": 359 + }, + { + "epoch": 0.36, + "grad_norm": 0.3038739076874848, + "learning_rate": 0.000449897985403559, + "loss": 4.6289, + "step": 360 + }, + { + "epoch": 0.36, + "grad_norm": 0.27322755655814623, + "learning_rate": 0.00044912969733566967, + "loss": 4.6484, + "step": 361 + }, + { + "epoch": 0.36, + "grad_norm": 0.2832382706505613, + "learning_rate": 0.0004483602096330509, + "loss": 4.6328, + "step": 362 + }, + { + "epoch": 0.36, + "grad_norm": 0.28631598952989207, + "learning_rate": 0.0004475895300444351, + "loss": 4.6094, + "step": 363 + }, + { + "epoch": 0.36, + "grad_norm": 0.2987917584998194, + "learning_rate": 0.0004468176663305572, + "loss": 4.6133, + "step": 364 + }, + { + "epoch": 0.36, + "grad_norm": 0.2967880899806566, + "learning_rate": 0.0004460446262640763, + "loss": 4.5117, + "step": 365 + }, + { + "epoch": 0.37, + "grad_norm": 0.3131220348773598, + "learning_rate": 0.0004452704176294972, + "loss": 4.5625, + "step": 366 + }, + { + "epoch": 0.37, + "grad_norm": 0.351287172376684, + "learning_rate": 0.00044449504822309245, + "loss": 4.5859, + "step": 367 + }, + { + "epoch": 0.37, + "grad_norm": 0.41479718509121755, + "learning_rate": 0.0004437185258528231, + "loss": 4.6133, + "step": 368 + }, + { + "epoch": 0.37, + "grad_norm": 0.500858047476452, + "learning_rate": 0.00044294085833826105, + "loss": 4.5195, + "step": 369 + }, + { + "epoch": 0.37, + "grad_norm": 0.44264696668268316, + "learning_rate": 0.00044216205351050935, + "loss": 4.5273, + "step": 370 + }, + { + "epoch": 0.37, + "grad_norm": 0.3491811100518669, + "learning_rate": 0.000441382119212124, + "loss": 4.4492, + "step": 371 + }, + { + "epoch": 0.37, + "grad_norm": 0.39094065310443904, + "learning_rate": 0.0004406010632970348, + "loss": 4.5938, + "step": 372 + }, + { + "epoch": 0.37, + "grad_norm": 0.42589950669151017, + "learning_rate": 0.00043981889363046604, + "loss": 4.4766, + "step": 373 + }, + { + "epoch": 0.37, + "grad_norm": 0.3735169848840402, + "learning_rate": 0.0004390356180888577, + "loss": 4.4961, + "step": 374 + }, + { + "epoch": 0.37, + "grad_norm": 0.3489922132224555, + "learning_rate": 0.00043825124455978563, + "loss": 4.5781, + "step": 375 + }, + { + "epoch": 0.38, + "grad_norm": 0.32411356449145856, + "learning_rate": 0.00043746578094188283, + "loss": 4.5273, + "step": 376 + }, + { + "epoch": 0.38, + "grad_norm": 0.3072457144479494, + "learning_rate": 0.0004366792351447589, + "loss": 4.5859, + "step": 377 + }, + { + "epoch": 0.38, + "grad_norm": 0.30778933355611565, + "learning_rate": 0.00043589161508892146, + "loss": 4.5391, + "step": 378 + }, + { + "epoch": 0.38, + "grad_norm": 0.30422537554955104, + "learning_rate": 0.0004351029287056957, + "loss": 4.5117, + "step": 379 + }, + { + "epoch": 0.38, + "grad_norm": 0.28548018463401653, + "learning_rate": 0.0004343131839371447, + "loss": 4.457, + "step": 380 + }, + { + "epoch": 0.38, + "grad_norm": 0.2987254595952286, + "learning_rate": 0.00043352238873598957, + "loss": 4.4531, + "step": 381 + }, + { + "epoch": 0.38, + "grad_norm": 0.32044706774364184, + "learning_rate": 0.0004327305510655292, + "loss": 4.5, + "step": 382 + }, + { + "epoch": 0.38, + "grad_norm": 0.3835631814807335, + "learning_rate": 0.0004319376788995602, + "loss": 4.5273, + "step": 383 + }, + { + "epoch": 0.38, + "grad_norm": 0.4987353932481338, + "learning_rate": 0.00043114378022229616, + "loss": 4.543, + "step": 384 + }, + { + "epoch": 0.38, + "grad_norm": 0.521234888411183, + "learning_rate": 0.00043034886302828837, + "loss": 4.5469, + "step": 385 + }, + { + "epoch": 0.39, + "grad_norm": 0.3367046977250759, + "learning_rate": 0.000429552935322344, + "loss": 4.5742, + "step": 386 + }, + { + "epoch": 0.39, + "grad_norm": 0.3574936591433018, + "learning_rate": 0.00042875600511944607, + "loss": 4.4805, + "step": 387 + }, + { + "epoch": 0.39, + "grad_norm": 0.3301760060022735, + "learning_rate": 0.000427958080444673, + "loss": 4.418, + "step": 388 + }, + { + "epoch": 0.39, + "grad_norm": 0.3209607031995241, + "learning_rate": 0.00042715916933311755, + "loss": 4.4375, + "step": 389 + }, + { + "epoch": 0.39, + "grad_norm": 0.2928558259134389, + "learning_rate": 0.00042635927982980534, + "loss": 4.4297, + "step": 390 + }, + { + "epoch": 0.39, + "grad_norm": 0.30495011187424853, + "learning_rate": 0.00042555841998961517, + "loss": 4.5586, + "step": 391 + }, + { + "epoch": 0.39, + "grad_norm": 0.29893525391168935, + "learning_rate": 0.00042475659787719663, + "loss": 4.4219, + "step": 392 + }, + { + "epoch": 0.39, + "grad_norm": 0.2662839149983714, + "learning_rate": 0.0004239538215668894, + "loss": 4.4648, + "step": 393 + }, + { + "epoch": 0.39, + "grad_norm": 0.2948338332982422, + "learning_rate": 0.000423150099142642, + "loss": 4.4336, + "step": 394 + }, + { + "epoch": 0.39, + "grad_norm": 0.3160166991490238, + "learning_rate": 0.0004223454386979305, + "loss": 4.5508, + "step": 395 + }, + { + "epoch": 0.4, + "grad_norm": 0.3094869555573116, + "learning_rate": 0.0004215398483356765, + "loss": 4.4453, + "step": 396 + }, + { + "epoch": 0.4, + "grad_norm": 0.2666983935113946, + "learning_rate": 0.00042073333616816607, + "loss": 4.5586, + "step": 397 + }, + { + "epoch": 0.4, + "grad_norm": 0.27831035395423853, + "learning_rate": 0.0004199259103169678, + "loss": 4.4141, + "step": 398 + }, + { + "epoch": 0.4, + "grad_norm": 0.3036373944797484, + "learning_rate": 0.00041911757891285086, + "loss": 4.4922, + "step": 399 + }, + { + "epoch": 0.4, + "grad_norm": 0.30775706426180766, + "learning_rate": 0.0004183083500957039, + "loss": 4.5078, + "step": 400 + } + ], + "logging_steps": 1, + "max_steps": 1000, + "num_input_tokens_seen": 0, + "num_train_epochs": 1, + "save_steps": 100, + "total_flos": 0.0, + "train_batch_size": 32, + "trial_name": null, + "trial_params": null +} diff --git a/checkpoint-400/training_args.bin b/checkpoint-400/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..09b35ec8ac2a16eb45febe1d655d456e47af68d1 --- /dev/null +++ b/checkpoint-400/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbd9a6067cf818494e2505097746a1cad30533fc72eb13916de34f7671e3e456 +size 6520 diff --git a/checkpoint-400/zero_to_fp32.py b/checkpoint-400/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/checkpoint-400/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/checkpoint-50/config.json b/checkpoint-50/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b59b0b4c67b30baa7b62a3a87fc086e8dd1f8916 --- /dev/null +++ b/checkpoint-50/config.json @@ -0,0 +1,31 @@ +{ + "_name_or_path": "georgeyw/gpt-2-small-init-seed-5", + "architectures": [ + "GPTNeoXForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "bos_token_id": 0, + "classifier_dropout": 0.1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 1024, + "model_type": "gpt_neox", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "rope_scaling": null, + "rotary_emb_base": 10000, + "rotary_pct": 0.25, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.2", + "use_cache": true, + "use_parallel_residual": true, + "vocab_size": 50304 +} diff --git a/checkpoint-50/model.safetensors b/checkpoint-50/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..3d13f0f5ed51f42abdcbde2ce29bcf09e51819ee --- /dev/null +++ b/checkpoint-50/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26298f93d2a7ec443d157b065f20cc189f5f2554f891ede0b2d0ea60d78bf87a +size 324662984 diff --git a/checkpoint-50/training_args.bin b/checkpoint-50/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..9362a9e736fc862ece575b9f1b9d54b14c10d0b5 --- /dev/null +++ b/checkpoint-50/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36ce7ab48fa86ef42491eaad3583773d2b60353997a5e7b6fb4ffc1414828749 +size 6520 diff --git a/checkpoint-75/config.json b/checkpoint-75/config.json new file mode 100644 index 0000000000000000000000000000000000000000..b59b0b4c67b30baa7b62a3a87fc086e8dd1f8916 --- /dev/null +++ b/checkpoint-75/config.json @@ -0,0 +1,31 @@ +{ + "_name_or_path": "georgeyw/gpt-2-small-init-seed-5", + "architectures": [ + "GPTNeoXForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "bos_token_id": 0, + "classifier_dropout": 0.1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 1024, + "model_type": "gpt_neox", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "rope_scaling": null, + "rotary_emb_base": 10000, + "rotary_pct": 0.25, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.2", + "use_cache": true, + "use_parallel_residual": true, + "vocab_size": 50304 +} diff --git a/checkpoint-75/model.safetensors b/checkpoint-75/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..5d66d8527add355e39fa6c43d70d3be4a62a82c6 --- /dev/null +++ b/checkpoint-75/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b257dcf6c7ab062da763314d5ac648ac7e0a2c99289dc786c8b399a83efd0d8e +size 324662984 diff --git a/checkpoint-75/training_args.bin b/checkpoint-75/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..9362a9e736fc862ece575b9f1b9d54b14c10d0b5 --- /dev/null +++ b/checkpoint-75/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36ce7ab48fa86ef42491eaad3583773d2b60353997a5e7b6fb4ffc1414828749 +size 6520 diff --git a/config.json b/config.json new file mode 100644 index 0000000000000000000000000000000000000000..c716bd649b1683f55e5cb589034b4f23a4dc7ff6 --- /dev/null +++ b/config.json @@ -0,0 +1,30 @@ +{ + "architectures": [ + "GPTNeoXForCausalLM" + ], + "attention_bias": true, + "attention_dropout": 0.0, + "bos_token_id": 0, + "classifier_dropout": 0.1, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout": 0.0, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-05, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 1024, + "model_type": "gpt_neox", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "rope_scaling": null, + "rotary_emb_base": 10000, + "rotary_pct": 0.25, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.2", + "use_cache": true, + "use_parallel_residual": true, + "vocab_size": 50304 +} diff --git a/model.safetensors b/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..7b574dbd8c52f35e8042482c0d255306dd99f54c --- /dev/null +++ b/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d93475419b628d509ae8d03ef561df2236942d6dfd4bcaddbaed95090eabcd3 +size 324662984 diff --git a/training_args.bin b/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..6aa6ec73d0d0c1ca27e7d57a6ae8d91d1ebbf4f8 --- /dev/null +++ b/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c61feea1d5eea2b62bd23aba372cbba309ed78e154f83a8ff6144ac09e9b8d0 +size 6648