Update
Browse files- .pre-commit-config.yaml +59 -34
- .vscode/settings.json +30 -0
- README.md +1 -1
- app.py +40 -46
- style.css +8 -0
    	
        .pre-commit-config.yaml
    CHANGED
    
    | @@ -1,35 +1,60 @@ | |
| 1 | 
             
            repos:
         | 
| 2 | 
            -
            - repo: https://github.com/pre-commit/pre-commit-hooks
         | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
              -  | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
            - repo: https://github.com/pre-commit/mirrors-mypy
         | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            repos:
         | 
| 2 | 
            +
              - repo: https://github.com/pre-commit/pre-commit-hooks
         | 
| 3 | 
            +
                rev: v4.5.0
         | 
| 4 | 
            +
                hooks:
         | 
| 5 | 
            +
                  - id: check-executables-have-shebangs
         | 
| 6 | 
            +
                  - id: check-json
         | 
| 7 | 
            +
                  - id: check-merge-conflict
         | 
| 8 | 
            +
                  - id: check-shebang-scripts-are-executable
         | 
| 9 | 
            +
                  - id: check-toml
         | 
| 10 | 
            +
                  - id: check-yaml
         | 
| 11 | 
            +
                  - id: end-of-file-fixer
         | 
| 12 | 
            +
                  - id: mixed-line-ending
         | 
| 13 | 
            +
                    args: ["--fix=lf"]
         | 
| 14 | 
            +
                  - id: requirements-txt-fixer
         | 
| 15 | 
            +
                  - id: trailing-whitespace
         | 
| 16 | 
            +
              - repo: https://github.com/myint/docformatter
         | 
| 17 | 
            +
                rev: v1.7.5
         | 
| 18 | 
            +
                hooks:
         | 
| 19 | 
            +
                  - id: docformatter
         | 
| 20 | 
            +
                    args: ["--in-place"]
         | 
| 21 | 
            +
              - repo: https://github.com/pycqa/isort
         | 
| 22 | 
            +
                rev: 5.13.2
         | 
| 23 | 
            +
                hooks:
         | 
| 24 | 
            +
                  - id: isort
         | 
| 25 | 
            +
                    args: ["--profile", "black"]
         | 
| 26 | 
            +
              - repo: https://github.com/pre-commit/mirrors-mypy
         | 
| 27 | 
            +
                rev: v1.8.0
         | 
| 28 | 
            +
                hooks:
         | 
| 29 | 
            +
                  - id: mypy
         | 
| 30 | 
            +
                    args: ["--ignore-missing-imports"]
         | 
| 31 | 
            +
                    additional_dependencies:
         | 
| 32 | 
            +
                      [
         | 
| 33 | 
            +
                        "types-python-slugify",
         | 
| 34 | 
            +
                        "types-requests",
         | 
| 35 | 
            +
                        "types-PyYAML",
         | 
| 36 | 
            +
                        "types-pytz",
         | 
| 37 | 
            +
                      ]
         | 
| 38 | 
            +
              - repo: https://github.com/psf/black
         | 
| 39 | 
            +
                rev: 24.2.0
         | 
| 40 | 
            +
                hooks:
         | 
| 41 | 
            +
                  - id: black
         | 
| 42 | 
            +
                    language_version: python3.10
         | 
| 43 | 
            +
                    args: ["--line-length", "119"]
         | 
| 44 | 
            +
              - repo: https://github.com/kynan/nbstripout
         | 
| 45 | 
            +
                rev: 0.7.1
         | 
| 46 | 
            +
                hooks:
         | 
| 47 | 
            +
                  - id: nbstripout
         | 
| 48 | 
            +
                    args:
         | 
| 49 | 
            +
                      [
         | 
| 50 | 
            +
                        "--extra-keys",
         | 
| 51 | 
            +
                        "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
         | 
| 52 | 
            +
                      ]
         | 
| 53 | 
            +
              - repo: https://github.com/nbQA-dev/nbQA
         | 
| 54 | 
            +
                rev: 1.7.1
         | 
| 55 | 
            +
                hooks:
         | 
| 56 | 
            +
                  - id: nbqa-black
         | 
| 57 | 
            +
                  - id: nbqa-pyupgrade
         | 
| 58 | 
            +
                    args: ["--py37-plus"]
         | 
| 59 | 
            +
                  - id: nbqa-isort
         | 
| 60 | 
            +
                    args: ["--float-to-top"]
         | 
    	
        .vscode/settings.json
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "editor.formatOnSave": true,
         | 
| 3 | 
            +
                "files.insertFinalNewline": false,
         | 
