File size: 7,437 Bytes
0dea527
 
 
 
562c3cb
0dea527
 
 
 
 
 
964360b
0dea527
 
562c3cb
 
 
 
674c962
 
 
 
 
 
 
 
 
 
964360b
674c962
0dea527
674c962
 
0dea527
 
 
 
 
 
964360b
0dea527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674c962
 
0dea527
 
 
674c962
 
 
 
 
0dea527
964360b
 
 
 
 
674c962
 
964360b
 
 
 
 
 
 
 
0dea527
 
 
 
 
 
674c962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0dea527
 
964360b
0dea527
 
 
fc22c78
 
0dea527
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import pandas as pd
from hub_utils import check_for_discussion, report_results
from model_utils import calculate_memory, get_model
from huggingface_hub.utils import HfHubHTTPError


# We need to store them as globals because gradio doesn't have a way for us to pass them in to the button
MODEL = None


def get_results(model_name: str, library: str, precision: list, training: list,  access_token: str, zero_stage: int, num_nodes: int, num_gpus: int, offloading: list, zero_init: list, additional_buffer_factor: float):
    global MODEL
    MODEL = get_model(model_name, library, access_token)
    try:
        has_discussion = check_for_discussion(model_name)
    except HfHubHTTPError:
        has_discussion = True

    options = {
        "precision": precision,
        "zero_stage": zero_stage,
        "cpu_offload": True if "Optimizer" in offloading else False,
        "cpu_offload_params": True if "Parameters" in offloading else False,
        "zero_init": True if "zero.Init" in zero_init else False,
        "num_nodes": num_nodes,
        "num_gpus_per_node": num_gpus,
        "training_regime": training,
        "additional_buffer_factor": additional_buffer_factor
    }
    data = calculate_memory(MODEL, options)
    
    title = f"## Memory usage for '{model_name}'"
    return [title, gr.update(visible=True, value=pd.DataFrame(data)), gr.update(visible=not has_discussion)]


with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown(
            """<img src="https://huggingface.co/spaces/andstor/model-memory-usage/resolve/main/measure_model_size.png" style="float: left;" width="250" height="250"><h1>🤗 DeepSpeed Model Memory Calculator</h1>

    This tool will help you calculate how much vRAM is needed to train and perform big model inference
    on a model hosted on the 🤗 Hugging Face Hub. The minimum recommended vRAM needed for a model
    is denoted as the size of the "largest layer", and training of a model is roughly 4x its size (for Adam).

    These calculations are accurate within a few percent at most, such as `bert-base-cased` being 413.68 MB and the calculator estimating 413.18 MB.

    When performing inference, expect to add up to an additional 20% to this as found by [EleutherAI](https://blog.eleuther.ai/transformer-math/). 
    More tests will be performed in the future to get a more accurate benchmark for each model.

    Currently this tool supports all models hosted that use `transformers` and `timm`.

    To use this tool pass in the URL or model name of the model you want to calculate the memory usage for,
    select which framework it originates from ("auto" will try and detect it from the model metadata), and
    what precisions you want to use."""
        )
        out_text = gr.Markdown()
        out = gr.DataFrame(
            headers=["dtype", "Largest Layer", "Total Size", "Training using Adam"],
            interactive=False,
            visible=False,
        )
        with gr.Row():
            inp = gr.Textbox(label="Model Name or URL", value="bert-base-cased")
        with gr.Row():
            library = gr.Radio(["auto", "transformers", "timm"], label="Library", value="auto")
            precision = gr.CheckboxGroup(
                ["float32", "float16/bfloat16"],
                value="float32",
                label="Model Precision",
            )
            training = gr.Radio(
                ["Mixed precision", "Single precision"],
                value="Mixed precision",
                label="Training Paradigm",
            )
            access_token = gr.Textbox(label="API Token", placeholder="Optional (for gated models)")
            num_gpus = gr.Number(label="GPUs per node", value=4, minimum=1, step=1)
            num_nodes = gr.Number(label="Nodes", value=1, minimum=1, step=1)
        with gr.Column(variant="panel"):
            with gr.Row(equal_height=True):

                zero_stage = gr.Radio(["Stage 0", "Stage 1", "Stage 2", "Stage 3"], label="ZeRO Stage", value="Stage 3", type="index")
                zero_description = gr.CheckboxGroup(["Optimizer state",  "Gradients", "Parameters"], label="Partitioning", value=["Optimizer state",  "Gradients", "Parameters"], interactive=False)

            with gr.Row(equal_height=True):
                #with gr.Column():
                offloading = gr.CheckboxGroup(["Optimizer",  "Parameters"], label="ZeRO-Offload", info="Offloading data and compute to CPU", value=["Optimizer",  "Parameters"])
                zero_init = gr.CheckboxGroup(["zero.Init"], value=["zero.Init"], label="Initialization")
                
                #with gr.Column():
                additional_buffer_factor = gr.Number(label="Additional Buffer Factor", value=1.5, minimum=1, step=0.1)
        with gr.Row():
            btn = gr.Button("Calculate Memory Usage")
            post_to_hub = gr.Button(
                value="Report results in this model repo's discussions!\n(Will open in a new tab)", visible=False
            )

    def change_zero_settings(evt: gr.SelectData):  # SelectData is a subclass of EventData
        if evt.index == 0:
            return [gr.update(visible = False), gr.update(visible = False)]
        if evt.index == 1 or evt.index == 2:
            return [gr.update(choices=["Optimizer"], visible=True), gr.update(visible = False)]
        if evt.index == 3:
            return [gr.update(choices=["Optimizer", "Parameters"], visible=True), gr.update(visible = True)]
    
    def change_zero_description(evt: gr.SelectData):  # SelectData is a subclass of EventData
        if evt.index == 0:
            return gr.update(value=None)
        if evt.index == 1:
            return gr.update(value=["Optimizer state"])
        if evt.index == 2:
            return gr.update(value=["Optimizer state", "Gradients"])
        if evt.index == 3:
            return gr.update(value=["Optimizer state", "Gradients", "Parameters"])
    
    def change_offloading(evt: gr.SelectData, zero_stage):  # SelectData is a subclass of EventData
        
        if evt.value == "Optimizer" and evt.selected == False:
            return gr.CheckboxGroup.update(choices=["Optimizer"], value=[])
        
        if evt.value == "Optimizer" and evt.selected == True:
            if zero_stage in [1, 2]:
                return gr.CheckboxGroup.update(choices=["Optimizer"], value=["Optimizer"])
            elif zero_stage == 3:
                return gr.CheckboxGroup.update(choices=["Optimizer", "Parameters"], value=["Optimizer"])

        if evt.value == "Parameters" and evt.selected == False:
            return gr.CheckboxGroup.update(value=["Optimizer"])

        if evt.value == "Parameters" and evt.selected == True:

            return gr.CheckboxGroup.update(value=["Optimizer", "Parameters"])



    zero_stage.select(change_zero_settings, None, [offloading, zero_init])
    zero_stage.select(change_zero_description, None, zero_description)
    offloading.select(change_offloading, zero_stage, offloading)


    btn.click(
        get_results,
        inputs=[inp, library, precision, training, access_token, zero_stage, num_nodes, num_gpus, offloading, zero_init, additional_buffer_factor],
        outputs=[out_text, out, post_to_hub],
    )

    post_to_hub.click(lambda: gr.Button.update(visible=False), outputs=post_to_hub).then(
        report_results, inputs=[inp, library, access_token]
    )


demo.launch()