File size: 673 Bytes
cff2458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch.optim as optim
import torch.nn as nn

class Trainer:
    def __init__(self, model):
        self.model = model
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.01)

    def train(self, train_loader, epochs=5):
        for epoch in range(epochs):
            for images, labels in train_loader:
                self.optimizer.zero_grad()
                outputs = self.model(images.view(-1, 28*28))
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

            print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")