File size: 454 Bytes
cff2458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch

class Predictor:
    def __init__(self, model):
        self.model = model

    def predict(self, test_loader):
        self.model.eval()
        predictions = []

        with torch.no_grad():
            for images, _ in test_loader:
                outputs = self.model(images.view(-1, 28 * 28))
                _, predicted = torch.max(outputs, 1)
                predictions.extend(predicted.cpu().numpy())

        return predictions