Spaces:
Runtime error
Runtime error
import tops | |
import torch | |
from tops import checkpointer | |
from tops.config import instantiate | |
from tops.logger import warn | |
def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None): | |
state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"] | |
if ckpt_mapper is not None: | |
state = ckpt_mapper(state) | |
load_state_dict(G, state) | |
tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M") | |
print(ckpt.keys()) | |
if "w_centers" in ckpt: | |
print("Has w_centers!") | |
G.style_net.w_centers = ckpt["w_centers"] | |
tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}") | |
def build_trained_generator(cfg, map_location=None): | |
map_location = map_location if map_location is not None else tops.get_device() | |
G = instantiate(cfg.generator).to(map_location) | |
G.eval() | |
G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None | |
if hasattr(cfg, "ckpt_mapper"): | |
ckpt_mapper = instantiate(cfg.ckpt_mapper) | |
else: | |
ckpt_mapper = None | |
if "model_url" in cfg.common: | |
ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum, map_location=torch.device("cpu")) | |
load_generator_state(ckpt, G, ckpt_mapper) | |
return G | |
try: | |
ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu") | |
load_generator_state(ckpt, G, ckpt_mapper) | |
except FileNotFoundError as e: | |
tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}") | |
return G | |
def build_trained_discriminator(cfg, map_location=None): | |
map_location = map_location if map_location is not None else tops.get_device() | |
D = instantiate(cfg.discriminator).to(map_location) | |
D.eval() | |
try: | |
ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu") | |
if hasattr(cfg, "ckpt_mapper_D"): | |
ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"]) | |
D.load_state_dict(ckpt["discriminator"]) | |
except FileNotFoundError as e: | |
tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}") | |
return D | |
def load_state_dict(module: torch.nn.Module, state_dict: dict): | |
module_sd = module.state_dict() | |
to_remove = [] | |
for key, item in state_dict.items(): | |
if key not in module_sd: | |
continue | |
if item.shape != module_sd[key].shape: | |
to_remove.append(key) | |
warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}") | |
for key in to_remove: | |
state_dict.pop(key) | |
for key, item in state_dict.items(): | |
if key not in module_sd: | |
warn(f"Did not fin key in model state dict: {key}") | |
for key, item in module_sd.items(): | |
if key not in state_dict: | |
warn(f"Did not find key in state dict: {key}") | |
module.load_state_dict(state_dict, strict=False) |