| 4 | 
            +
                "[python]": {
         | 
| 5 | 
            +
                    "editor.defaultFormatter": "ms-python.black-formatter",
         | 
| 6 | 
            +
                    "editor.formatOnType": true,
         | 
| 7 | 
            +
                    "editor.codeActionsOnSave": {
         | 
| 8 | 
            +
                        "source.organizeImports": "explicit"
         | 
| 9 | 
            +
                    }
         | 
| 10 | 
            +
                },
         | 
| 11 | 
            +
                "[jupyter]": {
         | 
| 12 | 
            +
                    "files.insertFinalNewline": false
         | 
| 13 | 
            +
                },
         | 
| 14 | 
            +
                "black-formatter.args": [
         | 
| 15 | 
            +
                    "--line-length=119"
         | 
| 16 | 
            +
                ],
         | 
| 17 | 
            +
                "isort.args": ["--profile", "black"],
         | 
| 18 | 
            +
                "flake8.args": [
         | 
| 19 | 
            +
                    "--max-line-length=119"
         | 
| 20 | 
            +
                ],
         | 
| 21 | 
            +
                "ruff.lint.args": [
         | 
| 22 | 
            +
                    "--line-length=119"
         | 
| 23 | 
            +
                ],
         | 
| 24 | 
            +
                "notebook.output.scrolling": true,
         | 
| 25 | 
            +
                "notebook.formatOnCellExecution": true,
         | 
| 26 | 
            +
                "notebook.formatOnSave.enabled": true,
         | 
| 27 | 
            +
                    "notebook.codeActionsOnSave": {
         | 
| 28 | 
            +
                        "source.organizeImports": "explicit"
         | 
| 29 | 
            +
                    }
         | 
| 30 | 
            +
            }
         | 
    	
        README.md
    CHANGED
    
    | @@ -4,7 +4,7 @@ emoji: 🌍 | |
| 4 | 
             
            colorFrom: green
         | 
| 5 | 
             
            colorTo: yellow
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version:  | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
|  | |
| 4 | 
             
            colorFrom: green
         | 
| 5 | 
             
            colorTo: yellow
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 4.19.2
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
    	
        app.py
    CHANGED
    
    | @@ -9,14 +9,9 @@ import shlex | |
| 9 | 
             
            import subprocess
         | 
| 10 | 
             
            import tarfile
         | 
| 11 |  | 
| 12 | 
            -
            if os.getenv( | 
| 13 | 
            -
                subprocess.run(
         | 
| 14 | 
            -
             | 
| 15 | 
            -
                        'pip install git+https://github.com/facebookresearch/[email protected]'
         | 
| 16 | 
            -
                    ))
         | 
| 17 | 
            -
                subprocess.run(
         | 
| 18 | 
            -
                    shlex.split(
         | 
| 19 | 
            -
                        'pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87'))
         | 
| 20 |  | 
| 21 | 
             
            import gradio as gr
         | 
| 22 | 
             
            import huggingface_hub
         | 
| @@ -27,26 +22,24 @@ from detectron2.data.detection_utils import read_image | |
| 27 | 
             
            from detectron2.engine.defaults import DefaultPredictor
         | 
| 28 | 
             
            from detectron2.utils.visualizer import Visualizer
         | 
| 29 |  | 
| 30 | 
            -
            DESCRIPTION =  | 
| 31 |  | 
| 32 | 
            -
            MODEL_REPO =  | 
| 33 |  | 
| 34 |  | 
| 35 | 
             
            def load_sample_image_paths() -> list[pathlib.Path]:
         | 
