diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..3b700f9387c4fb21b1dcab9161cf441a03ad8289 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,20 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +model-00002-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text +special_tokens_map.json filter=lfs diff=lfs merge=lfs -text +tokenizer_config.json filter=lfs diff=lfs merge=lfs -text +trainer_state.json filter=lfs diff=lfs merge=lfs -text +generation_config.json filter=lfs diff=lfs merge=lfs -text +merges.txt filter=lfs diff=lfs merge=lfs -text +model-00003-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text +tokenizer.json filter=lfs diff=lfs merge=lfs -text +training_args.bin filter=lfs diff=lfs merge=lfs -text +vocab.json filter=lfs diff=lfs merge=lfs -text +checkpoint-2500 filter=lfs diff=lfs merge=lfs -text +config.json filter=lfs diff=lfs merge=lfs -text +added_tokens.json filter=lfs diff=lfs merge=lfs -text +checkpoint-5000 filter=lfs diff=lfs merge=lfs -text +model-00001-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text +model-00004-of-00004.safetensors filter=lfs diff=lfs merge=lfs -text +model.safetensors.index.json filter=lfs diff=lfs merge=lfs -text diff --git a/added_tokens.json b/added_tokens.json new file mode 100644 index 0000000000000000000000000000000000000000..0d3573914a5e968f40a925a2577677f802a395e5 --- /dev/null +++ b/added_tokens.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f86f2a952b4888d195b8d77e3d56ad96864ecc696cca2155d0d4ac933fcfd55f +size 101 diff --git a/checkpoint-2500/added_tokens.json b/checkpoint-2500/added_tokens.json new file mode 100644 index 0000000000000000000000000000000000000000..0d3573914a5e968f40a925a2577677f802a395e5 --- /dev/null +++ b/checkpoint-2500/added_tokens.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f86f2a952b4888d195b8d77e3d56ad96864ecc696cca2155d0d4ac933fcfd55f +size 101 diff --git a/checkpoint-2500/config.json b/checkpoint-2500/config.json new file mode 100644 index 0000000000000000000000000000000000000000..93b56e80ef0ec4e6da4e5920918f813e4f35723c --- /dev/null +++ b/checkpoint-2500/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b0aa356f98509e29091097d9a97fdcec55f415127956bc889587850ca4723ea +size 3535 diff --git a/checkpoint-2500/generation_config.json b/checkpoint-2500/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..31fb5dc71c8e2284343cfd47aafcc44176fb585f --- /dev/null +++ b/checkpoint-2500/generation_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eea6735f058019c875227b53de54f34353f01bbdbb0e9d8fc6474567e1334c29 +size 248 diff --git a/checkpoint-2500/global_step2500/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-2500/global_step2500/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..cbcfe29a22349d05e035017c8f57bdd4f9c33bd9 --- /dev/null +++ b/checkpoint-2500/global_step2500/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dccccb881a4798a8179ce95c4c98235f15d7cbe21afcffcd93547cfb652dfc48 +size 23001093630 diff --git a/checkpoint-2500/global_step2500/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-2500/global_step2500/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..c2193a3123e55088b2342a4f69b8bf70431768e2 --- /dev/null +++ b/checkpoint-2500/global_step2500/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:078a14da915d0a1ed4180613b966f0e96585e54b55c8714bb9e92601d86e1100 +size 23001093630 diff --git a/checkpoint-2500/global_step2500/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/checkpoint-2500/global_step2500/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..4fa9c89cf55e2d8de5321e652c8b01e4a60e791c --- /dev/null +++ b/checkpoint-2500/global_step2500/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9dbcfb9298ddaf03bdf98d906f71a670336b10856888b0bb7331d563c644425a +size 23001093630 diff --git a/checkpoint-2500/global_step2500/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/checkpoint-2500/global_step2500/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..dbbb6847cc0b05983f8d997d5ba938c4ef99ffae --- /dev/null +++ b/checkpoint-2500/global_step2500/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fac00950cdc4d178fb814aac798c537941b4a40e4858ca6c8eea0aa889464850 +size 23001093630 diff --git a/checkpoint-2500/global_step2500/zero_pp_rank_0_mp_rank_00_model_states.pt b/checkpoint-2500/global_step2500/zero_pp_rank_0_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..378ad2de178c8847740e79216843c2b722e9dc78 --- /dev/null +++ b/checkpoint-2500/global_step2500/zero_pp_rank_0_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea35318387888658a9ec7f67d5d67fb574b1a24a5c63a87ad359c8d61074c207 +size 207917738 diff --git a/checkpoint-2500/global_step2500/zero_pp_rank_1_mp_rank_00_model_states.pt b/checkpoint-2500/global_step2500/zero_pp_rank_1_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..4acbb445f33483a77b125692f66f71c2fbfa7c94 --- /dev/null +++ b/checkpoint-2500/global_step2500/zero_pp_rank_1_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e9b4ca7c8d2725fabaddd1f30a3fd5f4ce570968f7d5cb2477d6cda8e7495ec +size 207917738 diff --git a/checkpoint-2500/global_step2500/zero_pp_rank_2_mp_rank_00_model_states.pt b/checkpoint-2500/global_step2500/zero_pp_rank_2_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..845db9b466155fd9123dbcfe6de00eb50fd7d0ee --- /dev/null +++ b/checkpoint-2500/global_step2500/zero_pp_rank_2_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f7fdde3d1609d461635a63c96ca18ce5b597321ea015725156b9162b0713c6e +size 207917738 diff --git a/checkpoint-2500/global_step2500/zero_pp_rank_3_mp_rank_00_model_states.pt b/checkpoint-2500/global_step2500/zero_pp_rank_3_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..b72e8fce4949c20bd3971c78b20923ae35e9734f --- /dev/null +++ b/checkpoint-2500/global_step2500/zero_pp_rank_3_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7647b71b17beba2bc9a7a58778fc3dfb07065bf8e09ac7e9ca6535293d7bf91c +size 207917738 diff --git a/checkpoint-2500/latest b/checkpoint-2500/latest new file mode 100644 index 0000000000000000000000000000000000000000..98f8bed9a5485ee900d9931cc06950de69499848 --- /dev/null +++ b/checkpoint-2500/latest @@ -0,0 +1 @@ +global_step2500 \ No newline at end of file diff --git a/checkpoint-2500/merges.txt b/checkpoint-2500/merges.txt new file mode 100644 index 0000000000000000000000000000000000000000..80c1a19fae38f8f4c9ab32cc9d4e145c241147e6 --- /dev/null +++ b/checkpoint-2500/merges.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8831e4f1a044471340f7c0a83d7bd71306a5b867e95fd870f74d0c5308a904d5 +size 1671853 diff --git a/checkpoint-2500/model-00001-of-00004.safetensors b/checkpoint-2500/model-00001-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..41f8b7000e3c9258919ab66ffbcb14e258883aa6 --- /dev/null +++ b/checkpoint-2500/model-00001-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:886643e79041530580e77bd781bce5e96c7b6381282b8dc98f56d10666a17cf3 +size 4877668032 diff --git a/checkpoint-2500/model-00002-of-00004.safetensors b/checkpoint-2500/model-00002-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..199832d2618b88015101fa6d24563deeb53c3ff8 --- /dev/null +++ b/checkpoint-2500/model-00002-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc41ab6b9c93602507e952dcd2624bb06e1420fa2beedbbe9424891af77af37f +size 4932751008 diff --git a/checkpoint-2500/model-00003-of-00004.safetensors b/checkpoint-2500/model-00003-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..01bdd35576f4a7a36bf0902598128fb5b0f76d83 --- /dev/null +++ b/checkpoint-2500/model-00003-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44c31eac059d021e74ee49ac7079c954c6e76593d0429eb210c4c0f3c98c226e +size 4994571904 diff --git a/checkpoint-2500/model-00004-of-00004.safetensors b/checkpoint-2500/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..4b401ed8a4c1c63bc3f602bb5b8db1f252aae891 --- /dev/null +++ b/checkpoint-2500/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:14620fe3fcbd31f18a1435fa968ce5cd2013b7836aec54242f57cfffdcaa6931 +size 1358630840 diff --git a/checkpoint-2500/model.safetensors.index.json b/checkpoint-2500/model.safetensors.index.json new file mode 100644 index 0000000000000000000000000000000000000000..09affe6e92e41092d0a0694cabaa81f499f7c9ba --- /dev/null +++ b/checkpoint-2500/model.safetensors.index.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de85b54d148ac26390347ed88b2b053e70b1f2c065023a9f7a05612b5b03e5cb +size 81426 diff --git a/checkpoint-2500/rng_state_0.pth b/checkpoint-2500/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..906dd6ee7eb8b62c9f27cfaf7dd8b22f23a3d19f --- /dev/null +++ b/checkpoint-2500/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4db6cccd20d5397e0e0cd02bdf3bbcf3948b1ea4f484ae2a9ecdc99549f37ee5 +size 15984 diff --git a/checkpoint-2500/rng_state_1.pth b/checkpoint-2500/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..ad38ef9e3ffbe2e7384880e865036363b96001d7 --- /dev/null +++ b/checkpoint-2500/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e481c7d820a38cdda48f1119bc0f1547a9d130f968495bb0587649f9fc6a96bb +size 15984 diff --git a/checkpoint-2500/rng_state_2.pth b/checkpoint-2500/rng_state_2.pth new file mode 100644 index 0000000000000000000000000000000000000000..518bc9125909c591addf05357229bf42c28fced4 --- /dev/null +++ b/checkpoint-2500/rng_state_2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6174572a2351be4990250de58417bd90f928ca0f7ff923ff268a38fb52ba2009 +size 15984 diff --git a/checkpoint-2500/rng_state_3.pth b/checkpoint-2500/rng_state_3.pth new file mode 100644 index 0000000000000000000000000000000000000000..e18b1a75e3fce9fa25702890e4f9a42a86b993e2 --- /dev/null +++ b/checkpoint-2500/rng_state_3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6345f8e8c88b711bdc7dbce9f86986b2ceaef14138cc8c33f6acdbbd5bb21cb5 +size 15984 diff --git a/checkpoint-2500/scheduler.pt b/checkpoint-2500/scheduler.pt new file mode 100644 index 0000000000000000000000000000000000000000..f7ee06d345996a502ae2a23ad72d5954b830e447 --- /dev/null +++ b/checkpoint-2500/scheduler.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0c1f1debc73316b98daa7a20c743b9392d4332257e03130ff987498b8d24819b +size 1064 diff --git a/checkpoint-2500/special_tokens_map.json b/checkpoint-2500/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..9b4b7305715f2ea949deee2af90c1341c29dc30d --- /dev/null +++ b/checkpoint-2500/special_tokens_map.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4f79e08d97f4d1c87f8d89264f525c8789da3b73b3bb55d1e12f692f41a7b1b +size 367 diff --git a/checkpoint-2500/tokenizer.json b/checkpoint-2500/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..782308d2969d60592f406fc350f5f4b52aa4c231 --- /dev/null +++ b/checkpoint-2500/tokenizer.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eff4e7d2862ab8b716a62d06f298c8978e21d9525cc2fdced5e4dd85707a3241 +size 7028199 diff --git a/checkpoint-2500/tokenizer_config.json b/checkpoint-2500/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..f1ca16411d98e72ef64d9cdb65dbd6a838a22956 --- /dev/null +++ b/checkpoint-2500/tokenizer_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ae8d5ad312360a6efd9b24299ca74b0f8342b41fe82799130a958c7000a92ff +size 1539 diff --git a/checkpoint-2500/trainer_state.json b/checkpoint-2500/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..6440418cc25de1863d90c3cfe8ebe6d0f3d9d827 --- /dev/null +++ b/checkpoint-2500/trainer_state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:384677216e81874ab727b0fabba4a8f6d9e35475ba24d6e16f98470392a4991c +size 396638 diff --git a/checkpoint-2500/training_args.bin b/checkpoint-2500/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..065d307a39cebfbe6b68064b23e0237976671ca2 --- /dev/null +++ b/checkpoint-2500/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9ce4d4cc6b8e61a8bb732af6de64c0da8cc02dadd47772b2f1993a571f7eae6 +size 7864 diff --git a/checkpoint-2500/vocab.json b/checkpoint-2500/vocab.json new file mode 100644 index 0000000000000000000000000000000000000000..6c49fc63bcb109de13abe49e58f85a4cdba7b679 --- /dev/null +++ b/checkpoint-2500/vocab.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca10d7e9fb3ed18575dd1e277a2579c16d108e32f27439684afa0e10b1440910 +size 2776833 diff --git a/checkpoint-2500/zero_to_fp32.py b/checkpoint-2500/zero_to_fp32.py new file mode 100755 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/checkpoint-2500/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/checkpoint-5000/added_tokens.json b/checkpoint-5000/added_tokens.json new file mode 100644 index 0000000000000000000000000000000000000000..0d3573914a5e968f40a925a2577677f802a395e5 --- /dev/null +++ b/checkpoint-5000/added_tokens.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f86f2a952b4888d195b8d77e3d56ad96864ecc696cca2155d0d4ac933fcfd55f +size 101 diff --git a/checkpoint-5000/config.json b/checkpoint-5000/config.json new file mode 100644 index 0000000000000000000000000000000000000000..93b56e80ef0ec4e6da4e5920918f813e4f35723c --- /dev/null +++ b/checkpoint-5000/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b0aa356f98509e29091097d9a97fdcec55f415127956bc889587850ca4723ea +size 3535 diff --git a/checkpoint-5000/generation_config.json b/checkpoint-5000/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..31fb5dc71c8e2284343cfd47aafcc44176fb585f --- /dev/null +++ b/checkpoint-5000/generation_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eea6735f058019c875227b53de54f34353f01bbdbb0e9d8fc6474567e1334c29 +size 248 diff --git a/checkpoint-5000/global_step5000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-5000/global_step5000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..b7d3c35f37e16b89bf42c1fee690d1681c7ece02 --- /dev/null +++ b/checkpoint-5000/global_step5000/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca96c78fca22e84d0b0c4eca58c62f4b0c182648500654167def656aaedfca86 +size 23001093630 diff --git a/checkpoint-5000/global_step5000/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-5000/global_step5000/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..29ecf85862e3d72ff0843823f6bd0a8992754d97 --- /dev/null +++ b/checkpoint-5000/global_step5000/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdc63176954a62641c9104f824b52c860b72ffb8a6ee3ba8704a1ebe0ae5b70d +size 23001093630 diff --git a/checkpoint-5000/global_step5000/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/checkpoint-5000/global_step5000/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..0add4ad31e90d9e006655587c9f3fd3709e7c587 --- /dev/null +++ b/checkpoint-5000/global_step5000/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:599fc90213b0ff0e7028755ce3f78095691725b655d4ad2506f1fb48b876903d +size 23001093630 diff --git a/checkpoint-5000/global_step5000/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/checkpoint-5000/global_step5000/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..52985acf6f6ca28982b60b643ed8ce38eb128fa1 --- /dev/null +++ b/checkpoint-5000/global_step5000/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b692453e50495054423ac456517d067495c357d0459e74162830e8cf112cac33 +size 23001093630 diff --git a/checkpoint-5000/global_step5000/zero_pp_rank_0_mp_rank_00_model_states.pt b/checkpoint-5000/global_step5000/zero_pp_rank_0_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..013fd027189ebfd788f3749528e1492c5908ad17 --- /dev/null +++ b/checkpoint-5000/global_step5000/zero_pp_rank_0_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b9625bcfc4cfb86f99bcc77c9dce84288b38727b44c478ad669fa23e60e71e1 +size 207917738 diff --git a/checkpoint-5000/global_step5000/zero_pp_rank_1_mp_rank_00_model_states.pt b/checkpoint-5000/global_step5000/zero_pp_rank_1_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..d7cd35420f69cfe6eb2dcdf089e5ff6e82006c5e --- /dev/null +++ b/checkpoint-5000/global_step5000/zero_pp_rank_1_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf3bb5b99a91b43fc306f56a0c1599c8af5970615fc79791ba3ef4b4f90315d0 +size 207917738 diff --git a/checkpoint-5000/global_step5000/zero_pp_rank_2_mp_rank_00_model_states.pt b/checkpoint-5000/global_step5000/zero_pp_rank_2_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..ceb49203640431532323d9f5da363e056c672441 --- /dev/null +++ b/checkpoint-5000/global_step5000/zero_pp_rank_2_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8e6a06380237dd58a1d3375e261ad01033ebc36bcbae0e9297918d94164f10ce +size 207917738 diff --git a/checkpoint-5000/global_step5000/zero_pp_rank_3_mp_rank_00_model_states.pt b/checkpoint-5000/global_step5000/zero_pp_rank_3_mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..f510fd6a4086ca6fe2ace7895241c44037a22e9e --- /dev/null +++ b/checkpoint-5000/global_step5000/zero_pp_rank_3_mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02ed0ab812c66e6f4fb3dc038f60c6cc1e0ab26bf514f59a6c030fc6675255ea +size 207917738 diff --git a/checkpoint-5000/latest b/checkpoint-5000/latest new file mode 100644 index 0000000000000000000000000000000000000000..f805186fa43374540c3fa51dfd3cca9ac06e56a5 --- /dev/null +++ b/checkpoint-5000/latest @@ -0,0 +1 @@ +global_step5000 \ No newline at end of file diff --git a/checkpoint-5000/merges.txt b/checkpoint-5000/merges.txt new file mode 100644 index 0000000000000000000000000000000000000000..80c1a19fae38f8f4c9ab32cc9d4e145c241147e6 --- /dev/null +++ b/checkpoint-5000/merges.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8831e4f1a044471340f7c0a83d7bd71306a5b867e95fd870f74d0c5308a904d5 +size 1671853 diff --git a/checkpoint-5000/model-00001-of-00004.safetensors b/checkpoint-5000/model-00001-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..539b1bc31dca01fd276eba2d543f75a1be2f28a2 --- /dev/null +++ b/checkpoint-5000/model-00001-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c15539f67e5ce925a8a5af7cf2139b29ae127a1decc70b0068d8bab52ef37020 +size 4877668032 diff --git a/checkpoint-5000/model-00002-of-00004.safetensors b/checkpoint-5000/model-00002-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..b9b7253c8f7dddb4186694efdf772a2b1d92a8a8 --- /dev/null +++ b/checkpoint-5000/model-00002-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:23a4f997106cb5fe5f91c8afc7bb608e39f315978b0b86c319a86d88fd483bf8 +size 4932751008 diff --git a/checkpoint-5000/model-00003-of-00004.safetensors b/checkpoint-5000/model-00003-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..d1f9e50d431753651f4e2e3c5ba7122ca6627755 --- /dev/null +++ b/checkpoint-5000/model-00003-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:637eaaa8c80ad02ab734f20867b48659f5bc003fd2d5ee6ed9b054fcda89beb3 +size 4994571904 diff --git a/checkpoint-5000/model-00004-of-00004.safetensors b/checkpoint-5000/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..ee340c74169f0d01bc455f41a9443f635bceac4e --- /dev/null +++ b/checkpoint-5000/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee8ea551058d56ac6fb9423a69d6ed9b967d85f231e4979c0bf52f1dde4d709d +size 1358630840 diff --git a/checkpoint-5000/model.safetensors.index.json b/checkpoint-5000/model.safetensors.index.json new file mode 100644 index 0000000000000000000000000000000000000000..09affe6e92e41092d0a0694cabaa81f499f7c9ba --- /dev/null +++ b/checkpoint-5000/model.safetensors.index.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de85b54d148ac26390347ed88b2b053e70b1f2c065023a9f7a05612b5b03e5cb +size 81426 diff --git a/checkpoint-5000/rng_state_0.pth b/checkpoint-5000/rng_state_0.pth new file mode 100644 index 0000000000000000000000000000000000000000..8fe24e1b432da9b8c20539b7795ec9d3bd164814 --- /dev/null +++ b/checkpoint-5000/rng_state_0.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ddce993b94c6dddffa3b6ad61845f3692872e47eef35ff1cbbc1e4b1c6f3bceb +size 15984 diff --git a/checkpoint-5000/rng_state_1.pth b/checkpoint-5000/rng_state_1.pth new file mode 100644 index 0000000000000000000000000000000000000000..d356fc635b197cc2991723ec033badcf1451dd89 --- /dev/null +++ b/checkpoint-5000/rng_state_1.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:065a5967865cf5405e2e1c2b4589d19bbcaf0b67a22eb6a89ed35e8a51717796 +size 15984 diff --git a/checkpoint-5000/rng_state_2.pth b/checkpoint-5000/rng_state_2.pth new file mode 100644 index 0000000000000000000000000000000000000000..0a85e10606f830a801f4486a4e47d8eacdffd5cb --- /dev/null +++ b/checkpoint-5000/rng_state_2.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2ecaaa26f24e622a3818fdae4b46cd62e98c683be19b3b76c940635b6e2f160 +size 15984 diff --git a/checkpoint-5000/rng_state_3.pth b/checkpoint-5000/rng_state_3.pth new file mode 100644 index 0000000000000000000000000000000000000000..ba45db674a5106495648cd958ff848f86559108d --- /dev/null +++ b/checkpoint-5000/rng_state_3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5668e598c32f79eb5a2ecf74a34dd4a22bb0e5ca6237daa09e5aad893daea46 +size 15984 diff --git a/checkpoint-5000/scheduler.pt b/checkpoint-5000/scheduler.pt new file mode 100644 index 0000000000000000000000000000000000000000..3b4966434a8ff5a89fe172f67973aeecde85dcee --- /dev/null +++ b/checkpoint-5000/scheduler.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:178b90171784f807a34b93abf47ce2963c0e51bb74e42f239733c7f0103ec145 +size 1064 diff --git a/checkpoint-5000/special_tokens_map.json b/checkpoint-5000/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..9b4b7305715f2ea949deee2af90c1341c29dc30d --- /dev/null +++ b/checkpoint-5000/special_tokens_map.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4f79e08d97f4d1c87f8d89264f525c8789da3b73b3bb55d1e12f692f41a7b1b +size 367 diff --git a/checkpoint-5000/tokenizer.json b/checkpoint-5000/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..782308d2969d60592f406fc350f5f4b52aa4c231 --- /dev/null +++ b/checkpoint-5000/tokenizer.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eff4e7d2862ab8b716a62d06f298c8978e21d9525cc2fdced5e4dd85707a3241 +size 7028199 diff --git a/checkpoint-5000/tokenizer_config.json b/checkpoint-5000/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..f1ca16411d98e72ef64d9cdb65dbd6a838a22956 --- /dev/null +++ b/checkpoint-5000/tokenizer_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ae8d5ad312360a6efd9b24299ca74b0f8342b41fe82799130a958c7000a92ff +size 1539 diff --git a/checkpoint-5000/trainer_state.json b/checkpoint-5000/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..db30bc83658da4b0ec299959aeb73f274be64051 --- /dev/null +++ b/checkpoint-5000/trainer_state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b67d8e637e5a2b3bb0d901d3ca37f2210c82813376d6f8d72ea1c1ef36ce14a +size 795008 diff --git a/checkpoint-5000/training_args.bin b/checkpoint-5000/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..065d307a39cebfbe6b68064b23e0237976671ca2 --- /dev/null +++ b/checkpoint-5000/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9ce4d4cc6b8e61a8bb732af6de64c0da8cc02dadd47772b2f1993a571f7eae6 +size 7864 diff --git a/checkpoint-5000/vocab.json b/checkpoint-5000/vocab.json new file mode 100644 index 0000000000000000000000000000000000000000..6c49fc63bcb109de13abe49e58f85a4cdba7b679 --- /dev/null +++ b/checkpoint-5000/vocab.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca10d7e9fb3ed18575dd1e277a2579c16d108e32f27439684afa0e10b1440910 +size 2776833 diff --git a/checkpoint-5000/zero_to_fp32.py b/checkpoint-5000/zero_to_fp32.py new file mode 100755 index 0000000000000000000000000000000000000000..24cc342e78d1a006c782b3a4cd68d9ce786d8fd8 --- /dev/null +++ b/checkpoint-5000/zero_to_fp32.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage <= 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # handle shared params + shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dict = torch.load(f, map_location=device) + # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights + # and also handle the case where it was already removed by another helper script + state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) + state_dicts.append(state_dict) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage <= 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage <= 2: + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage <= 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _has_callable(obj, fn): + attr = getattr(obj, fn, None) + return callable(attr) + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in zero_model_states[0].frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + if pair[1] in state_dict: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-t", + "--tag", + type=str, + default=None, + help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/config.json b/config.json new file mode 100644 index 0000000000000000000000000000000000000000..204b0380b4777522c51c90240df748278a0bfde2 --- /dev/null +++ b/config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de864edd81f36443d001a8c15b0adf65cf77af7efef1ed8875b39fd22384be99 +size 3534 diff --git a/generation_config.json b/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..31fb5dc71c8e2284343cfd47aafcc44176fb585f --- /dev/null +++ b/generation_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eea6735f058019c875227b53de54f34353f01bbdbb0e9d8fc6474567e1334c29 +size 248 diff --git a/merges.txt b/merges.txt new file mode 100644 index 0000000000000000000000000000000000000000..80c1a19fae38f8f4c9ab32cc9d4e145c241147e6 --- /dev/null +++ b/merges.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8831e4f1a044471340f7c0a83d7bd71306a5b867e95fd870f74d0c5308a904d5 +size 1671853 diff --git a/model-00001-of-00004.safetensors b/model-00001-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..d481e2d143ea49a24b0e554f165b38df126617c1 --- /dev/null +++ b/model-00001-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8c046c299e5ff9b49e694aca28d8fae279196e426037bd2a705d70d3caab806 +size 4877668032 diff --git a/model-00002-of-00004.safetensors b/model-00002-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..256a1db1931f17d26016d16d95aff9ddac41d568 --- /dev/null +++ b/model-00002-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9e9f49e7bd37795302e98f562d620bbfa61f7e79290e4498aaa704e39cfd88d +size 4932751008 diff --git a/model-00003-of-00004.safetensors b/model-00003-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..091aa0ce6a3c301ed9cca2653d6f6a551def41ae --- /dev/null +++ b/model-00003-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:856fd335f1f9dd44d63e63238cf45f6581589b83cb814bb2f045a71f8f8d68e6 +size 4994571904 diff --git a/model-00004-of-00004.safetensors b/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..0acd041bdedb85a5d3fdf851691fb93797aebfbe --- /dev/null +++ b/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aacc716741ea2cc16396ea82104724bb6bf9f42ea357e5dad10ddd4effb2b63d +size 1358630840 diff --git a/model.safetensors.index.json b/model.safetensors.index.json new file mode 100644 index 0000000000000000000000000000000000000000..09affe6e92e41092d0a0694cabaa81f499f7c9ba --- /dev/null +++ b/model.safetensors.index.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de85b54d148ac26390347ed88b2b053e70b1f2c065023a9f7a05612b5b03e5cb +size 81426 diff --git a/special_tokens_map.json b/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..9b4b7305715f2ea949deee2af90c1341c29dc30d --- /dev/null +++ b/special_tokens_map.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f4f79e08d97f4d1c87f8d89264f525c8789da3b73b3bb55d1e12f692f41a7b1b +size 367 diff --git a/tokenizer.json b/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..782308d2969d60592f406fc350f5f4b52aa4c231 --- /dev/null +++ b/tokenizer.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eff4e7d2862ab8b716a62d06f298c8978e21d9525cc2fdced5e4dd85707a3241 +size 7028199 diff --git a/tokenizer_config.json b/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..f1ca16411d98e72ef64d9cdb65dbd6a838a22956 --- /dev/null +++ b/tokenizer_config.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ae8d5ad312360a6efd9b24299ca74b0f8342b41fe82799130a958c7000a92ff +size 1539 diff --git a/trainer_state.json b/trainer_state.json new file mode 100644 index 0000000000000000000000000000000000000000..b83959c59b75c3f89b52964fdd7f054c3fa0c6b0 --- /dev/null +++ b/trainer_state.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6868cff51a5eba04a380d83c099b26969ec034d9c5fc39d51ba2822fb5d49633 +size 1033304 diff --git a/training_args.bin b/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..065d307a39cebfbe6b68064b23e0237976671ca2 --- /dev/null +++ b/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9ce4d4cc6b8e61a8bb732af6de64c0da8cc02dadd47772b2f1993a571f7eae6 +size 7864 diff --git a/vocab.json b/vocab.json new file mode 100644 index 0000000000000000000000000000000000000000..6c49fc63bcb109de13abe49e58f85a4cdba7b679 --- /dev/null +++ b/vocab.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca10d7e9fb3ed18575dd1e277a2579c16d108e32f27439684afa0e10b1440910 +size 2776833