File size: 6,440 Bytes
e371ddd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import torch
import torch.nn.functional as F
import cv2
from tqdm import tqdm
import json
from skimage.metrics import structural_similarity as calculate_ssim
import lpips
LPIPS = lpips.LPIPS(net='alex', version='0.1')


def calc_2D_metrics(pred_np, gt_np):
    # pred_np: [H, W, 3], [0, 255], np.uint8
    pred_image = torch.from_numpy(pred_np).unsqueeze(0).permute(0, 3, 1, 2)
    gt_image = torch.from_numpy(gt_np).unsqueeze(0).permute(0, 3, 1, 2)
    # [0-255] -> [-1, 1]
    pred_image = pred_image.float() / 127.5 - 1
    gt_image = gt_image.float() / 127.5 - 1
    # for 1 image
    # pixel loss
    loss = F.mse_loss(pred_image[0], gt_image[0].cpu()).item()
    # LPIPS
    lpips = LPIPS(pred_image[0], gt_image[0].cpu()).item()  # [-1, 1] torch tensor
    # SSIM
    ssim = calculate_ssim(pred_np, gt_np, channel_axis=2)
    # PSNR
    psnr = cv2.PSNR(gt_np, pred_np)

    return loss, lpips, ssim, psnr


# script to evaluate the model on GSO/RTMV/NeRF dataset
# todo modify the path
LOG_DIR = "../logs_6DoF"      # 6DoF, 4DoF
DATASET = "NeRF" # GSO RTMV NeRF GSO_6dof

if DATASET == "GSO":
    DATA_DIR = "/home/xin/data/EscherNet/Data/GSO30/"
    T_ins = [1, 2, 3, 5, 10]
    total_views = 25
    DATA_TYPE = "render_mvs_25"
    start_id = 10   # from 10 to total_views as test views
elif DATASET == "GSO_6dof":
    DATA_DIR = "/home/xin/data/dreammapping/evaluation/GSO/syncdreamer/gso/"    # TODO modify the path
    T_ins = [1, 2, 3, 5, 10]
    total_views = 25
    DATA_TYPE = "render_6dof_25"
    start_id = 10
elif DATASET == "RTMV":
    T_ins = [1, 2, 3, 5, 10]
    total_views = 20
    start_id = 10
elif DATASET == "NeRF":
    DATA_DIR = "/home/xin/data/EscherNet/Data/nerf_synthetic/"
    T_ins = [1, 2, 3, 5, 10, 20, 50, 100]
    total_views = 200
    start_id = 0
else:
    raise NotImplementedError

for T_in in tqdm(T_ins):
    if DATASET == "GSO":
        log_dir = os.path.join(LOG_DIR, f"GSO25/N{T_in}M25")
    elif DATASET == "GSO_6dof":
        log_dir = os.path.join(LOG_DIR, f"GSO25_6dof/N{T_in}M25")
    elif DATASET == "RTMV":
        log_dir = os.path.join(LOG_DIR, f"RTMV/N{T_in}M20")
        DATA_DIR = log_dir
    elif DATASET == "NeRF":
        log_dir = os.path.join(LOG_DIR, f"NeRF/N{T_in}M200")

    # calc 2D metrics
    val_lpips = 0
    val_ssim = 0
    val_psnr = 0
    val_loss = 0
    val_num = 0

    # get all objects
    objects = [f for f in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, f))]

    for obj in objects:
        lpips = 0
        ssim = 0
        psnr = 0
        loss = 0
        num = 0
        color = [1., 1., 1., 1.]

        if DATASET == "NeRF":
            # load test info
            with open(os.path.join(DATA_DIR, obj, "transforms_test.json"), "r") as f:
                test_info = json.load(f)["frames"]
            total_views = len(test_info)

        for i in range(start_id, total_views):
            # load the ground truth
            if DATASET == "GSO" or DATASET == "GSO_6dof":
                gt_path = os.path.join(DATA_DIR, obj, DATA_TYPE, "model", f"{i:03d}.png")    # 001 for target view
                img = plt.imread(gt_path)
                img[img[:, :, -1] == 0.] = color
                gt = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
                gt = gt.resize((256, 256))
                gt = np.array(gt)
            elif DATASET == "RTMV":
                gt_path = os.path.join(DATA_DIR, obj, "gt.png")
                img = plt.imread(gt_path)
                gt = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
                gt = np.array(gt)
                gt = gt[:, 256 * i:256 * (i + 1), :]
            elif DATASET == "NeRF":
                img_path = os.path.join(DATA_DIR, obj, test_info[i]["file_path"] + ".png")
                img = plt.imread(img_path)
                img[img[:, :, -1] == 0.] = color
                gt = Image.fromarray(np.uint8(img[:, :, :3] * 255.)).convert("RGB")
                gt = gt.resize((256, 256))
                gt = np.array(gt)


            # load the prediction
            if DATASET == "GSO" or DATASET == "GSO_6dof":
                pred_path = os.path.join(log_dir, obj, "0.png")
                # split width 4096 to 16 images, each 256 width
                pred = plt.imread(pred_path)
                pred = Image.fromarray(np.uint8(pred[:, :, :3] * 255.))
                # Image to numpy
                pred = np.array(pred)
                pred = pred[:, 256*i:256*(i+1), :]
            elif DATASET == "RTMV":
                pred_path = os.path.join(log_dir, obj, "0.png")
                # split width 4096 to 16 images, each 256 width
                pred = plt.imread(pred_path)
                pred = Image.fromarray(np.uint8(pred[:, :, :3] * 255.))
                pred = np.array(pred)
                pred = pred[:, 256 * i:256 * (i + 1), :]
            elif DATASET == "NeRF":
                pred_path = os.path.join(log_dir, obj, f"{i}.png")
                # read preds one by one
                pred = plt.imread(pred_path)
                pred = Image.fromarray(np.uint8(pred[:, :, :3] * 255.))
                pred = pred.resize((256, 256))
                pred = np.array(pred)

            # 2D metrics
            loss_i, lpips_i, ssim_i, psnr_i = calc_2D_metrics(pred, gt)

            loss += loss_i
            lpips += lpips_i
            ssim += ssim_i
            psnr += psnr_i
            num += 1
        loss /= num
        lpips /= num
        ssim /= num
        psnr /= num

        val_loss += loss
        val_lpips += lpips
        val_ssim += ssim
        val_psnr += psnr
        val_num += 1

        print(f"obj: {obj}, loss: {loss}, lpips: {lpips}, ssim: {ssim}, psnr: {psnr}")
        # save the results to txt
        with open(os.path.join(log_dir, "metrics2D.txt"), "a") as f:
            f.write(f"obj: {obj}, loss: {loss}, lpips: {lpips}, ssim: {ssim}, psnr: {psnr}\n")


    print(f"avg loss: {val_loss / val_num}, avg lpips: {val_lpips / val_num}, avg ssim: {val_ssim / val_num}, avg psnr: {val_psnr / val_num}")
    # save the results to txt
    with open(os.path.join(log_dir, "metrics2D.txt"), "a") as f:
        f.write(f"avg loss: {val_loss / val_num}, avg lpips: {val_lpips / val_num}, avg ssim: {val_ssim / val_num}, avg psnr: {val_psnr / val_num}")
        # add a new line
        f.write("\n")