File size: 698 Bytes
d1aaae4
 
7064c13
d1aaae4
7064c13
d1aaae4
 
 
 
 
7064c13
 
 
 
 
 
8ab7bd5
7064c13
8ab7bd5
 
 
 
 
d1aaae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from torchtext.data.utils import get_tokenizer
from model_arch import TextClassifierModel, load_state_dict

model_trained = torch.load('model_checkpoint.pth')
vocab = torch.load('vocab.pt')
tokenizer = get_tokenizer("spacy", language="es")

text_pipeline = lambda x: vocab(tokenizer(x))

num_class = 11
vocab_size = len(vocab)
embed_size = 300

model = TextClassifierModel(vocab_size, embed_size, num_class)

model = load_state_dict(model, model_trained, vocab)

def predict(text, model=model, text_pipeline=text_pipeline):
    with torch.no_grad()
    model.eval()
    text_tensor = torch.tensor(text_pipeline(text))
    return model(text_tensor, torch.tensor([0])).argmax(1).item()