Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2023 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from huggingface_hub import model_info | |
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError | |
from accelerate import init_empty_weights | |
from accelerate.commands.utils import CustomArgumentParser | |
from accelerate.utils import ( | |
calculate_maximum_sizes, | |
convert_bytes, | |
is_timm_available, | |
is_transformers_available, | |
) | |
if is_transformers_available(): | |
import transformers | |
from transformers import AutoConfig, AutoModel | |
if is_timm_available(): | |
import timm | |
def verify_on_hub(repo: str, token: str = None): | |
"Verifies that the model is on the hub and returns the model info." | |
try: | |
return model_info(repo, token=token) | |
except GatedRepoError: | |
return "gated" | |
except RepositoryNotFoundError: | |
return "repo" | |
def check_has_model(error): | |
""" | |
Checks what library spawned `error` when a model is not found | |
""" | |
if is_timm_available() and isinstance(error, RuntimeError) and "Unknown model" in error.args[0]: | |
return "timm" | |
elif ( | |
is_transformers_available() | |
and isinstance(error, OSError) | |
and "does not appear to have a file named" in error.args[0] | |
): | |
return "transformers" | |
else: | |
return "unknown" | |
def create_empty_model(model_name: str, library_name: str, trust_remote_code: bool = False, access_token: str = None): | |
""" | |
Creates an empty model from its parent library on the `Hub` to calculate the overall memory consumption. | |
Args: | |
model_name (`str`): | |
The model name on the Hub | |
library_name (`str`): | |
The library the model has an integration with, such as `transformers`. Will be used if `model_name` has no | |
metadata on the Hub to determine the library. | |
trust_remote_code (`bool`, `optional`, defaults to `False`): | |
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option | |
should only be set to `True` for repositories you trust and in which you have read the code, as it will | |
execute code present on the Hub on your local machine. | |
access_token (`str`, `optional`, defaults to `None`): | |
The access token to use to access private or gated models on the Hub. (for use on the Gradio app) | |
Returns: | |
`torch.nn.Module`: The torch model that has been initialized on the `meta` device. | |
""" | |
model_info = verify_on_hub(model_name, access_token) | |
# Simplified errors | |
if model_info == "gated": | |
raise GatedRepoError( | |
f"Repo for model `{model_name}` is gated. You must be authenticated to access it. Please run `huggingface-cli login`." | |
) | |
elif model_info == "repo": | |
raise RepositoryNotFoundError( | |
f"Repo for model `{model_name}` does not exist on the Hub. If you are trying to access a private repo," | |
" make sure you are authenticated via `huggingface-cli login` and have access." | |
) | |
if library_name is None: | |
library_name = getattr(model_info, "library_name", False) | |
if not library_name: | |
raise ValueError( | |
f"Model `{model_name}` does not have any library metadata on the Hub, please manually pass in a `--library_name` to use (such as `transformers`)" | |
) | |
if library_name == "transformers": | |
if not is_transformers_available(): | |
raise ImportError( | |
f"To check `{model_name}`, `transformers` must be installed. Please install it via `pip install transformers`" | |
) | |
print(f"Loading pretrained config for `{model_name}` from `transformers`...") | |
if model_info.config is None: | |
raise RuntimeError(f"Tried to load `{model_name}` with `transformers` but it does not have any metadata.") | |
auto_map = model_info.config.get("auto_map", False) | |
config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code, token=access_token) | |
with init_empty_weights(): | |
# remote code could specify a specific `AutoModel` class in the `auto_map` | |
constructor = AutoModel | |
if isinstance(auto_map, dict): | |
value = None | |
for key in auto_map.keys(): | |
if key.startswith("AutoModelFor"): | |
value = key | |
break | |
if value is not None: | |
constructor = getattr(transformers, value) | |
model = constructor.from_config(config, trust_remote_code=trust_remote_code) | |
elif library_name == "timm": | |
if not is_timm_available(): | |
raise ImportError( | |
f"To check `{model_name}`, `timm` must be installed. Please install it via `pip install timm`" | |
) | |
print(f"Loading pretrained config for `{model_name}` from `timm`...") | |
with init_empty_weights(): | |
model = timm.create_model(model_name, pretrained=False) | |
else: | |
raise ValueError( | |
f"Library `{library_name}` is not supported yet, please open an issue on GitHub for us to add support." | |
) | |
return model | |
def create_ascii_table(headers: list, rows: list, title: str): | |
"Creates a pretty table from a list of rows, minimal version of `tabulate`." | |
sep_char, in_between = "β", "β" | |
column_widths = [] | |
for i in range(len(headers)): | |
column_values = [row[i] for row in rows] + [headers[i]] | |
max_column_width = max(len(value) for value in column_values) | |
column_widths.append(max_column_width) | |
formats = [f"%{column_widths[i]}s" for i in range(len(rows[0]))] | |
pattern = f"{sep_char}{sep_char.join(formats)}{sep_char}" | |
diff = 0 | |
def make_row(left_char, middle_char, right_char): | |
return f"{left_char}{middle_char.join([in_between * n for n in column_widths])}{in_between * diff}{right_char}" | |
separator = make_row("β", "βΌ", "β€") | |
if len(title) > sum(column_widths): | |
diff = abs(len(title) - len(separator)) | |
column_widths[-1] += diff | |
# Update with diff | |
separator = make_row("β", "βΌ", "β€") | |
initial_rows = [ | |
make_row("β", in_between, "β"), | |
f"{sep_char}{title.center(len(separator) - 2)}{sep_char}", | |
make_row("β", "β¬", "β€"), | |
] | |
table = "\n".join(initial_rows) + "\n" | |
column_widths[-1] += diff | |
centered_line = [text.center(column_widths[i]) for i, text in enumerate(headers)] | |
table += f"{pattern % tuple(centered_line)}\n{separator}\n" | |
for i, line in enumerate(rows): | |
centered_line = [t.center(column_widths[i]) for i, t in enumerate(line)] | |
table += f"{pattern % tuple(centered_line)}\n" | |
table += f'β{"β΄".join([in_between * n for n in column_widths])}β' | |
return table | |
def estimate_command_parser(subparsers=None): | |
if subparsers is not None: | |
parser = subparsers.add_parser("estimate-memory") | |
else: | |
parser = CustomArgumentParser(description="Model size estimator for fitting a model onto CUDA memory.") | |
parser.add_argument("model_name", type=str, help="The model name on the Hugging Face Hub.") | |
parser.add_argument( | |
"--library_name", | |
type=str, | |
help="The library the model has an integration with, such as `transformers`, needed only if this information is not stored on the Hub.", | |
choices=["timm", "transformers"], | |
) | |
parser.add_argument( | |
"--dtypes", | |
type=str, | |
nargs="+", | |
default=["float32", "float16", "int8", "int4"], | |
help="The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`", | |
choices=["float32", "float16", "int8", "int4"], | |
) | |
parser.add_argument( | |
"--trust_remote_code", | |
action="store_true", | |
help="""Whether or not to allow for custom models defined on the Hub in their own modeling files. This flag | |
should only be used for repositories you trust and in which you have read the code, as it will execute | |
code present on the Hub on your local machine.""", | |
default=False, | |
) | |
if subparsers is not None: | |
parser.set_defaults(func=estimate_command) | |
return parser | |
def estimate_training_usage(bytes: int, mixed_precision: str, msamp_config: str = None) -> dict: | |
""" | |
Given an amount of `bytes` and `mixed_precision`, calculates how much training memory is needed for a batch size of | |
1. | |
Args: | |
bytes (`int`): | |
The size of the model being trained. | |
mixed_precision (`str`): | |
The mixed precision that would be ran. | |
msamp_config (`str`): | |
The msamp config to estimate the training memory for if `mixed_precision` is set to `"fp8"`. | |
""" | |
memory_sizes = {"model": -1, "optimizer": -1, "gradients": -1, "step": -1} | |
fp32_size = bytes | |
fp16_size = bytes // 2 | |
if mixed_precision == "float32": | |
memory_sizes["model"] = fp32_size | |
memory_sizes["gradients"] = fp32_size | |
memory_sizes["optimizer"] = fp32_size * 2 | |
memory_sizes["step"] = fp32_size * 4 | |
elif mixed_precision in ("float16", "bfloat16") or (mixed_precision == "fp8" and msamp_config is None): | |
# With native `TransformersEngine`, there is no memory savings with FP8 | |
# With mixed precision training, the model has weights stored | |
# in FP16 and FP32 | |
memory_sizes["model"] = fp32_size | |
# 1.5 from weight gradient + computation (GEMM) | |
memory_sizes["gradients"] = fp32_size + fp16_size | |
# 2x from optimizer states | |
memory_sizes["optimizer"] = fp32_size * 2 # Optimizer states | |
memory_sizes["step"] = memory_sizes["optimizer"] | |
return memory_sizes | |
def gather_data(args): | |
"Creates an empty model and gathers the data for the sizes" | |
try: | |
model = create_empty_model( | |
args.model_name, library_name=args.library_name, trust_remote_code=args.trust_remote_code | |
) | |
except (RuntimeError, OSError) as e: | |
library = check_has_model(e) | |
if library != "unknown": | |
raise RuntimeError( | |
f"Tried to load `{args.model_name}` with `{library}` but a possible model to load was not found inside the repo." | |
) | |
raise e | |
total_size, largest_layer = calculate_maximum_sizes(model) | |
data = [] | |
for dtype in args.dtypes: | |
dtype_total_size = total_size | |
dtype_largest_layer = largest_layer[0] | |
dtype_training_size = estimate_training_usage(dtype_total_size, dtype) | |
if dtype == "float16": | |
dtype_total_size /= 2 | |
dtype_largest_layer /= 2 | |
elif dtype == "int8": | |
dtype_total_size /= 4 | |
dtype_largest_layer /= 4 | |
elif dtype == "int4": | |
dtype_total_size /= 8 | |
dtype_largest_layer /= 8 | |
data.append([dtype, dtype_largest_layer, dtype_total_size, dtype_training_size]) | |
return data | |
def estimate_command(args): | |
data = gather_data(args) | |
for row in data: | |
for i, item in enumerate(row): | |
if isinstance(item, (int, float)): | |
row[i] = convert_bytes(item) | |
elif isinstance(item, dict): | |
training_usage = max(item.values()) | |
row[i] = convert_bytes(training_usage) if training_usage != -1 else "N/A" | |
headers = ["dtype", "Largest Layer", "Total Size", "Training using Adam"] | |
title = f"Memory Usage for loading `{args.model_name}`" | |
table = create_ascii_table(headers, data, title) | |
print(table) | |
def main(): | |
parser = estimate_command_parser() | |
args = parser.parse_args() | |
estimate_command(args) | |
if __name__ == "__main__": | |
main() | |