deep_privacy2 / dp2 /infer.py
haakohu's picture
fix
a1cdc0f
raw
history blame
3.04 kB
import tops
import torch
from tops import checkpointer
from tops.config import instantiate
from tops.logger import warn
def load_generator_state(ckpt, G: torch.nn.Module, ckpt_mapper=None):
state = ckpt["EMA_generator"] if "EMA_generator" in ckpt else ckpt["running_average_generator"]
if ckpt_mapper is not None:
state = ckpt_mapper(state)
load_state_dict(G, state)
tops.logger.log(f"Generator loaded, num parameters: {tops.num_parameters(G)/1e6}M")
print(ckpt.keys())
if "w_centers" in ckpt:
print("Has w_centers!")
G.style_net.w_centers = ckpt["w_centers"]
tops.logger.log(f"W cluster centers loaded. Number of centers: {len(G.style_net.w_centers)}")
def build_trained_generator(cfg, map_location=None):
map_location = map_location if map_location is not None else tops.get_device()
G = instantiate(cfg.generator).to(map_location)
G.eval()
G.imsize = tuple(cfg.data.imsize) if hasattr(cfg, "data") else None
if hasattr(cfg, "ckpt_mapper"):
ckpt_mapper = instantiate(cfg.ckpt_mapper)
else:
ckpt_mapper = None
if "model_url" in cfg.common:
ckpt = tops.load_file_or_url(cfg.common.model_url, md5sum=cfg.common.model_md5sum, map_location=torch.device("cpu"))
load_generator_state(ckpt, G, ckpt_mapper)
return G
try:
ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
load_generator_state(ckpt, G, ckpt_mapper)
except FileNotFoundError as e:
tops.logger.warn(f"Did not find generator checkpoint in: {cfg.checkpoint_dir}")
return G
def build_trained_discriminator(cfg, map_location=None):
map_location = map_location if map_location is not None else tops.get_device()
D = instantiate(cfg.discriminator).to(map_location)
D.eval()
try:
ckpt = checkpointer.load_checkpoint(cfg.checkpoint_dir, map_location="cpu")
if hasattr(cfg, "ckpt_mapper_D"):
ckpt["discriminator"] = instantiate(cfg.ckpt_mapper_D)(ckpt["discriminator"])
D.load_state_dict(ckpt["discriminator"])
except FileNotFoundError as e:
tops.logger.warn(f"Did not find discriminator checkpoint in: {cfg.checkpoint_dir}")
return D
def load_state_dict(module: torch.nn.Module, state_dict: dict):
module_sd = module.state_dict()
to_remove = []
for key, item in state_dict.items():
if key not in module_sd:
continue
if item.shape != module_sd[key].shape:
to_remove.append(key)
warn(f"Incorrect shape. Current model: {module_sd[key].shape}, in state dict: {item.shape} for key: {key}")
for key in to_remove:
state_dict.pop(key)
for key, item in state_dict.items():
if key not in module_sd:
warn(f"Did not fin key in model state dict: {key}")
for key, item in module_sd.items():
if key not in state_dict:
warn(f"Did not find key in state dict: {key}")
module.load_state_dict(state_dict, strict=False)