Realcat's picture
add: ripe
e6ac593
import pyrootutils
root = pyrootutils.setup_root(
search_from=__file__,
indicator=[".git", "pyproject.toml"],
pythonpath=True,
dotenv=True,
)
SEED = 32000
import collections
import os
import hydra
from hydra.utils import instantiate
from lightning.fabric import Fabric
print(SEED)
import random
os.environ["PYTHONHASHSEED"] = str(SEED)
import numpy as np
import torch
import tqdm
import wandb
from torch.optim.adamw import AdamW
from torch.utils.data import DataLoader
from ripe import utils
from ripe.benchmarks.imw_2020 import IMW_2020_Benchmark
from ripe.utils.utils import get_rewards
from ripe.utils.wandb_utils import get_flattened_wandb_cfg
log = utils.get_pylogger(__name__)
from pathlib import Path
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
def unpack_batch(batch):
src_image = batch["src_image"]
trg_image = batch["trg_image"]
trg_mask = batch["trg_mask"]
src_mask = batch["src_mask"]
label = batch["label"]
H = batch["homography"]
return src_image, trg_image, src_mask, trg_mask, H, label
@hydra.main(config_path="../conf/", config_name="config", version_base=None)
def train(cfg):
"""Main training function for the RIPE model."""
# Prepare model, data and hyperparms
strategy = "ddp" if cfg.num_gpus > 1 else "auto"
fabric = Fabric(
accelerator="cuda",
devices=cfg.num_gpus,
precision=cfg.precision,
strategy=strategy,
)
fabric.launch()
output_dir = Path(cfg.output_dir)
experiment_name = output_dir.parent.parent.parent.name
run_id = output_dir.parent.parent.name
timestamp = output_dir.parent.name + "_" + output_dir.name
experiment_name = run_id + " " + timestamp + " " + experiment_name
# setup logger
wandb_logger = wandb.init(
project=cfg.project_name,
name=experiment_name,
config=get_flattened_wandb_cfg(cfg),
dir=cfg.output_dir,
mode=cfg.wandb_mode,
)
min_nums_matches = {"homography": 4, "fundamental": 8, "fundamental_7pt": 7}
min_num_matches = min_nums_matches[cfg.transformation_model]
print(f"Minimum number of matches for {cfg.transformation_model} is {min_num_matches}")
batch_size = cfg.batch_size
steps = cfg.num_steps
lr = cfg.lr
num_grad_accs = (
cfg.num_grad_accs
) # this performs grad accumulation to simulate larger batch size, set to 1 to disable;
# instantiate dataset
ds = instantiate(cfg.data)
# prepare dataloader
dl = DataLoader(
ds,
batch_size=batch_size,
shuffle=True,
drop_last=True,
persistent_workers=False,
num_workers=cfg.num_workers,
)
dl = fabric.setup_dataloaders(dl)
i_dl = iter(dl)
# create matcher
matcher = instantiate(cfg.matcher)
if cfg.desc_loss_weight != 0.0:
descriptor_loss = instantiate(cfg.descriptor_loss)
else:
log.warning(
"Descriptor loss weight is 0.0, descriptor loss will not be used. 1x1 conv for descriptors will be deactivated!"
)
descriptor_loss = None
upsampler = instantiate(cfg.upsampler) if "upsampler" in cfg else None
# create network
net = instantiate(cfg.network)(
net=instantiate(cfg.backbones),
upsampler=upsampler,
descriptor_dim=cfg.descriptor_dim if descriptor_loss is not None else None,
device=fabric.device,
).train()
# get num parameters
num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
log.info(f"Number of parameters: {num_params}")
fp_penalty = cfg.fp_penalty # small penalty for not finding a match
kp_penalty = cfg.kp_penalty # small penalty for low logprob keypoints
opt_pi = AdamW(filter(lambda x: x.requires_grad, net.parameters()), lr=lr, weight_decay=1e-5)
net, opt_pi = fabric.setup(net, opt_pi)
if cfg.lr_scheduler:
scheduler = instantiate(cfg.lr_scheduler)(optimizer=opt_pi, steps_init=0)
else:
scheduler = None
val_benchmark = IMW_2020_Benchmark(
use_predefined_subset=True,
conf_inference=cfg.conf_inference,
edge_input_divisible_by=None,
)
# mean average of skipped batches
# this is used to monitor how many batches were skipped due to not enough keypoints
# this is useful to detect if the model is not learning anything -> should be zero
ma_skipped_batches = collections.deque(maxlen=100)
opt_pi.zero_grad()
# initialize scheduler
alpha_scheduler = instantiate(cfg.alpha_scheduler)
beta_scheduler = instantiate(cfg.beta_scheduler)
inl_th_scheduler = instantiate(cfg.inl_th)
# ====== Training Loop ======
# check if the model is in training mode
net.train()
with tqdm.tqdm(total=steps) as pbar:
for i_step in range(steps):
alpha = alpha_scheduler(i_step)
beta = beta_scheduler(i_step)
inl_th = inl_th_scheduler(i_step)
if scheduler:
scheduler.step()
# Initialize vars for current step
# We need to handle batching because the description can have arbitrary number of keypoints
sum_reward_batch = 0
sum_num_keypoints_1 = 0
sum_num_keypoints_2 = 0
loss = None
loss_policy_stack = None
loss_desc_stack = None
loss_kp_stack = None
try:
batch = next(i_dl)
except StopIteration:
i_dl = iter(dl)
batch = next(i_dl)
p1, p2, mask_padding_1, mask_padding_2, Hs, label = unpack_batch(batch)
(
kpts1,
logprobs1,
selected_mask1,
mask_padding_grid_1,
logits_selected_1,
out1,
) = net(p1, mask_padding_1, training=True)
(
kpts2,
logprobs2,
selected_mask2,
mask_padding_grid_2,
logits_selected_2,
out2,
) = net(p2, mask_padding_2, training=True)
# upsample coarse descriptors for all keypoints from the intermediate feature maps from the encoder
desc_1 = net.get_descs(out1["coarse_descs"], p1, kpts1, p1.shape[2], p1.shape[3])
desc_2 = net.get_descs(out2["coarse_descs"], p2, kpts2, p2.shape[2], p2.shape[3])
if cfg.padding_filter_mode == "ignore": # remove keypoints that are in padding
batch_mask_selection_for_matching_1 = selected_mask1 & mask_padding_grid_1
batch_mask_selection_for_matching_2 = selected_mask2 & mask_padding_grid_2
elif cfg.padding_filter_mode == "punish":
batch_mask_selection_for_matching_1 = selected_mask1 # keep all keypoints
batch_mask_selection_for_matching_2 = selected_mask2 # punish the keypoints in the padding area
else:
raise ValueError(f"Unknown padding filter mode: {cfg.padding_filter_mode}")
(
batch_rel_idx_matches,
batch_abs_idx_matches,
batch_ransac_inliers,
batch_Fm,
) = matcher(
kpts1,
kpts2,
desc_1,
desc_2,
batch_mask_selection_for_matching_1,
batch_mask_selection_for_matching_2,
inl_th,
label if cfg.no_filtering_negatives else None,
)
for b in range(batch_size):
# ignore if less than 16 keypoints have been detected
if batch_rel_idx_matches[b] is None:
ma_skipped_batches.append(1)
continue
else:
ma_skipped_batches.append(0)
mask_selection_for_matching_1 = batch_mask_selection_for_matching_1[b]
mask_selection_for_matching_2 = batch_mask_selection_for_matching_2[b]
rel_idx_matches = batch_rel_idx_matches[b]
abs_idx_matches = batch_abs_idx_matches[b]
ransac_inliers = batch_ransac_inliers[b]
if cfg.selected_only:
# every SELECTED keypoint with every other SELECTED keypoint
dense_logprobs = logprobs1[b][mask_selection_for_matching_1].view(-1, 1) + logprobs2[b][
mask_selection_for_matching_2
].view(1, -1)
else:
if cfg.padding_filter_mode == "ignore":
# every keypoint with every other keypoint, but WITHOUT keypoint in the padding area
dense_logprobs = logprobs1[b][mask_padding_grid_1[b]].view(-1, 1) + logprobs2[b][
mask_padding_grid_2[b]
].view(1, -1)
elif cfg.padding_filter_mode == "punish":
# every keypoint with every other keypoint, also WITH keypoints in the padding areas -> will be punished by the reward
dense_logprobs = logprobs1[b].view(-1, 1) + logprobs2[b].view(1, -1)
else:
raise ValueError(f"Unknown padding filter mode: {cfg.padding_filter_mode}")
reward = None
if cfg.reward_type == "inlier":
reward = (
0.5 if cfg.no_filtering_negatives and not label[b] else 1.0
) # reward is 1.0 if the pair is positive, 0.5 if negative and no filtering is applied
elif cfg.reward_type == "inlier_ratio":
ratio_inlier = ransac_inliers.sum() / len(abs_idx_matches)
reward = ratio_inlier # reward is the ratio of inliers -> higher if more matches are inliers
elif cfg.reward_type == "inlier+inlier_ratio":
ratio_inlier = ransac_inliers.sum() / len(abs_idx_matches)
reward = (
(1.0 - beta) * 1.0 + beta * ratio_inlier
) # reward is a combination of the ratio of inliers and the number of inliers -> gradually changes
else:
raise ValueError(f"Unknown reward type: {cfg.reward_type}")
dense_rewards = get_rewards(
reward,
kpts1[b],
kpts2[b],
mask_selection_for_matching_1,
mask_selection_for_matching_2,
mask_padding_grid_1[b],
mask_padding_grid_2[b],
rel_idx_matches,
abs_idx_matches,
ransac_inliers,
label[b],
fp_penalty * alpha,
use_whitening=cfg.use_whitening,
selected_only=cfg.selected_only,
filter_mode=cfg.padding_filter_mode,
)
if descriptor_loss is not None:
hard_loss = descriptor_loss(
desc1=desc_1[b],
desc2=desc_2[b],
matches=abs_idx_matches,
inliers=ransac_inliers,
label=label[b],
logits_1=None,
logits_2=None,
)
loss_desc_stack = (
hard_loss if loss_desc_stack is None else torch.hstack((loss_desc_stack, hard_loss))
)
sum_reward_batch += dense_rewards.sum()
current_loss_policy = (dense_rewards * dense_logprobs).view(-1)
loss_policy_stack = (
current_loss_policy
if loss_policy_stack is None
else torch.hstack((loss_policy_stack, current_loss_policy))
)
if kp_penalty != 0.0:
# keypoints with low logprob are penalized
# as they get large negative logprob values multiplying them with the penalty will make the loss larger
loss_kp = (
logprobs1[b][mask_selection_for_matching_1]
* torch.full_like(
logprobs1[b][mask_selection_for_matching_1],
kp_penalty * alpha,
)
).mean() + (
logprobs2[b][mask_selection_for_matching_2]
* torch.full_like(
logprobs2[b][mask_selection_for_matching_2],
kp_penalty * alpha,
)
).mean()
loss_kp_stack = loss_kp if loss_kp_stack is None else torch.hstack((loss_kp_stack, loss_kp))
sum_num_keypoints_1 += mask_selection_for_matching_1.sum()
sum_num_keypoints_2 += mask_selection_for_matching_2.sum()
loss = loss_policy_stack.mean()
if loss_kp_stack is not None:
loss += loss_kp_stack.mean()
loss = -loss
if descriptor_loss is not None:
loss += cfg.desc_loss_weight * loss_desc_stack.mean()
pbar.set_description(
f"LP: {loss.item():.4f} - Det: ({sum_num_keypoints_1 / batch_size:.4f}, {sum_num_keypoints_2 / batch_size:.4f}), #mRwd: {sum_reward_batch / batch_size:.1f}"
)
pbar.update()
# backward pass
loss /= num_grad_accs
fabric.backward(loss)
if i_step % num_grad_accs == 0:
opt_pi.step()
opt_pi.zero_grad()
if i_step % cfg.log_interval == 0:
wandb_logger.log(
{
# "loss": loss.item() if not use_amp else scaled_loss.item(),
"loss": loss.item(),
"loss_policy": -loss_policy_stack.mean().item(),
"loss_kp": loss_kp_stack.mean().item() if loss_kp_stack is not None else 0.0,
"loss_hard": (loss_desc_stack.mean().item() if loss_desc_stack is not None else 0.0),
"mean_num_det_kpts1": sum_num_keypoints_1 / batch_size,
"mean_num_det_kpts2": sum_num_keypoints_2 / batch_size,
"mean_reward": sum_reward_batch / batch_size,
"lr": opt_pi.param_groups[0]["lr"],
"ma_skipped_batches": sum(ma_skipped_batches) / len(ma_skipped_batches),
"inl_th": inl_th,
},
step=i_step,
)
if i_step % cfg.val_interval == 0:
val_benchmark.evaluate(net, fabric.device, progress_bar=False)
val_benchmark.log_results(logger=wandb_logger, step=i_step)
# ensure that the model is in training mode again
net.train()
# save the model
torch.save(
net.state_dict(),
output_dir / ("model" + "_" + str(i_step + 1) + "_final" + ".pth"),
)
if __name__ == "__main__":
train()