File size: 1,101 Bytes
5fa1a76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
We plan to add other information in the future:

index["metadata"]
{'total_size': 433245184}

The weights map is the main part of this index, which maps each parameter name (as usually found in a PyTorch model state_dict) to the file it's stored in:

index["weight_map"]
{'embeddings.LayerNorm.bias': 'pytorch_model-00001-of-00003.bin',
 'embeddings.LayerNorm.weight': 'pytorch_model-00001-of-00003.bin',
 

If you want to directly load such a sharded checkpoint inside a model without using [~PreTrainedModel.from_pretrained] (like you would do model.load_state_dict() for a full checkpoint) you should use [~modeling_utils.load_sharded_checkpoint]:

from transformers.modeling_utils import load_sharded_checkpoint
with tempfile.TemporaryDirectory() as tmp_dir:
     model.save_pretrained(tmp_dir, max_shard_size="200MB")
     load_sharded_checkpoint(model, tmp_dir)

Low memory loading
Sharded checkpoints reduce the memory usage during step 2 of the workflow mentioned above, but in order to use that model in a low memory setting, we recommend leveraging our tools based on the Accelerate library.