File size: 2,118 Bytes
92c1934 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import glob
def mae_torch(pred, gt):
h, w = gt.shape[0:2]
sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
maeError = torch.divide(sumError, float(h) * float(w) * 255.0 + 1e-4)
return maeError
def f1score_torch(pd, gt):
# print(gt.shape)
gtNum = torch.sum((gt > 128).float() * 1) ## number of ground truth pixels
pp = pd[gt > 128]
nn = pd[gt <= 128]
pp_hist = torch.histc(pp, bins=255, min=0, max=255)
nn_hist = torch.histc(nn, bins=255, min=0, max=255)
pp_hist_flip = torch.flipud(pp_hist)
nn_hist_flip = torch.flipud(nn_hist)
pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
precision = (pp_hist_flip_cum) / (
pp_hist_flip_cum + nn_hist_flip_cum + 1e-4
) # torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
recall = (pp_hist_flip_cum) / (gtNum + 1e-4)
f1 = (1 + 0.3) * precision * recall / (0.3 * precision + recall + 1e-4)
return (
torch.reshape(precision, (1, precision.shape[0])),
torch.reshape(recall, (1, recall.shape[0])),
torch.reshape(f1, (1, f1.shape[0])),
)
def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):
import time
tic = time.time()
if len(gt.shape) > 2:
gt = gt[:, :, 0]
pre, rec, f1 = f1score_torch(pred, gt)
mae = mae_torch(pred, gt)
print(valid_dataset.dataset["im_name"][idx] + ".png")
print("time for evaluation : ", time.time() - tic)
return (
pre.cpu().data.numpy(),
rec.cpu().data.numpy(),
f1.cpu().data.numpy(),
mae.cpu().data.numpy(),
)
|