import pickle | |
from glob import glob | |
import torch | |
import os | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def load_model(folder): | |
''' | |
Funci贸n que tiene por objetivo cargar un modelo de predicci贸n. | |
Utiliza un modelo .pt y un objeto .pkl | |
folder: carpeta de la que cargar el modelo (str) | |
''' | |
base_folder = 'production_models' | |
folder = folder | |
model_path = glob(os.path.join(base_folder, folder, '*.pt'))[0] | |
clf_path = glob(os.path.join(base_folder, folder, '*.pkl'))[0] | |
with open(clf_path, 'rb') as file: | |
clf = pickle.load(file) | |
clf.model = torch.load(model_path, map_location = device) | |
return clf |