File size: 2,118 Bytes
92c1934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import glob


def mae_torch(pred, gt):

    h, w = gt.shape[0:2]
    sumError = torch.sum(torch.absolute(torch.sub(pred.float(), gt.float())))
    maeError = torch.divide(sumError, float(h) * float(w) * 255.0 + 1e-4)

    return maeError


def f1score_torch(pd, gt):

    # print(gt.shape)
    gtNum = torch.sum((gt > 128).float() * 1)  ## number of ground truth pixels

    pp = pd[gt > 128]
    nn = pd[gt <= 128]

    pp_hist = torch.histc(pp, bins=255, min=0, max=255)
    nn_hist = torch.histc(nn, bins=255, min=0, max=255)

    pp_hist_flip = torch.flipud(pp_hist)
    nn_hist_flip = torch.flipud(nn_hist)

    pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
    nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)

    precision = (pp_hist_flip_cum) / (
        pp_hist_flip_cum + nn_hist_flip_cum + 1e-4
    )  # torch.divide(pp_hist_flip_cum,torch.sum(torch.sum(pp_hist_flip_cum, nn_hist_flip_cum), 1e-4))
    recall = (pp_hist_flip_cum) / (gtNum + 1e-4)
    f1 = (1 + 0.3) * precision * recall / (0.3 * precision + recall + 1e-4)

    return (
        torch.reshape(precision, (1, precision.shape[0])),
        torch.reshape(recall, (1, recall.shape[0])),
        torch.reshape(f1, (1, f1.shape[0])),
    )


def f1_mae_torch(pred, gt, valid_dataset, idx, mybins, hypar):

    import time

    tic = time.time()

    if len(gt.shape) > 2:
        gt = gt[:, :, 0]

    pre, rec, f1 = f1score_torch(pred, gt)
    mae = mae_torch(pred, gt)

    print(valid_dataset.dataset["im_name"][idx] + ".png")
    print("time for evaluation : ", time.time() - tic)

    return (
        pre.cpu().data.numpy(),
        rec.cpu().data.numpy(),
        f1.cpu().data.numpy(),
        mae.cpu().data.numpy(),
    )