File size: 2,490 Bytes
bd62227 171e2fc bd62227 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from typing import List
import yaml
import os
import torch
import torch.distributed as dist
import pydantic
from omegaconf import OmegaConf
from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader
class EvalConfig(pydantic.BaseModel):
checkpoint: str
save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"]
def launch():
eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore
RANK = 0
WORLD_SIZE = 1
# Initialize distributed training if in distributed environment (e.g. torchrun)
if "LOCAL_RANK" in os.environ:
# Initialize distributed, default device and dtype
dist.init_process_group(backend="nccl")
RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f:
config = PretrainConfig(**yaml.safe_load(f))
config.eval_save_outputs = eval_cfg.save_outputs
config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint)
# Dataloader
train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE)
# Models
train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE)
# Try unwrap torch.compile
try:
train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True)
except:
train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True)
train_state.step = 0
ckpt_filename = os.path.basename(eval_cfg.checkpoint)
if ckpt_filename.startswith("step_"):
train_state.step = int(ckpt_filename.removeprefix("step_"))
# Evaluate
print ("Starting evaluation")
train_state.model.eval()
metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE)
if metrics is not None:
print (metrics)
if __name__ == "__main__":
launch()
|