satia / utils /load_model.py
stinoco's picture
Added classification models for subcategories
a87c588
raw
history blame
655 Bytes
import pickle
from glob import glob
import torch
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(folder):
'''
Funci贸n que tiene por objetivo cargar un modelo de predicci贸n.
Utiliza un modelo .pt y un objeto .pkl
folder: carpeta de la que cargar el modelo (str)
'''
base_folder = 'production_models'
folder = folder
model_path = glob(os.path.join(base_folder, folder, '*.pt'))[0]
clf_path = glob(os.path.join(base_folder, folder, '*.pkl'))[0]
with open(clf_path, 'rb') as file:
clf = pickle.load(file)
clf.model = torch.load(model_path, map_location = device)
return clf