| 36 | 
            -
                image_dir = pathlib.Path( | 
| 37 | 
             
                if not image_dir.exists():
         | 
| 38 | 
            -
                    dataset_repo =  | 
| 39 | 
            -
                    path = huggingface_hub.hf_hub_download(dataset_repo,
         | 
| 40 | 
            -
                                                           'images.tar.gz',
         | 
| 41 | 
            -
                                                           repo_type='dataset')
         | 
| 42 | 
             
                    with tarfile.open(path) as f:
         | 
| 43 | 
             
                        f.extractall()
         | 
| 44 | 
            -
                return sorted(image_dir.glob( | 
| 45 |  | 
| 46 |  | 
| 47 | 
             
            def load_model(device: torch.device) -> DefaultPredictor:
         | 
| 48 | 
            -
                config_path = huggingface_hub.hf_hub_download(MODEL_REPO,  | 
| 49 | 
            -
                model_path = huggingface_hub.hf_hub_download(MODEL_REPO,  | 
| 50 | 
             
                cfg = get_cfg()
         | 
| 51 | 
             
                cfg.merge_from_file(config_path)
         | 
| 52 | 
             
                cfg.MODEL.WEIGHTS = model_path
         | 
| @@ -55,14 +48,14 @@ def load_model(device: torch.device) -> DefaultPredictor: | |
| 55 | 
             
                return DefaultPredictor(cfg)
         | 
| 56 |  | 
| 57 |  | 
| 58 | 
            -
            def predict( | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
             
                model.score_threshold = class_score_threshold
         | 
| 62 | 
             
                model.mask_threshold = mask_score_threshold
         | 
| 63 | 
            -
                image = read_image(image_path, format= | 
| 64 | 
             
                preds = model(image)
         | 
| 65 | 
            -
                instances = preds[ | 
| 66 |  | 
| 67 | 
             
                visualizer = Visualizer(image[:, :, ::-1])
         | 
| 68 | 
             
                vis = visualizer.draw_instance_predictions(predictions=instances)
         | 
| @@ -78,37 +71,38 @@ def predict(image_path: str, class_score_threshold: float, | |
| 78 | 
             
            image_paths = load_sample_image_paths()
         | 
| 79 | 
             
            examples = [[path.as_posix(), 0.1, 0.5] for path in image_paths]
         | 
| 80 |  | 
| 81 | 
            -
            device = torch.device( | 
| 82 | 
             
            model = load_model(device)
         | 
| 83 |  | 
| 84 | 
             
            fn = functools.partial(predict, model=model)
         | 
| 85 |  | 
| 86 | 
            -
            with gr.Blocks(css= | 
| 87 | 
             
                gr.Markdown(DESCRIPTION)
         | 
| 88 | 
             
                with gr.Row():
         | 
| 89 | 
             
                    with gr.Column():
         | 
| 90 | 
            -
                        image = gr.Image(label= | 
| 91 | 
            -
                        class_score_threshold = gr.Slider(label= | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
                                                          step=0.05,
         | 
| 95 | 
            -
                                                          value=0.1)
         | 
| 96 | 
            -
                        mask_score_threshold = gr.Slider(label='Mask Score Threshold',
         | 
| 97 | 
            -
                                                         minimum=0,
         | 
| 98 | 
            -
                                                         maximum=1,
         | 
| 99 | 
            -
                                                         step=0.05,
         | 
| 100 | 
            -
                                                         value=0.5)
         | 
| 101 | 
            -
                        run_button = gr.Button('Run')
         | 
| 102 | 
             
                    with gr.Column():
         | 
| 103 | 
            -
                        result_instances = gr.Image(label= | 
| 104 | 
            -
                        result_masked = gr.Image(label= | 
| 105 |  | 
| 106 | 
             
                inputs = [image, class_score_threshold, mask_score_threshold]
         | 
| 107 | 
             
                outputs = [result_instances, result_masked]
         | 
| 108 | 
            -
                gr.Examples( | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 9 | 
             
            import subprocess
         | 
| 10 | 
             
            import tarfile
         | 
| 11 |  | 
| 12 | 
            +
            if os.getenv("SYSTEM") == "spaces":
         | 
| 13 | 
            +
                subprocess.run(shlex.split("pip install git+https://github.com/facebookresearch/[email protected]"))
         | 
| 14 | 
            +
                subprocess.run(shlex.split("pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87"))
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 15 |  | 
| 16 | 
             
            import gradio as gr
         | 
| 17 | 
             
            import huggingface_hub
         | 
|  | |
| 22 | 
             
            from detectron2.engine.defaults import DefaultPredictor
         | 
| 23 | 
             
            from detectron2.utils.visualizer import Visualizer
         | 
| 24 |  | 
| 25 | 
            +
            DESCRIPTION = "# [Yet-Another-Anime-Segmenter](https://github.com/zymk9/Yet-Another-Anime-Segmenter)"
         | 
| 26 |  | 
| 27 | 
            +
            MODEL_REPO = "public-data/Yet-Another-Anime-Segmenter"
         | 
| 28 |  | 
| 29 |  | 
| 30 | 
             
            def load_sample_image_paths() -> list[pathlib.Path]:
         | 
| 31 | 
            +
                image_dir = pathlib.Path("images")
         | 
| 32 | 
             
                if not image_dir.exists():
         | 
| 33 | 
            +
                    dataset_repo = "hysts/sample-images-TADNE"
         | 
| 34 | 
            +
                    path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset")
         | 
|  | |
|  | |
| 35 | 
             
                    with tarfile.open(path) as f:
         | 
| 36 | 
             
                        f.extractall()
         | 
| 37 | 
            +
                return sorted(image_dir.glob("*"))
         | 
| 38 |  | 
| 39 |  | 
| 40 | 
             
            def load_model(device: torch.device) -> DefaultPredictor:
         | 
| 41 | 
            +
                config_path = huggingface_hub.hf_hub_download(MODEL_REPO, "SOLOv2.yaml")
         | 
| 42 | 
            +
                model_path = huggingface_hub.hf_hub_download(MODEL_REPO, "SOLOv2.pth")
         | 
| 43 | 
             
                cfg = get_cfg()
         | 
| 44 | 
             
                cfg.merge_from_file(config_path)
         | 
| 45 | 
             
                cfg.MODEL.WEIGHTS = model_path
         | 
|  | |
| 48 | 
             
                return DefaultPredictor(cfg)
         | 
| 49 |  | 
| 50 |  | 
| 51 | 
            +
            def predict(
         | 
| 52 | 
            +
                image_path: str, class_score_threshold: float, mask_score_threshold: float, model: DefaultPredictor
         | 
| 53 | 
            +
            ) -> tuple[np.ndarray, np.ndarray]:
         | 
| 54 | 
             
                model.score_threshold = class_score_threshold
         | 
| 55 | 
             
                model.mask_threshold = mask_score_threshold
         | 
| 56 | 
            +
                image = read_image(image_path, format="BGR")
         | 
| 57 | 
             
                preds = model(image)
         | 
| 58 | 
            +
                instances = preds["instances"].to("cpu")
         | 
| 59 |  | 
| 60 | 
             
                visualizer = Visualizer(image[:, :, ::-1])
         | 
| 61 | 
             
                vis = visualizer.draw_instance_predictions(predictions=instances)
         | 
|  | |
| 71 | 
             
            image_paths = load_sample_image_paths()
         | 
| 72 | 
             
            examples = [[path.as_posix(), 0.1, 0.5] for path in image_paths]
         | 
| 73 |  | 
| 74 | 
            +
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         | 
| 75 | 
             
            model = load_model(device)
         | 
| 76 |  | 
| 77 | 
             
            fn = functools.partial(predict, model=model)
         | 
| 78 |  | 
| 79 | 
            +
            with gr.Blocks(css="style.css") as demo:
         | 
| 80 | 
             
                gr.Markdown(DESCRIPTION)
         | 
| 81 | 
             
                with gr.Row():
         | 
| 82 | 
             
                    with gr.Column():
         | 
| 83 | 
            +
                        image = gr.Image(label="Input", type="filepath")
         | 
| 84 | 
            +
                        class_score_threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.05, value=0.1)
         | 
| 85 | 
            +
                        mask_score_threshold = gr.Slider(label="Mask Score Threshold", minimum=0, maximum=1, step=0.05, value=0.5)
         | 
| 86 | 
            +
                        run_button = gr.Button("Run")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 87 | 
             
                    with gr.Column():
         | 
| 88 | 
            +
                        result_instances = gr.Image(label="Instances")
         | 
| 89 | 
            +
                        result_masked = gr.Image(label="Masked")
         | 
| 90 |  | 
| 91 | 
             
                inputs = [image, class_score_threshold, mask_score_threshold]
         | 
| 92 | 
             
                outputs = [result_instances, result_masked]
         | 
| 93 | 
            +
                gr.Examples(
         | 
| 94 | 
            +
                    examples=examples,
         | 
| 95 | 
            +
                    inputs=inputs,
         | 
| 96 | 
            +
                    outputs=outputs,
         | 
| 97 | 
            +
                    fn=fn,
         | 
| 98 | 
            +
                    cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
         | 
| 99 | 
            +
                )
         | 
| 100 | 
            +
                run_button.click(
         | 
| 101 | 
            +
                    fn=fn,
         | 
| 102 | 
            +
                    inputs=inputs,
         | 
| 103 | 
            +
                    outputs=outputs,
         | 
| 104 | 
            +
                    api_name="predict",
         | 
| 105 | 
            +
                )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            if __name__ == "__main__":
         | 
| 108 | 
            +
                demo.queue(max_size=15).launch()
         | 
    	
        style.css
    CHANGED
    
    | @@ -1,3 +1,11 @@ | |
| 1 | 
             
            h1 {
         | 
| 2 | 
             
              text-align: center;
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 3 | 
             
            }
         | 
|  | |
| 1 | 
             
            h1 {
         | 
| 2 | 
             
              text-align: center;
         | 
| 3 | 
            +
              display: block;
         | 
| 4 | 
            +
            }
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            #duplicate-button {
         | 
| 7 | 
            +
              margin: auto;
         | 
| 8 | 
            +
              color: #fff;
         | 
| 9 | 
            +
              background: #1565c0;
         | 
| 10 | 
            +
              border-radius: 100vh;
         | 
| 11 | 
             
            }
         | 
