GeminiFan207 commited on
Commit
10bfb12
·
verified ·
1 Parent(s): 80219c0

Update base_model.safetensors

Browse files
Files changed (1) hide show
  1. base_model.safetensors +48 -125
base_model.safetensors CHANGED
@@ -1,128 +1,51 @@
1
- from safetensors.torch import load_file, save_file
2
  import torch
3
- from typing import List, Dict, Optional
4
- import logging
5
- from tqdm import tqdm
6
- import os
7
- import hashlib
8
- from concurrent.futures import ThreadPoolExecutor, as_completed
9
-
10
- # Configure logging
11
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
12
- logger = logging.getLogger(__name__)
13
-
14
- def calculate_checksum(file_path: str) -> str:
15
- """
16
- Calculate the SHA-256 checksum of a file.
17
-
18
- Args:
19
- file_path (str): Path to the file.
20
-
21
- Returns:
22
- str: SHA-256 checksum of the file.
23
- """
24
- sha256 = hashlib.sha256()
25
- with open(file_path, "rb") as f:
26
- for chunk in iter(lambda: f.read(4096), b""):
27
- sha256.update(chunk)
28
- return sha256.hexdigest()
29
-
30
- def verify_checksums(model_parts: List[str], expected_checksums: List[str]) -> None:
31
- """
32
- Verify the checksums of model part files.
33
-
34
- Args:
35
- model_parts (list): List of model part file paths.
36
- expected_checksums (list): List of expected checksums for each part.
37
-
38
- Raises:
39
- RuntimeError: If any checksum does not match.
40
- """
41
- for part, expected_checksum in zip(model_parts, expected_checksums):
42
- actual_checksum = calculate_checksum(part)
43
- if actual_checksum != expected_checksum:
44
- raise RuntimeError(f"Checksum mismatch for {part}: expected {expected_checksum}, got {actual_checksum}")
45
-
46
- def load_part(part: str) -> Dict[str, torch.Tensor]:
47
- """
48
- Load a single model part.
49
-
50
- Args:
51
- part (str): Path to the model part file.
52
-
53
- Returns:
54
- dict: State dictionary of the model part.
55
- """
56
- return load_file(part)
57
-
58
- def load_charm_model(model_parts: List[str], expected_checksums: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
59
- """
60
- Load and merge multiple .safetensors model files.
61
-
62
- Args:
63
- model_parts (list): List of model part file paths (e.g., ["model-1-of-10.safetensors", ...]).
64
- expected_checksums (list, optional): List of expected checksums for each part.
65
-
66
- Returns:
67
- dict: Merged model state dictionary.
68
-
69
- Raises:
70
- FileNotFoundError: If any model part file is missing.
71
- RuntimeError: If there is an issue loading or merging the model parts.
72
- """
73
- merged_state_dict = {}
74
-
75
- # Check if all model parts exist
76
- for part in model_parts:
77
- if not os.path.exists(part):
78
- raise FileNotFoundError(f"Model part not found: {part}")
79
-
80
- # Verify checksums if provided
81
- if expected_checksums:
82
- logger.info("Verifying checksums...")
83
- verify_checksums(model_parts, expected_checksums)
84
- logger.info("Checksums verified successfully.")
85
-
86
- # Load and merge model parts in parallel
87
- try:
88
- logger.info("Loading and merging model parts...")
89
- with ThreadPoolExecutor() as executor:
90
- futures = {executor.submit(load_part, part): part for part in model_parts}
91
- for future in tqdm(as_completed(futures), total=len(futures), desc="Loading model parts"):
92
- part = futures[future]
93
- try:
94
- state_dict = future.result()
95
- merged_state_dict.update(state_dict) # Merge parameters
96
- logger.debug(f"Loaded part: {part}")
97
- except Exception as e:
98
- logger.error(f"Error loading part {part}: {e}")
99
- raise RuntimeError(f"Failed to load part: {part}")
100
-
101
- logger.info("Model parts loaded and merged successfully.")
102
- return merged_state_dict
103
- except Exception as e:
104
- logger.error(f"Error loading or merging model parts: {e}")
105
- raise RuntimeError("Failed to load or merge model parts.")
106
 
107
  # Example usage
108
- if __name__ == "__main__":
109
- try:
110
- # List of model part files
111
- model_files = [f"model-{i}-of-10.safetensors" for i in range(1, 11)]
112
-
113
- # Optional: List of expected checksums for each part
114
- expected_checksums = [
115
- "checksum_for_model-1-of-10.safetensors",
116
- "checksum_for_model-2-of-10.safetensors",
117
- # Add checksums for all parts...
118
- ]
119
-
120
- # Load and merge the model
121
- charm_model = load_charm_model(model_files, expected_checksums)
122
-
123
- # Save the merged model as a .safetensors file
124
- output_file = "merged_model.safetensors"
125
- save_file(charm_model, output_file)
126
- logger.info(f"Merged model saved as '{output_file}'.")
127
- except Exception as e:
128
- logger.error(f"An error occurred: {e}")
 
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from safetensors.torch import load_file
4
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
5
+
6
+ # Specify the model name and safetensors file path
7
+ MODEL_NAME = "mistral-8x7B"
8
+ SAFETENSORS_PATH = "path_to_your_model.safetensors"
9
+
10
+ # Load the tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+
13
+ # Initialize an empty model (no weights loaded yet)
14
+ with init_empty_weights():
15
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
16
+
17
+ # Load the model weights from the safetensors file
18
+ model_weights = load_file(SAFETENSORS_PATH)
19
+
20
+ # Use Hugging Face's `accelerate` to load the model efficiently
21
+ # This allows for sharding and offloading to CPU/disk if needed
22
+ model = load_checkpoint_and_dispatch(
23
+ model,
24
+ SAFETENSORS_PATH,
25
+ device_map="auto", # Automatically handles GPU/CPU offloading
26
+ no_split_module_classes=["MistralLayer"], # Specify layers not to split
27
+ dtype=torch.float16, # Use mixed precision for memory efficiency
28
+ )
29
+
30
+ # Move the model to the appropriate device
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Example usage
35
+ input_text = "Hello, how are you?"
36
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
37
+
38
+ # Generate output with efficient memory usage
39
+ with torch.no_grad():
40
+ outputs = model.generate(
41
+ inputs["input_ids"],
42
+ max_length=50,
43
+ num_return_sequences=1,
44
+ temperature=0.7,
45
+ top_k=50,
46
+ top_p=0.95,
47
+ )
48
+
49
+ # Decode and print the output
50
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ print("Generated Text:", generated_text)