| import os | |
| from omegaconf import OmegaConf | |
| import torch | |
| import tempfile | |
| from safetensors.torch import load_file | |
| import requests | |
| import yaml | |
| def get_ckpt(path, key="state_dict"): | |
| is_url = path.startswith("http://") or path.startswith("https://") | |
| suffix = os.path.splitext(path)[-1] | |
| if is_url: | |
| print(f"Loading checkpoint from URL: {path}") | |
| with tempfile.NamedTemporaryFile(suffix=suffix) as tmp_file: | |
| response = requests.get(path) | |
| response.raise_for_status() | |
| tmp_file.write(response.content) | |
| tmp_file.flush() | |
| ckpt_path = tmp_file.name | |
| if suffix == ".safetensors": | |
| checkpoint = load_file(ckpt_path) | |
| else: | |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| else: | |
| print(f"Loading checkpoint from local path: {path}") | |
| if suffix == ".safetensors": | |
| checkpoint = load_file(path) | |
| else: | |
| checkpoint = torch.load(path, map_location="cpu", weights_only=False) | |
| if key is not None and key in checkpoint: | |
| checkpoint = checkpoint[key] | |
| return checkpoint | |
| def get_yaml_config(path): | |
| if path.startswith("http://") or path.startswith("https://"): | |
| response = requests.get(path) | |
| response.raise_for_status() | |
| config = OmegaConf.create(response.text) | |
| else: | |
| with open(path, 'r') as f: | |
| config = OmegaConf.load(f) | |
| return config | |