Spaces:
Running
on
Zero
Running
on
Zero
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
from torchvision.models import ( | |
ResNet50_Weights, | |
VGG16_BN_Weights, | |
VGG16_Weights, | |
resnet50, | |
vgg16, | |
vgg16_bn, | |
) | |
from engine.BiRefNet.config import Config | |
from engine.BiRefNet.models.backbones.pvt_v2 import ( | |
pvt_v2_b0, | |
pvt_v2_b1, | |
pvt_v2_b2, | |
pvt_v2_b5, | |
) | |
from engine.BiRefNet.models.backbones.swin_v1 import ( | |
swin_v1_b, | |
swin_v1_l, | |
swin_v1_s, | |
swin_v1_t, | |
) | |
config = Config() | |
def build_backbone(bb_name, pretrained=True, params_settings=""): | |
if bb_name == "vgg16": | |
bb_net = list( | |
vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children() | |
)[0] | |
bb = nn.Sequential( | |
OrderedDict( | |
{ | |
"conv1": bb_net[:4], | |
"conv2": bb_net[4:9], | |
"conv3": bb_net[9:16], | |
"conv4": bb_net[16:23], | |
} | |
) | |
) | |
elif bb_name == "vgg16bn": | |
bb_net = list( | |
vgg16_bn( | |
pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None | |
).children() | |
)[0] | |
bb = nn.Sequential( | |
OrderedDict( | |
{ | |
"conv1": bb_net[:6], | |
"conv2": bb_net[6:13], | |
"conv3": bb_net[13:23], | |
"conv4": bb_net[23:33], | |
} | |
) | |
) | |
elif bb_name == "resnet50": | |
bb_net = list( | |
resnet50( | |
pretrained=ResNet50_Weights.DEFAULT if pretrained else None | |
).children() | |
) | |
bb = nn.Sequential( | |
OrderedDict( | |
{ | |
"conv1": nn.Sequential(*bb_net[0:3]), | |
"conv2": bb_net[4], | |
"conv3": bb_net[5], | |
"conv4": bb_net[6], | |
} | |
) | |
) | |
else: | |
bb = eval("{}({})".format(bb_name, params_settings)) | |
if pretrained: | |
bb = load_weights(bb, bb_name) | |
return bb | |
def load_weights(model, model_name): | |
save_model = torch.load( | |
config.weights[model_name], map_location="cpu", weights_only=True | |
) | |
model_dict = model.state_dict() | |
state_dict = { | |
k: v if v.size() == model_dict[k].size() else model_dict[k] | |
for k, v in save_model.items() | |
if k in model_dict.keys() | |
} | |
# to ignore the weights with mismatched size when I modify the backbone itself. | |
if not state_dict: | |
save_model_keys = list(save_model.keys()) | |
sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None | |
state_dict = { | |
k: v if v.size() == model_dict[k].size() else model_dict[k] | |
for k, v in save_model[sub_item].items() | |
if k in model_dict.keys() | |
} | |
if not state_dict or not sub_item: | |
print( | |
"Weights are not successully loaded. Check the state dict of weights file." | |
) | |
return None | |
else: | |
print( | |
'Found correct weights in the "{}" item of loaded state_dict.'.format( | |
sub_item | |
) | |
) | |
model_dict.update(state_dict) | |
model.load_state_dict(model_dict) | |
return model | |