GeminiFan207 commited on
Commit
80219c0
·
verified ·
1 Parent(s): 0fabaf6

Update base_model.safetensors

Browse files
Files changed (1) hide show
  1. base_model.safetensors +127 -66
base_model.safetensors CHANGED
@@ -1,67 +1,128 @@
1
- import os
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
4
- from safetensors.torch import save_file
5
-
6
- # Define model and output settings
7
- model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" # Corrected to a real model
8
- output_dir = "mixtral_8x7b_safetensors"
9
- max_shard_size = 2 * 1024 * 1024 * 1024 # 2GB per shard in bytes
10
- dtype = torch.float16 # Half-precision to save space
11
-
12
- # Create output directory
13
- os.makedirs(output_dir, exist_ok=True)
14
-
15
- try:
16
- # Load config and tokenizer first (low memory footprint)
17
- config = AutoConfig.from_pretrained(model_name)
18
- tokenizer = AutoTokenizer.from_pretrained(model_name)
19
-
20
- # Save config and tokenizer for later use
21
- config.save_pretrained(output_dir)
22
- tokenizer.save_pretrained(output_dir)
23
-
24
- # Load model with offloading to avoid OOM (if GPU available)
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_name,
27
- torch_dtype=dtype,
28
- device_map="auto", # Auto-distribute across GPU/CPU
29
- low_cpu_mem_usage=True # Reduce RAM usage during load
30
- )
31
-
32
- # Get state dict
33
- state_dict = model.state_dict()
34
-
35
- # Estimate total size and shard dynamically
36
- total_size = sum(t.element_size() * t.nelement() for t in state_dict.values())
37
- num_shards = max(1, int(total_size / max_shard_size) + 1) # At least 1 shard
38
-
39
- # Distribute parameters by size, not count
40
- shards = [{} for _ in range(num_shards)]
41
- current_size = [0] * num_shards
42
- shard_index = 0
43
-
44
- for key, value in state_dict.items():
45
- tensor_size = value.element_size() * value.nelement()
46
- # Move to next shard if current one exceeds size limit
47
- while current_size[shard_index] + tensor_size > max_shard_size and shard_index < num_shards - 1:
48
- shard_index += 1
49
- shards[shard_index][key] = value
50
- current_size[shard_index] += tensor_size
51
-
52
- # Save each shard
53
- for i, shard in enumerate(shards):
54
- if shard: # Only save non-empty shards
55
- shard_path = os.path.join(output_dir, f"model_shard_{i}.safetensors")
56
- save_file(shard, shard_path)
57
- print(f"Saved shard {i} to {shard_path}")
58
-
59
- print(f"Model saved to {output_dir} with {len([s for s in shards if s])} shards")
60
-
61
- except Exception as e:
62
- print(f"Error occurred: {str(e)}")
63
- finally:
64
- # Clean up memory
65
- if 'model' in locals():
66
- del model
67
- torch.cuda.empty_cache() # Clear GPU memory if used
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors.torch import load_file, save_file
2
  import torch
3
+ from typing import List, Dict, Optional
4
+ import logging
5
+ from tqdm import tqdm
6
+ import os
7
+ import hashlib
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def calculate_checksum(file_path: str) -> str:
15
+ """
16
+ Calculate the SHA-256 checksum of a file.
17
+
18
+ Args:
19
+ file_path (str): Path to the file.
20
+
21
+ Returns:
22
+ str: SHA-256 checksum of the file.
23
+ """
24
+ sha256 = hashlib.sha256()
25
+ with open(file_path, "rb") as f:
26
+ for chunk in iter(lambda: f.read(4096), b""):
27
+ sha256.update(chunk)
28
+ return sha256.hexdigest()
29
+
30
+ def verify_checksums(model_parts: List[str], expected_checksums: List[str]) -> None:
31
+ """
32
+ Verify the checksums of model part files.
33
+
34
+ Args:
35
+ model_parts (list): List of model part file paths.
36
+ expected_checksums (list): List of expected checksums for each part.
37
+
38
+ Raises:
39
+ RuntimeError: If any checksum does not match.
40
+ """
41
+ for part, expected_checksum in zip(model_parts, expected_checksums):
42
+ actual_checksum = calculate_checksum(part)
43
+ if actual_checksum != expected_checksum:
44
+ raise RuntimeError(f"Checksum mismatch for {part}: expected {expected_checksum}, got {actual_checksum}")
45
+
46
+ def load_part(part: str) -> Dict[str, torch.Tensor]:
47
+ """
48
+ Load a single model part.
49
+
50
+ Args:
51
+ part (str): Path to the model part file.
52
+
53
+ Returns:
54
+ dict: State dictionary of the model part.
55
+ """
56
+ return load_file(part)
57
+
58
+ def load_charm_model(model_parts: List[str], expected_checksums: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
59
+ """
60
+ Load and merge multiple .safetensors model files.
61
+
62
+ Args:
63
+ model_parts (list): List of model part file paths (e.g., ["model-1-of-10.safetensors", ...]).
64
+ expected_checksums (list, optional): List of expected checksums for each part.
65
+
66
+ Returns:
67
+ dict: Merged model state dictionary.
68
+
69
+ Raises:
70
+ FileNotFoundError: If any model part file is missing.
71
+ RuntimeError: If there is an issue loading or merging the model parts.
72
+ """
73
+ merged_state_dict = {}
74
+
75
+ # Check if all model parts exist
76
+ for part in model_parts:
77
+ if not os.path.exists(part):
78
+ raise FileNotFoundError(f"Model part not found: {part}")
79
+
80
+ # Verify checksums if provided
81
+ if expected_checksums:
82
+ logger.info("Verifying checksums...")
83
+ verify_checksums(model_parts, expected_checksums)
84
+ logger.info("Checksums verified successfully.")
85
+
86
+ # Load and merge model parts in parallel
87
+ try:
88
+ logger.info("Loading and merging model parts...")
89
+ with ThreadPoolExecutor() as executor:
90
+ futures = {executor.submit(load_part, part): part for part in model_parts}
91
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Loading model parts"):
92
+ part = futures[future]
93
+ try:
94
+ state_dict = future.result()
95
+ merged_state_dict.update(state_dict) # Merge parameters
96
+ logger.debug(f"Loaded part: {part}")
97
+ except Exception as e:
98
+ logger.error(f"Error loading part {part}: {e}")
99
+ raise RuntimeError(f"Failed to load part: {part}")
100
+
101
+ logger.info("Model parts loaded and merged successfully.")
102
+ return merged_state_dict
103
+ except Exception as e:
104
+ logger.error(f"Error loading or merging model parts: {e}")
105
+ raise RuntimeError("Failed to load or merge model parts.")
106
+
107
+ # Example usage
108
+ if __name__ == "__main__":
109
+ try:
110
+ # List of model part files
111
+ model_files = [f"model-{i}-of-10.safetensors" for i in range(1, 11)]
112
+
113
+ # Optional: List of expected checksums for each part
114
+ expected_checksums = [
115
+ "checksum_for_model-1-of-10.safetensors",
116
+ "checksum_for_model-2-of-10.safetensors",
117
+ # Add checksums for all parts...
118
+ ]
119
+
120
+ # Load and merge the model
121
+ charm_model = load_charm_model(model_files, expected_checksums)
122
+
123
+ # Save the merged model as a .safetensors file
124
+ output_file = "merged_model.safetensors"
125
+ save_file(charm_model, output_file)
126
+ logger.info(f"Merged model saved as '{output_file}'.")
127
+ except Exception as e:
128
+ logger.error(f"An error occurred: {e}")