|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torchvision.models import ( |
|
ResNet50_Weights, |
|
VGG16_BN_Weights, |
|
VGG16_Weights, |
|
resnet50, |
|
vgg16, |
|
vgg16_bn, |
|
) |
|
|
|
from engine.BiRefNet.config import Config |
|
from engine.BiRefNet.models.backbones.pvt_v2 import ( |
|
pvt_v2_b0, |
|
pvt_v2_b1, |
|
pvt_v2_b2, |
|
pvt_v2_b5, |
|
) |
|
from engine.BiRefNet.models.backbones.swin_v1 import ( |
|
swin_v1_b, |
|
swin_v1_l, |
|
swin_v1_s, |
|
swin_v1_t, |
|
) |
|
|
|
config = Config() |
|
|
|
|
|
def build_backbone(bb_name, pretrained=True, params_settings=""): |
|
if bb_name == "vgg16": |
|
bb_net = list( |
|
vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children() |
|
)[0] |
|
bb = nn.Sequential( |
|
OrderedDict( |
|
{ |
|
"conv1": bb_net[:4], |
|
"conv2": bb_net[4:9], |
|
"conv3": bb_net[9:16], |
|
"conv4": bb_net[16:23], |
|
} |
|
) |
|
) |
|
elif bb_name == "vgg16bn": |
|
bb_net = list( |
|
vgg16_bn( |
|
pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None |
|
).children() |
|
)[0] |
|
bb = nn.Sequential( |
|
OrderedDict( |
|
{ |
|
"conv1": bb_net[:6], |
|
"conv2": bb_net[6:13], |
|
"conv3": bb_net[13:23], |
|
"conv4": bb_net[23:33], |
|
} |
|
) |
|
) |
|
elif bb_name == "resnet50": |
|
bb_net = list( |
|
resnet50( |
|
pretrained=ResNet50_Weights.DEFAULT if pretrained else None |
|
).children() |
|
) |
|
bb = nn.Sequential( |
|
OrderedDict( |
|
{ |
|
"conv1": nn.Sequential(*bb_net[0:3]), |
|
"conv2": bb_net[4], |
|
"conv3": bb_net[5], |
|
"conv4": bb_net[6], |
|
} |
|
) |
|
) |
|
else: |
|
bb = eval("{}({})".format(bb_name, params_settings)) |
|
if pretrained: |
|
bb = load_weights(bb, bb_name) |
|
return bb |
|
|
|
|
|
def load_weights(model, model_name): |
|
save_model = torch.load( |
|
config.weights[model_name], map_location="cpu", weights_only=True |
|
) |
|
model_dict = model.state_dict() |
|
state_dict = { |
|
k: v if v.size() == model_dict[k].size() else model_dict[k] |
|
for k, v in save_model.items() |
|
if k in model_dict.keys() |
|
} |
|
|
|
if not state_dict: |
|
save_model_keys = list(save_model.keys()) |
|
sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None |
|
state_dict = { |
|
k: v if v.size() == model_dict[k].size() else model_dict[k] |
|
for k, v in save_model[sub_item].items() |
|
if k in model_dict.keys() |
|
} |
|
if not state_dict or not sub_item: |
|
print( |
|
"Weights are not successully loaded. Check the state dict of weights file." |
|
) |
|
return None |
|
else: |
|
print( |
|
'Found correct weights in the "{}" item of loaded state_dict.'.format( |
|
sub_item |
|
) |
|
) |
|
model_dict.update(state_dict) |
|
model.load_state_dict(model_dict) |
|
return model |
|
|