File size: 3,977 Bytes
89c9b15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import sys
from pathlib import Path
import tempfile
import torch
from PIL import Image

from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel

roma_path = Path(__file__).parent / "../../third_party/RoMa"
sys.path.append(str(roma_path))
from romatch.models.model_zoo import roma_model

dad_path = Path(__file__).parent / "../../third_party/dad"
sys.path.append(str(dad_path))
import dad as dad_detector

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Dad(BaseModel):
    default_conf = {
        "name": "two_view_pipeline",
        "model_name": "roma_outdoor.pth",
        "model_utils_name": "dinov2_vitl14_pretrain.pth",
        "max_keypoints": 3000,
        "coarse_res": (560, 560),
        "upsample_res": (864, 1152),
    }
    required_inputs = [
        "image0",
        "image1",
    ]

    # Initialize the line matcher
    def _init(self, conf):
        model_path = self._download_model(
            repo_id=MODEL_REPO_ID,
            filename="{}/{}".format("roma", self.conf["model_name"]),
        )

        dinov2_weights = self._download_model(
            repo_id=MODEL_REPO_ID,
            filename="{}/{}".format("roma", self.conf["model_utils_name"]),
        )

        logger.info("Loading Dad + Roma model")
        # load the model
        weights = torch.load(model_path, map_location="cpu")
        dinov2_weights = torch.load(dinov2_weights, map_location="cpu")

        if str(device) == "cpu":
            amp_dtype = torch.float32
        else:
            amp_dtype = torch.float16

        self.matcher = roma_model(
            resolution=self.conf["coarse_res"],
            upsample_preds=True,
            weights=weights,
            dinov2_weights=dinov2_weights,
            device=device,
            amp_dtype=amp_dtype,
        )
        self.matcher.upsample_res = self.conf["upsample_res"]
        self.matcher.symmetric = False

        self.detector = dad_detector.load_DaD()
        logger.info("Load Dad + Roma model done.")

    def _forward(self, data):
        img0 = data["image0"].cpu().numpy().squeeze() * 255
        img1 = data["image1"].cpu().numpy().squeeze() * 255
        img0 = img0.transpose(1, 2, 0)
        img1 = img1.transpose(1, 2, 0)
        img0 = Image.fromarray(img0.astype("uint8"))
        img1 = Image.fromarray(img1.astype("uint8"))
        W_A, H_A = img0.size
        W_B, H_B = img1.size

        # hack: bad way to save then match
        with (
            tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_img0,
            tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_img1,
        ):
            img0_path = temp_img0.name
            img1_path = temp_img1.name
            img0.save(img0_path)
            img1.save(img1_path)

        # Match
        warp, certainty = self.matcher.match(img0_path, img1_path, device=device)
        # Detect
        keypoints_A = self.detector.detect_from_path(
            img0_path,
            num_keypoints=self.conf["max_keypoints"],
        )["keypoints"][0]
        keypoints_B = self.detector.detect_from_path(
            img1_path,
            num_keypoints=self.conf["max_keypoints"],
        )["keypoints"][0]
        matches = self.matcher.match_keypoints(
            keypoints_A,
            keypoints_B,
            warp,
            certainty,
            return_tuple=False,
        )

        # Sample matches for estimation
        kpts1, kpts2 = self.matcher.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
        offset = self.detector.topleft - 0
        kpts1, kpts2 = kpts1 - offset, kpts2 - offset
        pred = {
            "keypoints0": self.matcher._to_pixel_coordinates(keypoints_A, H_A, W_A),
            "keypoints1": self.matcher._to_pixel_coordinates(keypoints_B, H_B, W_B),
            "mkeypoints0": kpts1,
            "mkeypoints1": kpts2,
            "mconf": torch.ones_like(kpts1[:, 0]),
        }
        return pred