| import torch | |
| from tqdm import tqdm | |
| from training.utils import VariableLossLogPrinter | |
| def get_acc(outputs, targets): | |
| _, predicted = torch.max(outputs.data, 1) | |
| total = targets.size(0) | |
| correct = (predicted == targets).sum().item() | |
| return correct / total * 100 | |
| def train(model, train_loader, optimizer, fdl, epoch): | |
| model.train() | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| VariableLossPrinter = VariableLossLogPrinter() | |
| model = model.to(device) | |
| iterator = tqdm(enumerate(train_loader), total=len(train_loader)) | |
| for batch_idx, (data, target) in iterator: | |
| on_device = data.to(device) | |
| target_on_device = target.to(device) | |
| output, feature_maps = model(on_device, with_feature_maps=True) | |
| loss = torch.nn.functional.cross_entropy(output, target_on_device) | |
| fdl_loss = fdl(feature_maps, output) | |
| total_loss = loss + fdl_loss | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| optimizer.step() | |
| acc = get_acc(output, target_on_device) | |
| VariableLossPrinter.log_loss("Train Acc", acc, on_device.size(0)) | |
| VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0)) | |
| VariableLossPrinter.log_loss("FDL", fdl_loss.item(), on_device.size(0)) | |
| VariableLossPrinter.log_loss("Total-Loss", total_loss.item(), on_device.size(0)) | |
| iterator.set_description(f"Train Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}") | |
| print("Trained model for one epoch ", epoch," with lr group 0: ", optimizer.param_groups[0]["lr"]) | |
| return model | |
| def test(model, test_loader, epoch): | |
| model.eval() | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| model = model.to(device) | |
| VariableLossPrinter = VariableLossLogPrinter() | |
| iterator = tqdm(enumerate(test_loader), total=len(test_loader)) | |
| with torch.no_grad(): | |
| for batch_idx, (data, target) in iterator: | |
| on_device = data.to(device) | |
| target_on_device = target.to(device) | |
| output, feature_maps = model(on_device, with_feature_maps=True) | |
| loss = torch.nn.functional.cross_entropy(output, target_on_device) | |
| acc = get_acc(output, target_on_device) | |
| VariableLossPrinter.log_loss("Test Acc", acc, on_device.size(0)) | |
| VariableLossPrinter.log_loss("CE-Loss", loss.item(), on_device.size(0)) | |
| iterator.set_description(f"Test Epoch:{epoch} Metrics: {VariableLossPrinter.get_loss_string()}") | |