| import torch | |
| import numpy | |
| def predict(mask, label, threshold=0.5, score_type='combined'): | |
| with torch.no_grad(): | |
| if score_type == 'pixel': | |
| score = torch.mean(mask, axis=(1, 2, 3)) | |
| elif score_type == 'binary': | |
| score = label | |
| else: | |
| score = (torch.mean(mask, axis=(1, 2, 3)) + label) / 2 | |
| preds = (score > threshold).type(torch.FloatTensor) | |
| return preds, score | |
| def test_accuracy(model, test_dl): | |
| acc = 0 | |
| total = len(test_dl.dataset) | |
| for img, mask, label in test_dl: | |
| net_mask, net_label = model(img) | |
| preds, _ = predict(net_mask, net_label) | |
| ac = (preds == label).type(torch.FloatTensor) | |
| acc += torch.sum(ac).item() | |
| return (acc / total) * 100 | |
| def test_loss(model, test_dl, loss_fn): | |
| loss = 0 | |
| total = len(test_dl) | |
| for img, mask, label in test_dl: | |
| net_mask, net_label = model(img) | |
| losses = loss_fn(net_mask, net_label, mask, label) | |
| loss += torch.mean(losses).item() | |
| return loss / total | |