Spaces:
Sleeping
Sleeping
| """Utility functions for training and inference.""" | |
| import math | |
| import pickle | |
| import sys | |
| from contextlib import nullcontext | |
| from io import BytesIO | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, ContextManager, Dict, List, Mapping, Optional, TypeVar, Union | |
| import lightning as L | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils._device | |
| from lightning.fabric.strategies import FSDPStrategy | |
| from lightning.fabric.utilities.load import _lazy_load as lazy_load | |
| from torch.serialization import normalize_storage_type | |
| if TYPE_CHECKING: | |
| from lit_gpt import GPT | |
| def find_multiple(n: int, k: int) -> int: | |
| assert k > 0 | |
| if n % k == 0: | |
| return n | |
| return n + k - (n % k) | |
| def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: | |
| total = 0 | |
| for p in module.parameters(): | |
| if requires_grad is None or p.requires_grad == requires_grad: | |
| if hasattr(p, "quant_state"): | |
| # bitsandbytes 4bit layer support | |
| total += math.prod(p.quant_state[1]) | |
| else: | |
| total += p.numel() | |
| return total | |
| def gptq_quantization(enabled: bool = False) -> ContextManager: | |
| if not enabled: | |
| return nullcontext() | |
| from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager | |
| from quantize.gptq import ColBlockQuantizedLinear | |
| class QuantizedLinear(ColBlockQuantizedLinear): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, bits=4, tile_cols=-1, **kwargs) | |
| return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear}) | |
| def check_valid_checkpoint_dir(checkpoint_dir: Path, model_name: str) -> None: | |
| if model_name == "pythia_160m_deduped_huggingface": | |
| selected_model_name = "pythia_160m_deduped_hf.pth" | |
| elif model_name == "pythia_160m_deduped_custom": | |
| selected_model_name = "pythia_160m_deduped_custom.pth" | |
| else: | |
| selected_model_name = "lit_model.pth" | |
| files = { | |
| "lit_model.pth": (checkpoint_dir / selected_model_name).is_file(), | |
| "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), | |
| "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( | |
| checkpoint_dir / "tokenizer.model" | |
| ).is_file(), | |
| "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), | |
| } | |
| if checkpoint_dir.is_dir(): | |
| if all(files.values()): | |
| # we're good | |
| return | |
| problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" | |
| else: | |
| problem = " is not a checkpoint directory" | |
| # list locally available checkpoints | |
| available = list(Path("checkpoints").glob("*/*")) | |
| if available: | |
| options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) | |
| extra = f"\nYou have downloaded locally:{options}\n" | |
| else: | |
| extra = "" | |
| error_message = ( | |
| f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." | |
| "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" | |
| f"{extra}\nSee all download options by running:\n python scripts/download.py" | |
| ) | |
| print(error_message, file=sys.stderr) | |
| raise SystemExit(1) | |
| class SavingProxyForStorage: | |
| def __init__(self, obj, saver, protocol_version=5): | |
| self.protocol_version = protocol_version | |
| self.saver = saver | |
| if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): | |
| raise TypeError(f"expected storage, not {type(obj)}") | |
| # this logic is taken from PyTorch 2.0+ torch/serialization.py | |
| if isinstance(obj, torch.storage.TypedStorage): | |
| # PT upstream wants to deprecate this eventually... | |
| storage = obj._untyped_storage | |
| storage_type_str = obj._pickle_storage_type() | |
| storage_type = getattr(torch, storage_type_str) | |
| storage_numel = obj._size() | |
| else: | |
| storage = obj | |
| storage_type = normalize_storage_type(type(obj)) | |
| storage_numel = storage.nbytes() | |
| storage_key = saver._write_storage_and_return_key(storage) | |
| location = torch.serialization.location_tag(storage) | |
| self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) | |
| def __reduce_ex__(self, protocol_version): | |
| assert False, "this should be handled with out of band" | |
| class SavingProxyForTensor: | |
| def __init__(self, tensor, saver, protocol_version=5): | |
| self.protocol_version = protocol_version | |
| self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version) | |
| if reduce_args[0] == torch._utils._rebuild_tensor_v2: | |
| # for Tensors with Python attributes | |
| (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args | |
| assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" | |
| storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) | |
| self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args) | |
| else: | |
| (storage, *other_reduce_args) = reduce_args | |
| assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" | |
| storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) | |
| self.reduce_args = (storage_proxy, *other_reduce_args) | |
| def __reduce_ex__(self, protocol_version): | |
| if protocol_version != self.protocol_version: | |
| raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") | |
| return self.reduce_ret_fn, self.reduce_args | |
| class IncrementalPyTorchPickler(pickle.Pickler): | |
| def __init__(self, saver, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.storage_dtypes = {} | |
| self.saver = saver | |
| self.id_map = {} | |
| # this logic is taken from PyTorch 2.0+ torch/serialization.py | |
| def persistent_id(self, obj): | |
| # FIXME: the docs say that persistent_id should only return a string | |
| # but torch store returns tuples. This works only in the binary protocol | |
| # see | |
| # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects | |
| # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 | |
| if isinstance(obj, SavingProxyForStorage): | |
| return obj.storage_info | |
| if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): | |
| if isinstance(obj, torch.storage.TypedStorage): | |
| # TODO: Once we decide to break serialization FC, this case | |
| # can be deleted | |
| storage = obj._untyped_storage | |
| storage_dtype = obj.dtype | |
| storage_type_str = obj._pickle_storage_type() | |
| storage_type = getattr(torch, storage_type_str) | |
| storage_numel = obj._size() | |
| else: | |
| storage = obj | |
| storage_dtype = torch.uint8 | |
| storage_type = normalize_storage_type(type(obj)) | |
| storage_numel = storage.nbytes() | |
| # If storage is allocated, ensure that any other saved storages | |
| # pointing to the same data all have the same dtype. If storage is | |
| # not allocated, don't perform this check | |
| if storage.data_ptr() != 0: | |
| if storage.data_ptr() in self.storage_dtypes: | |
| if storage_dtype != self.storage_dtypes[storage.data_ptr()]: | |
| raise RuntimeError( | |
| "Cannot save multiple tensors or storages that view the same data as different types" | |
| ) | |
| else: | |
| self.storage_dtypes[storage.data_ptr()] = storage_dtype | |
| storage_key = self.id_map.get(storage._cdata) | |
| if storage_key is None: | |
| storage_key = self.saver._write_storage_and_return_key(storage) | |
| self.id_map[storage._cdata] = storage_key | |
| location = torch.serialization.location_tag(storage) | |
| return ("storage", storage_type, storage_key, location, storage_numel) | |
| return None | |
| class incremental_save: | |
| def __init__(self, name): | |
| self.name = name | |
| self.zipfile = torch._C.PyTorchFileWriter(str(name)) | |
| self.has_saved = False | |
| self.next_key = 0 | |
| def __enter__(self): | |
| return self | |
| def store_early(self, tensor): | |
| if isinstance(tensor, torch.Tensor): | |
| return SavingProxyForTensor(tensor, self) | |
| raise TypeError(f"can only store tensors early, not {type(tensor)}") | |
| def save(self, obj): | |
| if self.has_saved: | |
| raise RuntimeError("have already saved") | |
| # Write the pickle data for `obj` | |
| data_buf = BytesIO() | |
| pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) | |
| pickler.dump(obj) | |
| data_value = data_buf.getvalue() | |
| self.zipfile.write_record("data.pkl", data_value, len(data_value)) | |
| self.has_saved = True | |
| def _write_storage_and_return_key(self, storage): | |
| if self.has_saved: | |
| raise RuntimeError("have already saved") | |
| key = self.next_key | |
| self.next_key += 1 | |
| name = f"data/{key}" | |
| if storage.device.type != "cpu": | |
| storage = storage.cpu() | |
| num_bytes = storage.nbytes() | |
| self.zipfile.write_record(name, storage.data_ptr(), num_bytes) | |
| return key | |
| def __exit__(self, type, value, traceback): | |
| self.zipfile.write_end_of_file() | |
| T = TypeVar("T") | |
| def chunked_cross_entropy( | |
| logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 | |
| ) -> torch.Tensor: | |
| # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate | |
| # the memory usage in fine-tuning settings with low number of parameters. | |
| # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing | |
| # the memory spike's magnitude | |
| # lm_head was chunked (we are fine-tuning) | |
| if isinstance(logits, list): | |
| # don't want to chunk cross entropy | |
| if chunk_size == 0: | |
| logits = torch.cat(logits, dim=1) | |
| logits = logits.reshape(-1, logits.size(-1)) | |
| targets = targets.reshape(-1) | |
| return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) | |
| # chunk cross entropy | |
| logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] | |
| target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] | |
| loss_chunks = [ | |
| torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") | |
| for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) | |
| ] | |
| non_masked_elems = (targets != -1).sum() | |
| mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems) | |
| return mean_loss | |
| # no chunking at all | |
| logits = logits.reshape(-1, logits.size(-1)) | |
| targets = targets.reshape(-1) | |
| if chunk_size == 0: | |
| return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) | |
| # lm_head wasn't chunked, chunk cross entropy | |
| logit_chunks = logits.split(chunk_size) | |
| target_chunks = targets.split(chunk_size) | |
| loss_chunks = [ | |
| torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") | |
| for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) | |
| ] | |
| non_masked_elems = (targets != -1).sum() | |
| mean_loss = torch.cat(loss_chunks).sum() / max(1, non_masked_elems) | |
| return mean_loss | |
| def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: | |
| for checkpoint_name, attribute_name in mapping.items(): | |
| full_checkpoint_name = prefix + checkpoint_name | |
| if full_checkpoint_name in state_dict: | |
| full_attribute_name = prefix + attribute_name | |
| state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) | |
| return state_dict | |
| def get_default_supported_precision(training: bool) -> str: | |
| """Return default precision that is supported by the hardware: either `bf16` or `16`. | |
| Args: | |
| training: `-mixed` or `-true` version of the precision to use | |
| Returns: | |
| default precision that is suitable for the task and is supported by the hardware | |
| """ | |
| from lightning.fabric.accelerators import MPSAccelerator | |
| if MPSAccelerator.is_available() or (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()): | |
| return "16-mixed" if training else "16-true" | |
| return "bf16-mixed" if training else "bf16-true" | |
| def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None: | |
| if isinstance(fabric.strategy, FSDPStrategy): | |
| fabric.load_raw(checkpoint_path, model, strict=strict) | |
| else: | |
| state_dict = lazy_load(checkpoint_path) | |
| state_dict = state_dict.get("model", state_dict) | |
| model.load_state_dict(state_dict, strict=strict) | |
| def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int: | |
| flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation | |
| # this assumes that all samples have a fixed length equal to the block size | |
| # which is most likely false during finetuning | |
| flops_per_seq = flops_per_token * max_seq_length | |
| attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2)) | |
| return flops_per_seq + attn_flops_per_seq | |
| def estimate_flops(model: "GPT", training: bool) -> int: | |
| """Measures estimated FLOPs for MFU. | |
| Refs: | |
| * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 | |
| * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 | |
| """ | |
| # using all parameters for this is a naive over estimation because not all model parameters actually contribute to | |
| # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage | |
| # (~10%) compared to the measured FLOPs, making those lower but more realistic. | |
| # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. | |
| n_trainable_params = num_parameters(model, requires_grad=True) | |
| trainable_flops = flops_per_param( | |
| model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params | |
| ) | |
| # forward + backward + gradients (assumes no gradient accumulation) | |
| ops_per_step = 3 if training else 1 | |
| n_frozen_params = num_parameters(model, requires_grad=False) | |
| frozen_flops = flops_per_param(model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params) | |
| # forward + backward | |
| frozen_ops_per_step = 2 if training else 1 | |
| return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops | |