Spaces:
Sleeping
Sleeping
File size: 737 Bytes
fa84113 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
import os
import logging
from collections import OrderedDict
from timm.models import load_checkpoint
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
def load_pretrained(model, url, filter_fn=None, strict=True):
if not url:
logging.warning("Pretrained model URL is empty, using random initialization. "
"Did you intend to use a `tf_` variant of the model?")
return
state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')
if filter_fn is not None:
state_dict = filter_fn(state_dict)
model.load_state_dict(state_dict, strict=strict)
|