from safetensors.torch import load_file, save_file import torch from typing import List, Dict, Optional import logging from tqdm import tqdm import os import hashlib from concurrent.futures import ThreadPoolExecutor, as_completed # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def calculate_checksum(file_path: str) -> str: """ Calculate the SHA-256 checksum of a file. Args: file_path (str): Path to the file. Returns: str: SHA-256 checksum of the file. """ sha256 = hashlib.sha256() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): sha256.update(chunk) return sha256.hexdigest() def verify_checksums(model_parts: List[str], expected_checksums: List[str]) -> None: """ Verify the checksums of model part files. Args: model_parts (list): List of model part file paths. expected_checksums (list): List of expected checksums for each part. Raises: RuntimeError: If any checksum does not match. """ for part, expected_checksum in zip(model_parts, expected_checksums): actual_checksum = calculate_checksum(part) if actual_checksum != expected_checksum: raise RuntimeError(f"Checksum mismatch for {part}: expected {expected_checksum}, got {actual_checksum}") def load_part(part: str) -> Dict[str, torch.Tensor]: """ Load a single model part. Args: part (str): Path to the model part file. Returns: dict: State dictionary of the model part. """ return load_file(part) def load_charm_model(model_parts: List[str], expected_checksums: Optional[List[str]] = None) -> Dict[str, torch.Tensor]: """ Load and merge multiple .safetensors model files. Args: model_parts (list): List of model part file paths (e.g., ["model-1-of-10.safetensors", ...]). expected_checksums (list, optional): List of expected checksums for each part. Returns: dict: Merged model state dictionary. Raises: FileNotFoundError: If any model part file is missing. RuntimeError: If there is an issue loading or merging the model parts. """ merged_state_dict = {} # Check if all model parts exist for part in model_parts: if not os.path.exists(part): raise FileNotFoundError(f"Model part not found: {part}") # Verify checksums if provided if expected_checksums: logger.info("Verifying checksums...") verify_checksums(model_parts, expected_checksums) logger.info("Checksums verified successfully.") # Load and merge model parts in parallel try: logger.info("Loading and merging model parts...") with ThreadPoolExecutor() as executor: futures = {executor.submit(load_part, part): part for part in model_parts} for future in tqdm(as_completed(futures), total=len(futures), desc="Loading model parts"): part = futures[future] try: state_dict = future.result() merged_state_dict.update(state_dict) # Merge parameters logger.debug(f"Loaded part: {part}") except Exception as e: logger.error(f"Error loading part {part}: {e}") raise RuntimeError(f"Failed to load part: {part}") logger.info("Model parts loaded and merged successfully.") return merged_state_dict except Exception as e: logger.error(f"Error loading or merging model parts: {e}") raise RuntimeError("Failed to load or merge model parts.") # Example usage if __name__ == "__main__": try: # List of model part files model_files = [f"model-{i}-of-10.safetensors" for i in range(1, 11)] # Optional: List of expected checksums for each part expected_checksums = [ "checksum_for_model-1-of-10.safetensors", "checksum_for_model-2-of-10.safetensors", # Add checksums for all parts... ] # Load and merge the model charm_model = load_charm_model(model_files, expected_checksums) # Save the merged model as a .safetensors file output_file = "merged_model.safetensors" save_file(charm_model, output_file) logger.info(f"Merged model saved as '{output_file}'.") except Exception as e: logger.error(f"An error occurred: {e}")