Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import cv2 | |
import glob | |
import argparse | |
from pathlib import Path | |
from tqdm import tqdm | |
from copy import deepcopy | |
from scipy.optimize import minimize | |
import os | |
from collections import defaultdict | |
def group_by_directory(pathes, idx=-1): | |
""" | |
Groups the file paths based on the second-to-last directory in their paths. | |
Parameters: | |
- pathes (list): List of file paths. | |
Returns: | |
- dict: A dictionary where keys are the second-to-last directory names and values are lists of file paths. | |
""" | |
grouped_pathes = defaultdict(list) | |
for path in pathes: | |
# Extract the second-to-last directory | |
dir_name = os.path.dirname(path).split("/")[idx] | |
grouped_pathes[dir_name].append(path) | |
return grouped_pathes | |
def depth2disparity(depth, return_mask=False): | |
if isinstance(depth, torch.Tensor): | |
disparity = torch.zeros_like(depth) | |
elif isinstance(depth, np.ndarray): | |
disparity = np.zeros_like(depth) | |
non_negtive_mask = depth > 0 | |
disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask] | |
if return_mask: | |
return disparity, non_negtive_mask | |
else: | |
return disparity | |
def absolute_error_loss(params, predicted_depth, ground_truth_depth): | |
s, t = params | |
predicted_aligned = s * predicted_depth + t | |
abs_error = np.abs(predicted_aligned - ground_truth_depth) | |
return np.sum(abs_error) | |
def absolute_value_scaling(predicted_depth, ground_truth_depth, s=1, t=0): | |
predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1) | |
ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1) | |
initial_params = [s, t] # s = 1, t = 0 | |
result = minimize( | |
absolute_error_loss, | |
initial_params, | |
args=(predicted_depth_np, ground_truth_depth_np), | |
) | |
s, t = result.x | |
return s, t | |
def absolute_value_scaling2( | |
predicted_depth, | |
ground_truth_depth, | |
s_init=1.0, | |
t_init=0.0, | |
lr=1e-4, | |
max_iters=1000, | |
tol=1e-6, | |
): | |
# Initialize s and t as torch tensors with requires_grad=True | |
s = torch.tensor( | |
[s_init], | |
requires_grad=True, | |
device=predicted_depth.device, | |
dtype=predicted_depth.dtype, | |
) | |
t = torch.tensor( | |
[t_init], | |
requires_grad=True, | |
device=predicted_depth.device, | |
dtype=predicted_depth.dtype, | |
) | |
optimizer = torch.optim.Adam([s, t], lr=lr) | |
prev_loss = None | |
for i in range(max_iters): | |
optimizer.zero_grad() | |
# Compute predicted aligned depth | |
predicted_aligned = s * predicted_depth + t | |
# Compute absolute error | |
abs_error = torch.abs(predicted_aligned - ground_truth_depth) | |
# Compute loss | |
loss = torch.sum(abs_error) | |
# Backpropagate | |
loss.backward() | |
# Update parameters | |
optimizer.step() | |
# Check convergence | |
if prev_loss is not None and torch.abs(prev_loss - loss) < tol: | |
break | |
prev_loss = loss.item() | |
return s.detach().item(), t.detach().item() | |
def depth_evaluation( | |
predicted_depth_original, | |
ground_truth_depth_original, | |
max_depth=80, | |
custom_mask=None, | |
post_clip_min=None, | |
post_clip_max=None, | |
pre_clip_min=None, | |
pre_clip_max=None, | |
align_with_lstsq=False, | |
align_with_lad=False, | |
align_with_lad2=False, | |
metric_scale=False, | |
lr=1e-4, | |
max_iters=1000, | |
use_gpu=False, | |
align_with_scale=False, | |
disp_input=False, | |
): | |
""" | |
Evaluate the depth map using various metrics and return a depth error parity map, with an option for least squares alignment. | |
Args: | |
predicted_depth (numpy.ndarray or torch.Tensor): The predicted depth map. | |
ground_truth_depth (numpy.ndarray or torch.Tensor): The ground truth depth map. | |
max_depth (float): The maximum depth value to consider. Default is 80 meters. | |
align_with_lstsq (bool): If True, perform least squares alignment of the predicted depth with ground truth. | |
Returns: | |
dict: A dictionary containing the evaluation metrics. | |
torch.Tensor: The depth error parity map. | |
""" | |
if isinstance(predicted_depth_original, np.ndarray): | |
predicted_depth_original = torch.from_numpy(predicted_depth_original) | |
if isinstance(ground_truth_depth_original, np.ndarray): | |
ground_truth_depth_original = torch.from_numpy(ground_truth_depth_original) | |
if custom_mask is not None and isinstance(custom_mask, np.ndarray): | |
custom_mask = torch.from_numpy(custom_mask) | |
# if the dimension is 3, flatten to 2d along the batch dimension | |
if predicted_depth_original.dim() == 3: | |
_, h, w = predicted_depth_original.shape | |
predicted_depth_original = predicted_depth_original.view(-1, w) | |
ground_truth_depth_original = ground_truth_depth_original.view(-1, w) | |
if custom_mask is not None: | |
custom_mask = custom_mask.view(-1, w) | |
# put to device | |
if use_gpu: | |
predicted_depth_original = predicted_depth_original.cuda() | |
ground_truth_depth_original = ground_truth_depth_original.cuda() | |
# Filter out depths greater than max_depth | |
if max_depth is not None: | |
mask = (ground_truth_depth_original > 0) & ( | |
ground_truth_depth_original < max_depth | |
) | |
else: | |
mask = ground_truth_depth_original > 0 | |
predicted_depth = predicted_depth_original[mask] | |
ground_truth_depth = ground_truth_depth_original[mask] | |
# Clip the depth values | |
if pre_clip_min is not None: | |
predicted_depth = torch.clamp(predicted_depth, min=pre_clip_min) | |
if pre_clip_max is not None: | |
predicted_depth = torch.clamp(predicted_depth, max=pre_clip_max) | |
if disp_input: # align the pred to gt in the disparity space | |
real_gt = ground_truth_depth.clone() | |
ground_truth_depth = 1 / (ground_truth_depth + 1e-8) | |
# various alignment methods | |
if metric_scale: | |
predicted_depth = predicted_depth | |
elif align_with_lstsq: | |
# Convert to numpy for lstsq | |
predicted_depth_np = predicted_depth.cpu().numpy().reshape(-1, 1) | |
ground_truth_depth_np = ground_truth_depth.cpu().numpy().reshape(-1, 1) | |
# Add a column of ones for the shift term | |
A = np.hstack([predicted_depth_np, np.ones_like(predicted_depth_np)]) | |
# Solve for scale (s) and shift (t) using least squares | |
result = np.linalg.lstsq(A, ground_truth_depth_np, rcond=None) | |
s, t = result[0][0], result[0][1] | |
# convert to torch tensor | |
s = torch.tensor(s, device=predicted_depth_original.device) | |
t = torch.tensor(t, device=predicted_depth_original.device) | |
# Apply scale and shift | |
predicted_depth = s * predicted_depth + t | |
elif align_with_lad: | |
s, t = absolute_value_scaling( | |
predicted_depth, | |
ground_truth_depth, | |
s=torch.median(ground_truth_depth) / torch.median(predicted_depth), | |
) | |
predicted_depth = s * predicted_depth + t | |
elif align_with_lad2: | |
s_init = ( | |
torch.median(ground_truth_depth) / torch.median(predicted_depth) | |
).item() | |
s, t = absolute_value_scaling2( | |
predicted_depth, | |
ground_truth_depth, | |
s_init=s_init, | |
lr=lr, | |
max_iters=max_iters, | |
) | |
predicted_depth = s * predicted_depth + t | |
elif align_with_scale: | |
# Compute initial scale factor 's' using the closed-form solution (L2 norm) | |
dot_pred_gt = torch.nanmean(ground_truth_depth) | |
dot_pred_pred = torch.nanmean(predicted_depth) | |
s = dot_pred_gt / dot_pred_pred | |
# Iterative reweighted least squares using the Weiszfeld method | |
for _ in range(10): | |
# Compute residuals between scaled predictions and ground truth | |
residuals = s * predicted_depth - ground_truth_depth | |
abs_residuals = ( | |
residuals.abs() + 1e-8 | |
) # Add small constant to avoid division by zero | |
# Compute weights inversely proportional to the residuals | |
weights = 1.0 / abs_residuals | |
# Update 's' using weighted sums | |
weighted_dot_pred_gt = torch.sum( | |
weights * predicted_depth * ground_truth_depth | |
) | |
weighted_dot_pred_pred = torch.sum(weights * predicted_depth**2) | |
s = weighted_dot_pred_gt / weighted_dot_pred_pred | |
# Optionally clip 's' to prevent extreme scaling | |
s = s.clamp(min=1e-3) | |
# Detach 's' if you want to stop gradients from flowing through it | |
s = s.detach() | |
# Apply the scale factor to the predicted depth | |
predicted_depth = s * predicted_depth | |
else: | |
# Align the predicted depth with the ground truth using median scaling | |
scale_factor = torch.median(ground_truth_depth) / torch.median(predicted_depth) | |
predicted_depth *= scale_factor | |
if disp_input: | |
# convert back to depth | |
ground_truth_depth = real_gt | |
predicted_depth = depth2disparity(predicted_depth) | |
# Clip the predicted depth values | |
if post_clip_min is not None: | |
predicted_depth = torch.clamp(predicted_depth, min=post_clip_min) | |
if post_clip_max is not None: | |
predicted_depth = torch.clamp(predicted_depth, max=post_clip_max) | |
if custom_mask is not None: | |
assert custom_mask.shape == ground_truth_depth_original.shape | |
mask_within_mask = custom_mask.cpu()[mask] | |
predicted_depth = predicted_depth[mask_within_mask] | |
ground_truth_depth = ground_truth_depth[mask_within_mask] | |
# Calculate the metrics | |
abs_rel = torch.mean( | |
torch.abs(predicted_depth - ground_truth_depth) / ground_truth_depth | |
).item() | |
sq_rel = torch.mean( | |
((predicted_depth - ground_truth_depth) ** 2) / ground_truth_depth | |
).item() | |
# Correct RMSE calculation | |
rmse = torch.sqrt(torch.mean((predicted_depth - ground_truth_depth) ** 2)).item() | |
# Clip the depth values to avoid log(0) | |
predicted_depth = torch.clamp(predicted_depth, min=1e-5) | |
log_rmse = torch.sqrt( | |
torch.mean((torch.log(predicted_depth) - torch.log(ground_truth_depth)) ** 2) | |
).item() | |
# Calculate the accuracy thresholds | |
max_ratio = torch.maximum( | |
predicted_depth / ground_truth_depth, ground_truth_depth / predicted_depth | |
) | |
threshold_0 = torch.mean((max_ratio < 1.0).float()).item() | |
threshold_1 = torch.mean((max_ratio < 1.25).float()).item() | |
threshold_2 = torch.mean((max_ratio < 1.25**2).float()).item() | |
threshold_3 = torch.mean((max_ratio < 1.25**3).float()).item() | |
# Compute the depth error parity map | |
if metric_scale: | |
predicted_depth_original = predicted_depth_original | |
if disp_input: | |
predicted_depth_original = depth2disparity(predicted_depth_original) | |
depth_error_parity_map = ( | |
torch.abs(predicted_depth_original - ground_truth_depth_original) | |
/ ground_truth_depth_original | |
) | |
elif align_with_lstsq or align_with_lad or align_with_lad2: | |
predicted_depth_original = predicted_depth_original * s + t | |
if disp_input: | |
predicted_depth_original = depth2disparity(predicted_depth_original) | |
depth_error_parity_map = ( | |
torch.abs(predicted_depth_original - ground_truth_depth_original) | |
/ ground_truth_depth_original | |
) | |
elif align_with_scale: | |
predicted_depth_original = predicted_depth_original * s | |
if disp_input: | |
predicted_depth_original = depth2disparity(predicted_depth_original) | |
depth_error_parity_map = ( | |
torch.abs(predicted_depth_original - ground_truth_depth_original) | |
/ ground_truth_depth_original | |
) | |
else: | |
predicted_depth_original = predicted_depth_original * scale_factor | |
if disp_input: | |
predicted_depth_original = depth2disparity(predicted_depth_original) | |
depth_error_parity_map = ( | |
torch.abs(predicted_depth_original - ground_truth_depth_original) | |
/ ground_truth_depth_original | |
) | |
# Reshape the depth_error_parity_map back to the original image size | |
depth_error_parity_map_full = torch.zeros_like(ground_truth_depth_original) | |
depth_error_parity_map_full = torch.where( | |
mask, depth_error_parity_map, depth_error_parity_map_full | |
) | |
predict_depth_map_full = predicted_depth_original | |
gt_depth_map_full = torch.zeros_like(ground_truth_depth_original) | |
gt_depth_map_full = torch.where( | |
mask, ground_truth_depth_original, gt_depth_map_full | |
) | |
num_valid_pixels = ( | |
torch.sum(mask).item() | |
if custom_mask is None | |
else torch.sum(mask_within_mask).item() | |
) | |
if num_valid_pixels == 0: | |
( | |
abs_rel, | |
sq_rel, | |
rmse, | |
log_rmse, | |
threshold_0, | |
threshold_1, | |
threshold_2, | |
threshold_3, | |
) = (0, 0, 0, 0, 0, 0, 0, 0) | |
results = { | |
"Abs Rel": abs_rel, | |
"Sq Rel": sq_rel, | |
"RMSE": rmse, | |
"Log RMSE": log_rmse, | |
"δ < 1.": threshold_0, | |
"δ < 1.25": threshold_1, | |
"δ < 1.25^2": threshold_2, | |
"δ < 1.25^3": threshold_3, | |
"valid_pixels": num_valid_pixels, | |
} | |
return ( | |
results, | |
depth_error_parity_map_full, | |
predict_depth_map_full, | |
gt_depth_map_full, | |
) | |