Spaces:
Runtime error
Runtime error
from email.policy import strict | |
import torch | |
import torchvision.models | |
import os.path as osp | |
import copy | |
from ...log_service import print_log | |
from .utils import \ | |
get_total_param, get_total_param_sum, \ | |
get_unit | |
# def load_state_dict(net, model_path): | |
# if isinstance(net, dict): | |
# for ni, neti in net.items(): | |
# paras = torch.load(model_path[ni], map_location=torch.device('cpu')) | |
# new_paras = neti.state_dict() | |
# new_paras.update(paras) | |
# neti.load_state_dict(new_paras) | |
# else: | |
# paras = torch.load(model_path, map_location=torch.device('cpu')) | |
# new_paras = net.state_dict() | |
# new_paras.update(paras) | |
# net.load_state_dict(new_paras) | |
# return | |
# def save_state_dict(net, path): | |
# if isinstance(net, (torch.nn.DataParallel, | |
# torch.nn.parallel.DistributedDataParallel)): | |
# torch.save(net.module.state_dict(), path) | |
# else: | |
# torch.save(net.state_dict(), path) | |
def singleton(class_): | |
instances = {} | |
def getinstance(*args, **kwargs): | |
if class_ not in instances: | |
instances[class_] = class_(*args, **kwargs) | |
return instances[class_] | |
return getinstance | |
def preprocess_model_args(args): | |
# If args has layer_units, get the corresponding | |
# units. | |
# If args get backbone, get the backbone model. | |
args = copy.deepcopy(args) | |
if 'layer_units' in args: | |
layer_units = [ | |
get_unit()(i) for i in args.layer_units | |
] | |
args.layer_units = layer_units | |
if 'backbone' in args: | |
args.backbone = get_model()(args.backbone) | |
return args | |
class get_model(object): | |
def __init__(self): | |
self.model = {} | |
def register(self, model, name): | |
self.model[name] = model | |
def __call__(self, cfg, verbose=True): | |
""" | |
Construct model based on the config. | |
""" | |
if cfg is None: | |
return None | |
t = cfg.type | |
# the register is in each file | |
if t.find('pfd')==0: | |
from .. import pfd | |
elif t=='autoencoderkl': | |
from .. import autokl | |
elif (t.find('clip')==0) or (t.find('openclip')==0): | |
from .. import clip | |
elif t.find('openai_unet')==0: | |
from .. import openaimodel | |
elif t.find('controlnet')==0: | |
from .. import controlnet | |
elif t.find('seecoder')==0: | |
from .. import seecoder | |
elif t.find('swin')==0: | |
from .. import swin | |
args = preprocess_model_args(cfg.args) | |
net = self.model[t](**args) | |
pretrained = cfg.get('pretrained', None) | |
if pretrained is None: # backward compatible | |
pretrained = cfg.get('pth', None) | |
map_location = cfg.get('map_location', 'cpu') | |
strict_sd = cfg.get('strict_sd', True) | |
if pretrained is not None: | |
if osp.splitext(pretrained)[1] == '.pth': | |
sd = torch.load(pretrained, map_location=map_location) | |
elif osp.splitext(pretrained)[1] == '.ckpt': | |
sd = torch.load(pretrained, map_location=map_location)['state_dict'] | |
elif osp.splitext(pretrained)[1] == '.safetensors': | |
from safetensors.torch import load_file | |
from collections import OrderedDict | |
sd = load_file(pretrained, map_location) | |
sd = OrderedDict(sd) | |
net.load_state_dict(sd, strict=strict_sd) | |
if verbose: | |
print_log('Load model from [{}] strict [{}].'.format(pretrained, strict_sd)) | |
# display param_num & param_sum | |
if verbose: | |
print_log( | |
'Load {} with total {} parameters,' | |
'{:.3f} parameter sum.'.format( | |
t, | |
get_total_param(net), | |
get_total_param_sum(net) )) | |
return net | |
def register(name): | |
def wrapper(class_): | |
get_model().register(class_, name) | |
return class_ | |
return wrapper | |