Spaces:
Runtime error
Runtime error
File size: 3,036 Bytes
548d634 a1cdc0f 548d634 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import tops
import torch
from tops import checkpointer
from tops.config import instantiate
from tops.logger import warn
def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None):
state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"]
if ckpt_mapper is not None:
state = ckpt_mapper(state)
load_state_dict(G, state)
tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M")
print(ckpt.keys())
if "w_centers" in ckpt:
print("Has w_centers!")
G.style_net.w_centers = ckpt["w_centers"]
tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}")
def build_trained_generator(cfg, map_location=None):
map_location = map_location if map_location is not None else tops.get_device()
G = instantiate(cfg.generator).to(map_location)
G.eval()
G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None
if hasattr(cfg, "ckpt_mapper"):
ckpt_mapper = instantiate(cfg.ckpt_mapper)
else:
ckpt_mapper = None
if "model_url" in cfg.common:
ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum, map_location=torch.device("cpu"))
load_generator_state(ckpt, G, ckpt_mapper)
return G
try:
ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
load_generator_state(ckpt, G, ckpt_mapper)
except FileNotFoundError as e:
tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}")
return G
def build_trained_discriminator(cfg, map_location=None):
map_location = map_location if map_location is not None else tops.get_device()
D = instantiate(cfg.discriminator).to(map_location)
D.eval()
try:
ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
if hasattr(cfg, "ckpt_mapper_D"):
ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"])
D.load_state_dict(ckpt["discriminator"])
except FileNotFoundError as e:
tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}")
return D
def load_state_dict(module: torch.nn.Module, state_dict: dict):
module_sd = module.state_dict()
to_remove = []
for key, item in state_dict.items():
if key not in module_sd:
continue
if item.shape != module_sd[key].shape:
to_remove.append(key)
warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}")
for key in to_remove:
state_dict.pop(key)
for key, item in state_dict.items():
if key not in module_sd:
warn(f"Did not fin key in model state dict: {key}")
for key, item in module_sd.items():
if key not in state_dict:
warn(f"Did not find key in state dict: {key}")
module.load_state_dict(state_dict, strict=False) |