Spaces:
Runtime error
Runtime error
| # A reimplemented version in public environments by Xiao Fu and Mu Hu | |
| import numpy as np | |
| from scipy.optimize import least_squares | |
| import torch | |
| def align_scale_shift(pred, target, clip_max): | |
| mask = (target > 0) & (target < clip_max) | |
| if mask.sum() > 10: | |
| target_mask = target[mask] | |
| pred_mask = pred[mask] | |
| scale, shift = np.polyfit(pred_mask, target_mask, deg=1) | |
| return scale, shift | |
| else: | |
| return 1, 0 | |
| def align_scale(pred: torch.tensor, target: torch.tensor): | |
| mask = target > 0 | |
| if torch.sum(mask) > 10: | |
| scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) | |
| else: | |
| scale = 1 | |
| pred_scale = pred * scale | |
| return pred_scale, scale | |
| def align_shift(pred: torch.tensor, target: torch.tensor): | |
| mask = target > 0 | |
| if torch.sum(mask) > 10: | |
| shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8) | |
| else: | |
| shift = 0 | |
| pred_shift = pred + shift | |
| return pred_shift, shift |