File size: 569 Bytes
d1aaae4
 
7064c13
d1aaae4
7064c13
d1aaae4
 
 
 
 
7064c13
 
 
 
 
 
 
 
 
 
d1aaae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torchtext.data.utils import get_tokenizer
from model_arch import TextClassifierModel, load_state_dict

model_trained = torch.load('model_checkpoint.pth')
vocab = torch.load('vocab.pt')
tokenizer = get_tokenizer("spacy", language="es")

text_pipeline = lambda x: vocab(tokenizer(x))

num_class = 11
vocab_size = len(vocab)
embed_size = 300
lr = 0.4

model = TextClassifierModel(vocab_size, embed_size, num_class)
optimizer = torch.optim.SGD(model_test.parameters(), lr=0.4)


model, optimizer = load_state_dict(model, optimizer, model_trained, vocab)