satia / utils /production_model
stinoco's picture
Added classification models for subcategories
a87c588
raw
history blame
1.02 kB
import torch
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class ProductionModel():
def __init__(self, tokenizer, dict_labels):
self.model = None
self.tokenizer = tokenizer
self.dict_labels = dict_labels
def predict(self, X):
'''
M茅todo que genera la predicci贸n sobre nuevos datos (X).
X: Lista con los datos, cada elemento es una observaci贸n (list)
'''
if self.model is None:
raise ValueError('Debes cargar el modelo con self.model = torch.load(model_file.pt)')
X = self.tokenizer.tokenize(X)
X = torch.tensor(X, device = device)
self.model.eval()
with torch.no_grad():
predictions = self.model(X)[0]
predictions = F.softmax(predictions, dim = 1)
predictions = predictions.to('cpu').detach().numpy()
output = [{self.dict_labels[i]: float(lista[i]) for i in range(len(lista))} for lista in predictions]
if len(output) == 1:
return output[0]
return output