GeminiFan207 commited on
Commit
0fabaf6
·
verified ·
1 Parent(s): 6332035

Rename model to model-1-of-10.safetensors

Browse files
Files changed (2) hide show
  1. model +0 -0
  2. model-1-of-10.safetensors +60 -0
model DELETED
File without changes
model-1-of-10.safetensors ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors.torch import load_file, save_file
2
+ import torch
3
+ from typing import List, Dict
4
+ import logging
5
+ from tqdm import tqdm
6
+ import os
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def load_charm_model(model_parts: List[str]) -> Dict[str, torch.Tensor]:
13
+ """
14
+ Load and merge multiple .safetensors model files.
15
+
16
+ Args:
17
+ model_parts (list): List of model part file paths (e.g., ["model-1-of-10.safetensors", ...])
18
+
19
+ Returns:
20
+ dict: Merged model state dictionary
21
+
22
+ Raises:
23
+ FileNotFoundError: If any model part file is missing.
24
+ RuntimeError: If there is an issue loading or merging the model parts.
25
+ """
26
+ merged_state_dict = {}
27
+
28
+ # Check if all model parts exist
29
+ for part in model_parts:
30
+ if not os.path.exists(part):
31
+ raise FileNotFoundError(f"Model part not found: {part}")
32
+
33
+ # Load and merge model parts
34
+ try:
35
+ logger.info("Loading and merging model parts...")
36
+ for part in tqdm(model_parts, desc="Loading model parts"):
37
+ state_dict = load_file(part)
38
+ merged_state_dict.update(state_dict) # Merge parameters
39
+ logger.debug(f"Loaded part: {part}")
40
+
41
+ logger.info("Model parts loaded and merged successfully.")
42
+ return merged_state_dict
43
+ except Exception as e:
44
+ logger.error(f"Error loading or merging model parts: {e}")
45
+ raise RuntimeError("Failed to load or merge model parts.")
46
+
47
+ # Example usage
48
+ if __name__ == "__main__":
49
+ try:
50
+ # List of model part files
51
+ model_files = [f"model-{i}-of-10.safetensors" for i in range(1, 11)]
52
+
53
+ # Load and merge the model
54
+ charm_model = load_charm_model(model_files)
55
+
56
+ # Save the merged model as a .safetensors file
57
+ save_file(charm_model, "merged_model.safetensors")
58
+ logger.info("Merged model saved as 'merged_model.safetensors'.")
59
+ except Exception as e:
60
+ logger.error(f"An error occurred: {e}")