Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	fix: cached model
Browse files- common/app_class.py +2 -2
- common/config.yaml +20 -20
- common/utils.py +6 -7
- common/viz.py +7 -5
    	
        common/app_class.py
    CHANGED
    
    | @@ -5,7 +5,7 @@ from pathlib import Path | |
| 5 | 
             
            from typing import Dict, Any, Optional, Tuple, List, Union
         | 
| 6 | 
             
            from common.utils import (
         | 
| 7 | 
             
                ransac_zoo,
         | 
| 8 | 
            -
                 | 
| 9 | 
             
                load_config,
         | 
| 10 | 
             
                get_matcher_zoo,
         | 
| 11 | 
             
                run_matching,
         | 
| @@ -290,7 +290,7 @@ class ImageMatchingApp: | |
| 290 |  | 
| 291 | 
             
                            # estimate geo
         | 
| 292 | 
             
                            choice_geometry_type.change(
         | 
| 293 | 
            -
                                fn= | 
| 294 | 
             
                                inputs=[
         | 
| 295 | 
             
                                    input_image0,
         | 
| 296 | 
             
                                    input_image1,
         | 
|  | |
| 5 | 
             
            from typing import Dict, Any, Optional, Tuple, List, Union
         | 
| 6 | 
             
            from common.utils import (
         | 
| 7 | 
             
                ransac_zoo,
         | 
| 8 | 
            +
                generate_warp_images,
         | 
| 9 | 
             
                load_config,
         | 
| 10 | 
             
                get_matcher_zoo,
         | 
| 11 | 
             
                run_matching,
         | 
|  | |
| 290 |  | 
| 291 | 
             
                            # estimate geo
         | 
| 292 | 
             
                            choice_geometry_type.change(
         | 
| 293 | 
            +
                                fn=generate_warp_images,
         | 
| 294 | 
             
                                inputs=[
         | 
| 295 | 
             
                                    input_image0,
         | 
| 296 | 
             
                                    input_image1,
         | 
    	
        common/config.yaml
    CHANGED
    
    | @@ -16,26 +16,26 @@ defaults: | |
| 16 | 
             
              setting_geometry: Homography
         | 
| 17 |  | 
| 18 | 
             
            matcher_zoo:
         | 
| 19 | 
            -
              roma:
         | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
              dkm:
         | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
             
              loftr:
         | 
| 40 | 
             
                matcher: loftr
         | 
| 41 | 
             
                dense: true
         | 
|  | |
| 16 | 
             
              setting_geometry: Homography
         | 
| 17 |  | 
| 18 | 
             
            matcher_zoo:
         | 
| 19 | 
            +
              # roma:
         | 
| 20 | 
            +
              #   matcher: roma
         | 
| 21 | 
            +
              #   dense: true
         | 
| 22 | 
            +
              #   info: 
         | 
| 23 | 
            +
              #     name: RoMa #dispaly name
         | 
| 24 | 
            +
              #     source: "CVPR 2024"
         | 
| 25 | 
            +
              #     github: https://github.com/Parskatt/RoMa
         | 
| 26 | 
            +
              #     paper: https://arxiv.org/abs/2305.15404
         | 
| 27 | 
            +
              #     project: https://parskatt.github.io/RoMa
         | 
| 28 | 
            +
              #     display: true
         | 
| 29 | 
            +
              # dkm:
         | 
| 30 | 
            +
              #   matcher: dkm
         | 
| 31 | 
            +
              #   dense: true
         | 
| 32 | 
            +
              #   info: 
         | 
| 33 | 
            +
              #     name: DKM #dispaly name
         | 
| 34 | 
            +
              #     source: "CVPR 2023"
         | 
| 35 | 
            +
              #     github: https://github.com/Parskatt/DKM
         | 
| 36 | 
            +
              #     paper: https://arxiv.org/abs/2202.00667
         | 
| 37 | 
            +
              #     project: https://parskatt.github.io/DKM
         | 
| 38 | 
            +
              #     display: true
         | 
| 39 | 
             
              loftr:
         | 
| 40 | 
             
                matcher: loftr
         | 
| 41 | 
             
                dense: true
         | 
    	
        common/utils.py
    CHANGED
    
    | @@ -12,7 +12,6 @@ from hloc.utils.base_model import dynamic_load | |
| 12 | 
             
            from hloc import match_dense, match_features, extract_features
         | 
| 13 | 
             
            from hloc.utils.viz import add_text, plot_keypoints
         | 
| 14 | 
             
            from .viz import (
         | 
| 15 | 
            -
                draw_matches,
         | 
| 16 | 
             
                fig2im,
         | 
| 17 | 
             
                plot_images,
         | 
| 18 | 
             
                display_matches,
         | 
| @@ -242,7 +241,7 @@ def filter_matches( | |
| 242 | 
             
                return pred
         | 
| 243 |  | 
| 244 |  | 
| 245 | 
            -
            def  | 
| 246 | 
             
                pred: Dict[str, Any],
         | 
| 247 | 
             
                ransac_method: str = DEFAULT_RANSAC_METHOD,
         | 
| 248 | 
             
                ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
         | 
| @@ -373,7 +372,7 @@ def wrap_images( | |
| 373 | 
             
                    return None, None
         | 
| 374 |  | 
| 375 |  | 
| 376 | 
            -
            def  | 
| 377 | 
             
                input_image0: np.ndarray,
         | 
| 378 | 
             
                input_image1: np.ndarray,
         | 
| 379 | 
             
                matches_info: Dict[str, Any],
         | 
| @@ -475,7 +474,7 @@ def run_matching( | |
| 475 | 
             
                match_conf["model"]["match_threshold"] = match_threshold
         | 
| 476 | 
             
                match_conf["model"]["max_keypoints"] = extract_max_keypoints
         | 
| 477 | 
             
                t0 = time.time()
         | 
| 478 | 
            -
                cache_key = match_conf["model"]["name"]
         | 
| 479 | 
             
                if cache_key in models_already_loaded:
         | 
| 480 | 
             
                    matcher = models_already_loaded[cache_key]
         | 
| 481 | 
             
                    matcher.conf["max_keypoints"] = extract_max_keypoints
         | 
| @@ -499,7 +498,7 @@ def run_matching( | |
| 499 | 
             
                    # update extract config
         | 
| 500 | 
             
                    extract_conf["model"]["max_keypoints"] = extract_max_keypoints
         | 
| 501 | 
             
                    extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
         | 
| 502 | 
            -
                    cache_key = extract_conf["model"]["name"]
         | 
| 503 | 
             
                    if cache_key in models_already_loaded:
         | 
| 504 | 
             
                        extractor = models_already_loaded[cache_key]
         | 
| 505 | 
             
                        extractor.conf["max_keypoints"] = extract_max_keypoints
         | 
| @@ -567,8 +566,8 @@ def run_matching( | |
| 567 |  | 
| 568 | 
             
                t1 = time.time()
         | 
| 569 | 
             
                # plot wrapped images
         | 
| 570 | 
            -
                geom_info =  | 
| 571 | 
            -
                output_wrapped, _ =  | 
| 572 | 
             
                    pred["image0_orig"],
         | 
| 573 | 
             
                    pred["image1_orig"],
         | 
| 574 | 
             
                    {"geom_info": geom_info},
         | 
|  | |
| 12 | 
             
            from hloc import match_dense, match_features, extract_features
         | 
| 13 | 
             
            from hloc.utils.viz import add_text, plot_keypoints
         | 
| 14 | 
             
            from .viz import (
         | 
|  | |
| 15 | 
             
                fig2im,
         | 
| 16 | 
             
                plot_images,
         | 
| 17 | 
             
                display_matches,
         | 
|  | |
| 241 | 
             
                return pred
         | 
| 242 |  | 
| 243 |  | 
| 244 | 
            +
            def compute_geometry(
         | 
| 245 | 
             
                pred: Dict[str, Any],
         | 
| 246 | 
             
                ransac_method: str = DEFAULT_RANSAC_METHOD,
         | 
| 247 | 
             
                ransac_reproj_threshold: float = DEFAULT_RANSAC_REPROJ_THRESHOLD,
         | 
|  | |
| 372 | 
             
                    return None, None
         | 
| 373 |  | 
| 374 |  | 
| 375 | 
            +
            def generate_warp_images(
         | 
| 376 | 
             
                input_image0: np.ndarray,
         | 
| 377 | 
             
                input_image1: np.ndarray,
         | 
| 378 | 
             
                matches_info: Dict[str, Any],
         | 
|  | |
| 474 | 
             
                match_conf["model"]["match_threshold"] = match_threshold
         | 
| 475 | 
             
                match_conf["model"]["max_keypoints"] = extract_max_keypoints
         | 
| 476 | 
             
                t0 = time.time()
         | 
| 477 | 
            +
                cache_key = "{}_{}".format(key, match_conf["model"]["name"])
         | 
| 478 | 
             
                if cache_key in models_already_loaded:
         | 
| 479 | 
             
                    matcher = models_already_loaded[cache_key]
         | 
| 480 | 
             
                    matcher.conf["max_keypoints"] = extract_max_keypoints
         | 
|  | |
| 498 | 
             
                    # update extract config
         | 
| 499 | 
             
                    extract_conf["model"]["max_keypoints"] = extract_max_keypoints
         | 
| 500 | 
             
                    extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
         | 
| 501 | 
            +
                    cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
         | 
| 502 | 
             
                    if cache_key in models_already_loaded:
         | 
| 503 | 
             
                        extractor = models_already_loaded[cache_key]
         | 
| 504 | 
             
                        extractor.conf["max_keypoints"] = extract_max_keypoints
         | 
|  | |
| 566 |  | 
| 567 | 
             
                t1 = time.time()
         | 
| 568 | 
             
                # plot wrapped images
         | 
| 569 | 
            +
                geom_info = compute_geometry(pred)
         | 
| 570 | 
            +
                output_wrapped, _ = generate_warp_images(
         | 
| 571 | 
             
                    pred["image0_orig"],
         | 
| 572 | 
             
                    pred["image1_orig"],
         | 
| 573 | 
             
                    {"geom_info": geom_info},
         | 
    	
        common/viz.py
    CHANGED
    
    | @@ -247,7 +247,7 @@ def fig2im(fig: matplotlib.figure.Figure) -> np.ndarray: | |
| 247 | 
             
                return buf_ndarray.reshape(height, width, 3)
         | 
| 248 |  | 
| 249 |  | 
| 250 | 
            -
            def  | 
| 251 | 
             
                mkpts0: List[np.ndarray],
         | 
| 252 | 
             
                mkpts1: List[np.ndarray],
         | 
| 253 | 
             
                img0: np.ndarray,
         | 
| @@ -293,7 +293,7 @@ def draw_matches( | |
| 293 | 
             
                            mkpts1,
         | 
| 294 | 
             
                            color,
         | 
| 295 | 
             
                            titles=titles,
         | 
| 296 | 
            -
                            text= | 
| 297 | 
             
                            path=path,
         | 
| 298 | 
             
                            dpi=dpi,
         | 
| 299 | 
             
                            pad=pad,
         | 
| @@ -308,7 +308,7 @@ def draw_matches( | |
| 308 | 
             
                            mkpts1,
         | 
| 309 | 
             
                            color,
         | 
| 310 | 
             
                            titles=titles,
         | 
| 311 | 
            -
                            text= | 
| 312 | 
             
                            pad=pad,
         | 
| 313 | 
             
                            dpi=dpi,
         | 
| 314 | 
             
                        )
         | 
| @@ -406,7 +406,7 @@ def display_matches( | |
| 406 | 
             
                        mconf = pred["mconf"]
         | 
| 407 | 
             
                    else:
         | 
| 408 | 
             
                        mconf = np.ones(len(mkpts0))
         | 
| 409 | 
            -
                    fig_mkpts =  | 
| 410 | 
             
                        mkpts0,
         | 
| 411 | 
             
                        mkpts1,
         | 
| 412 | 
             
                        img0,
         | 
| @@ -445,7 +445,9 @@ def display_matches( | |
| 445 | 
             
                            mconf = pred["mconf"]
         | 
| 446 | 
             
                        else:
         | 
| 447 | 
             
                            mconf = np.ones(len(mkpts0))
         | 
| 448 | 
            -
                        fig_mkpts =  | 
|  | |
|  | |
| 449 | 
             
                        fig_lines = cv2.resize(
         | 
| 450 | 
             
                            fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0])
         | 
| 451 | 
             
                        )
         | 
|  | |
| 247 | 
             
                return buf_ndarray.reshape(height, width, 3)
         | 
| 248 |  | 
| 249 |  | 
| 250 | 
            +
            def draw_matches_core(
         | 
| 251 | 
             
                mkpts0: List[np.ndarray],
         | 
| 252 | 
             
                mkpts1: List[np.ndarray],
         | 
| 253 | 
             
                img0: np.ndarray,
         | 
|  | |
| 293 | 
             
                            mkpts1,
         | 
| 294 | 
             
                            color,
         | 
| 295 | 
             
                            titles=titles,
         | 
| 296 | 
            +
                            # text=texts,
         | 
| 297 | 
             
                            path=path,
         | 
| 298 | 
             
                            dpi=dpi,
         | 
| 299 | 
             
                            pad=pad,
         | 
|  | |
| 308 | 
             
                            mkpts1,
         | 
| 309 | 
             
                            color,
         | 
| 310 | 
             
                            titles=titles,
         | 
| 311 | 
            +
                            # text=texts,
         | 
| 312 | 
             
                            pad=pad,
         | 
| 313 | 
             
                            dpi=dpi,
         | 
| 314 | 
             
                        )
         | 
|  | |
| 406 | 
             
                        mconf = pred["mconf"]
         | 
| 407 | 
             
                    else:
         | 
| 408 | 
             
                        mconf = np.ones(len(mkpts0))
         | 
| 409 | 
            +
                    fig_mkpts = draw_matches_core(
         | 
| 410 | 
             
                        mkpts0,
         | 
| 411 | 
             
                        mkpts1,
         | 
| 412 | 
             
                        img0,
         | 
|  | |
| 445 | 
             
                            mconf = pred["mconf"]
         | 
| 446 | 
             
                        else:
         | 
| 447 | 
             
                            mconf = np.ones(len(mkpts0))
         | 
| 448 | 
            +
                        fig_mkpts = draw_matches_core(
         | 
| 449 | 
            +
                            mkpts0, mkpts1, img0, img1, mconf, dpi=300
         | 
| 450 | 
            +
                        )
         | 
| 451 | 
             
                        fig_lines = cv2.resize(
         | 
| 452 | 
             
                            fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0])
         | 
| 453 | 
             
                        )
         | 
 
			
