Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	File size: 3,102 Bytes
			
			| 499e141 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | import sys
from pathlib import Path
import torch
from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel
gluestick_path = Path(__file__).parent / "../../third_party/GlueStick"
sys.path.append(str(gluestick_path))
from gluestick import batch_to_np
from gluestick.models.two_view_pipeline import TwoViewPipeline
class GlueStick(BaseModel):
    default_conf = {
        "name": "two_view_pipeline",
        "model_name": "checkpoint_GlueStick_MD.tar",
        "use_lines": True,
        "max_keypoints": 1000,
        "max_lines": 300,
        "force_num_keypoints": False,
    }
    required_inputs = [
        "image0",
        "image1",
    ]
    # Initialize the line matcher
    def _init(self, conf):
        # Download the model.
        model_path = self._download_model(
            repo_id=MODEL_REPO_ID,
            filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
        )
        logger.info("Loading GlueStick model...")
        gluestick_conf = {
            "name": "two_view_pipeline",
            "use_lines": True,
            "extractor": {
                "name": "wireframe",
                "sp_params": {
                    "force_num_keypoints": False,
                    "max_num_keypoints": 1000,
                },
                "wireframe_params": {
                    "merge_points": True,
                    "merge_line_endpoints": True,
                },
                "max_n_lines": 300,
            },
            "matcher": {
                "name": "gluestick",
                "weights": str(model_path),
                "trainable": False,
            },
            "ground_truth": {
                "from_pose_depth": False,
            },
        }
        gluestick_conf["extractor"]["sp_params"]["max_num_keypoints"] = conf[
            "max_keypoints"
        ]
        gluestick_conf["extractor"]["sp_params"]["force_num_keypoints"] = conf[
            "force_num_keypoints"
        ]
        gluestick_conf["extractor"]["max_n_lines"] = conf["max_lines"]
        self.net = TwoViewPipeline(gluestick_conf)
    def _forward(self, data):
        pred = self.net(data)
        pred = batch_to_np(pred)
        kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
        m0 = pred["matches0"]
        line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
        line_matches = pred["line_matches0"]
        valid_matches = m0 != -1
        match_indices = m0[valid_matches]
        matched_kps0 = kp0[valid_matches]
        matched_kps1 = kp1[match_indices]
        valid_matches = line_matches != -1
        match_indices = line_matches[valid_matches]
        matched_lines0 = line_seg0[valid_matches]
        matched_lines1 = line_seg1[match_indices]
        pred["raw_lines0"], pred["raw_lines1"] = line_seg0, line_seg1
        pred["lines0"], pred["lines1"] = matched_lines0, matched_lines1
        pred["keypoints0"], pred["keypoints1"] = (
            torch.from_numpy(matched_kps0),
            torch.from_numpy(matched_kps1),
        )
        pred = {**pred, **data}
        return pred
 | 
 
			
