import torch from safetensors.torch import load_file, save_file import logging import asyncio # Logging setup logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") # Define model checkpoint path MODEL_CHECKPOINT = "model-3-of-10.safetensors" # Detect GPU availability DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Load model with efficient memory management async def load_model(filepath: str) -> dict: """Asynchronously loads a model from a safetensors file.""" try: logging.info(f"Loading model from {filepath} on {DEVICE}...") model_data = load_file(filepath, device=DEVICE) logging.info(f"Model {filepath} successfully loaded.") return model_data except Exception as e: logging.error(f"Error loading model: {str(e)}") raise RuntimeError(f"Error loading model: {str(e)}") # Save model with optimized storage format async def save_model(filepath: str, model_tensors: dict): """Asynchronously saves a model to a safetensors file.""" try: logging.info(f"Saving model to {filepath}...") save_file(model_tensors, filepath) logging.info(f"Model saved at {filepath}") except Exception as e: logging.error(f"Error saving model: {str(e)}") raise RuntimeError(f"Error saving model: {str(e)}") # Dynamically generate layers for efficient scaling def initialize_model(layers: list = [4096, 8192, 16384], dtype: torch.dtype = torch.float16) -> dict: """Initializes a model with random tensors for each layer.""" model_tensors = {} for i, size in enumerate(layers): layer_name = f"layer_{i+1}" logging.info(f"Initializing {layer_name} with size {size}x{size} on {DEVICE}...") model_tensors[layer_name] = torch.randn(size, size, dtype=dtype, device=DEVICE) torch.cuda.empty_cache() # Free unused memory logging.info("Model initialization completed.") return model_tensors # Main execution async def main(): model_data = initialize_model() # Save the model for deployment await save_model(MODEL_CHECKPOINT, model_data) # Load the model for verification loaded_model_data = await load_model(MODEL_CHECKPOINT) # Verify loaded tensors match saved tensors for key in model_data: if not torch.allclose(model_data[key], loaded_model_data[key], atol=1e-5): logging.warning(f"Tensor mismatch in {key}!") else: logging.info(f"Tensor {key} verified successfully.") # Run asynchronously asyncio.run(main())