GeminiFan207 commited on
Commit
87cbc7b
·
verified ·
1 Parent(s): 09c4008

Create base_model.safetensors

Browse files
Files changed (1) hide show
  1. base_model.safetensors +67 -0
base_model.safetensors ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
4
+ from safetensors.torch import save_file
5
+
6
+ # Define model and output settings
7
+ model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Corrected to a real model
8
+ output_dir = "mixtral_8x7b_safetensors"
9
+ max_shard_size = 2 * 1024 * 1024 * 1024 # 2GB per shard in bytes
10
+ dtype = torch.float16 # Half-precision to save space
11
+
12
+ # Create output directory
13
+ os.makedirs(output_dir, exist_ok=True)
14
+
15
+ try:
16
+ # Load config and tokenizer first (low memory footprint)
17
+ config = AutoConfig.from_pretrained(model_name)
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+
20
+ # Save config and tokenizer for later use
21
+ config.save_pretrained(output_dir)
22
+ tokenizer.save_pretrained(output_dir)
23
+
24
+ # Load model with offloading to avoid OOM (if GPU available)
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=dtype,
28
+ device_map="auto", # Auto-distribute across GPU/CPU
29
+ low_cpu_mem_usage=True # Reduce RAM usage during load
30
+ )
31
+
32
+ # Get state dict
33
+ state_dict = model.state_dict()
34
+
35
+ # Estimate total size and shard dynamically
36
+ total_size = sum(t.element_size() * t.nelement() for t in state_dict.values())
37
+ num_shards = max(1, int(total_size / max_shard_size) + 1) # At least 1 shard
38
+
39
+ # Distribute parameters by size, not count
40
+ shards = [{} for _ in range(num_shards)]
41
+ current_size = [0] * num_shards
42
+ shard_index = 0
43
+
44
+ for key, value in state_dict.items():
45
+ tensor_size = value.element_size() * value.nelement()
46
+ # Move to next shard if current one exceeds size limit
47
+ while current_size[shard_index] + tensor_size > max_shard_size and shard_index < num_shards - 1:
48
+ shard_index += 1
49
+ shards[shard_index][key] = value
50
+ current_size[shard_index] += tensor_size
51
+
52
+ # Save each shard
53
+ for i, shard in enumerate(shards):
54
+ if shard: # Only save non-empty shards
55
+ shard_path = os.path.join(output_dir, f"model_shard_{i}.safetensors")
56
+ save_file(shard, shard_path)
57
+ print(f"Saved shard {i} to {shard_path}")
58
+
59
+ print(f"Model saved to {output_dir} with {len([s for s in shards if s])} shards")
60
+
61
+ except Exception as e:
62
+ print(f"Error occurred: {str(e)}")
63
+ finally:
64
+ # Clean up memory
65
+ if 'model' in locals():
66
+ del model
67
+ torch.cuda.empty_cache() # Clear GPU memory if used