File size: 655 Bytes
a87c588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import pickle
from glob import glob
import torch
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(folder):
'''
Función que tiene por objetivo cargar un modelo de predicción.
Utiliza un modelo .pt y un objeto .pkl
folder: carpeta de la que cargar el modelo (str)
'''
base_folder = 'production_models'
folder = folder
model_path = glob(os.path.join(base_folder, folder, '*.pt'))[0]
clf_path = glob(os.path.join(base_folder, folder, '*.pkl'))[0]
with open(clf_path, 'rb') as file:
clf = pickle.load(file)
clf.model = torch.load(model_path, map_location = device)
return clf |