Charm_15 / model_3_of_278.safetensors
GeminiFan207's picture
Rename model_3_of_10.safetensors to model_3_of_278.safetensors
589a376 verified
import torch
from safetensors.torch import load_file, save_file
import logging
import asyncio
# Logging setup
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
# Define model checkpoint path
MODEL_CHECKPOINT = "model-3-of-10.safetensors"
# Detect GPU availability
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load model with efficient memory management
async def load_model(filepath: str) -> dict:
"""Asynchronously loads a model from a safetensors file."""
try:
logging.info(f"Loading model from {filepath} on {DEVICE}...")
model_data = load_file(filepath, device=DEVICE)
logging.info(f"Model {filepath} successfully loaded.")
return model_data
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
raise RuntimeError(f"Error loading model: {str(e)}")
# Save model with optimized storage format
async def save_model(filepath: str, model_tensors: dict):
"""Asynchronously saves a model to a safetensors file."""
try:
logging.info(f"Saving model to {filepath}...")
save_file(model_tensors, filepath)
logging.info(f"Model saved at {filepath}")
except Exception as e:
logging.error(f"Error saving model: {str(e)}")
raise RuntimeError(f"Error saving model: {str(e)}")
# Dynamically generate layers for efficient scaling
def initialize_model(layers: list = [4096, 8192, 16384], dtype: torch.dtype = torch.float16) -> dict:
"""Initializes a model with random tensors for each layer."""
model_tensors = {}
for i, size in enumerate(layers):
layer_name = f"layer_{i+1}"
logging.info(f"Initializing {layer_name} with size {size}x{size} on {DEVICE}...")
model_tensors[layer_name] = torch.randn(size, size, dtype=dtype, device=DEVICE)
torch.cuda.empty_cache() # Free unused memory
logging.info("Model initialization completed.")
return model_tensors
# Main execution
async def main():
model_data = initialize_model()
# Save the model for deployment
await save_model(MODEL_CHECKPOINT, model_data)
# Load the model for verification
loaded_model_data = await load_model(MODEL_CHECKPOINT)
# Verify loaded tensors match saved tensors
for key in model_data:
if not torch.allclose(model_data[key], loaded_model_data[key], atol=1e-5):
logging.warning(f"Tensor mismatch in {key}!")
else:
logging.info(f"Tensor {key} verified successfully.")
# Run asynchronously
asyncio.run(main())