santit96's picture
Create the streamlit app that classifies the trash in an image into classes
fa84113
raw
history blame
737 Bytes
import torch
import os
import logging
from collections import OrderedDict
from timm.models import load_checkpoint
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
def load_pretrained(model, url, filter_fn=None, strict=True):
if not url:
logging.warning("Pretrained model URL is empty, using random initialization. "
"Did you intend to use a `tf_` variant of the model?")
return
state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')
if filter_fn is not None:
state_dict = filter_fn(state_dict)
model.load_state_dict(state_dict, strict=strict)