|
import torch |
|
from safetensors.torch import load_file, save_file |
|
import logging |
|
import asyncio |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
|
|
|
MODEL_CHECKPOINT = "model-3-of-10.safetensors" |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
async def load_model(filepath: str) -> dict: |
|
"""Asynchronously loads a model from a safetensors file.""" |
|
try: |
|
logging.info(f"Loading model from {filepath} on {DEVICE}...") |
|
model_data = load_file(filepath, device=DEVICE) |
|
logging.info(f"Model {filepath} successfully loaded.") |
|
return model_data |
|
except Exception as e: |
|
logging.error(f"Error loading model: {str(e)}") |
|
raise RuntimeError(f"Error loading model: {str(e)}") |
|
|
|
|
|
async def save_model(filepath: str, model_tensors: dict): |
|
"""Asynchronously saves a model to a safetensors file.""" |
|
try: |
|
logging.info(f"Saving model to {filepath}...") |
|
save_file(model_tensors, filepath) |
|
logging.info(f"Model saved at {filepath}") |
|
except Exception as e: |
|
logging.error(f"Error saving model: {str(e)}") |
|
raise RuntimeError(f"Error saving model: {str(e)}") |
|
|
|
|
|
def initialize_model(layers: list = [4096, 8192, 16384], dtype: torch.dtype = torch.float16) -> dict: |
|
"""Initializes a model with random tensors for each layer.""" |
|
model_tensors = {} |
|
for i, size in enumerate(layers): |
|
layer_name = f"layer_{i+1}" |
|
logging.info(f"Initializing {layer_name} with size {size}x{size} on {DEVICE}...") |
|
model_tensors[layer_name] = torch.randn(size, size, dtype=dtype, device=DEVICE) |
|
|
|
torch.cuda.empty_cache() |
|
logging.info("Model initialization completed.") |
|
return model_tensors |
|
|
|
|
|
async def main(): |
|
model_data = initialize_model() |
|
|
|
|
|
await save_model(MODEL_CHECKPOINT, model_data) |
|
|
|
|
|
loaded_model_data = await load_model(MODEL_CHECKPOINT) |
|
|
|
|
|
for key in model_data: |
|
if not torch.allclose(model_data[key], loaded_model_data[key], atol=1e-5): |
|
logging.warning(f"Tensor mismatch in {key}!") |
|
else: |
|
logging.info(f"Tensor {key} verified successfully.") |
|
|
|
|
|
asyncio.run(main()) |