File size: 3,576 Bytes
ec93f77
 
 
 
49ff668
ec93f77
 
 
2b6089a
0635e35
 
 
ec93f77
 
 
 
 
 
 
 
 
 
2b6089a
ec93f77
2b6089a
ec93f77
 
 
2b6089a
ec93f77
2b6089a
 
ec93f77
0635e35
2b6089a
ec93f77
 
 
2b6089a
 
ec93f77
 
 
 
 
 
 
 
0635e35
 
 
 
2b6089a
0635e35
2b6089a
ec93f77
 
2b6089a
ec93f77
2b6089a
ec93f77
 
 
 
 
 
 
 
 
 
 
 
49ff668
0635e35
49ff668
08eb34b
0635e35
08eb34b
 
 
2b6089a
 
 
0635e35
08eb34b
2b6089a
 
08eb34b
 
 
2b6089a
 
 
 
0635e35
2b6089a
 
 
0635e35
2b6089a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python

import os
import pathlib
import shlex
import subprocess
import tarfile

if os.getenv("SYSTEM") == "spaces":
    subprocess.run(shlex.split("pip install git+https://github.com/facebookresearch/[email protected]"), check=True)  # noqa: S603
    subprocess.run(shlex.split("pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87"), check=True)  # noqa: S603
    subprocess.run(shlex.split("pip install Pillow==9.5.0"), check=True)  # noqa: S603

import gradio as gr
import huggingface_hub
import numpy as np
import torch
from adet.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import Visualizer

DESCRIPTION = "# [Yet-Another-Anime-Segmenter](https://github.com/zymk9/Yet-Another-Anime-Segmenter)"

MODEL_REPO = "public-data/Yet-Another-Anime-Segmenter"


def load_sample_image_paths() -> list[pathlib.Path]:
    image_dir = pathlib.Path("images")
    if not image_dir.exists():
        dataset_repo = "hysts/sample-images-TADNE"
        path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset")
        with tarfile.open(path) as f:
            f.extractall()  # noqa: S202
    return sorted(image_dir.glob("*"))


def load_model(device: torch.device) -> DefaultPredictor:
    config_path = huggingface_hub.hf_hub_download(MODEL_REPO, "SOLOv2.yaml")
    model_path = huggingface_hub.hf_hub_download(MODEL_REPO, "SOLOv2.pth")
    cfg = get_cfg()
    cfg.merge_from_file(config_path)
    cfg.MODEL.WEIGHTS = model_path
    cfg.MODEL.DEVICE = device.type
    cfg.freeze()
    return DefaultPredictor(cfg)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(device)


def predict(
    image_path: str, class_score_threshold: float, mask_score_threshold: float
) -> tuple[np.ndarray, np.ndarray]:
    model.score_threshold = class_score_threshold
    model.mask_threshold = mask_score_threshold
    image = read_image(image_path, format="BGR")
    preds = model(image)
    instances = preds["instances"].to("cpu")

    visualizer = Visualizer(image[:, :, ::-1])
    vis = visualizer.draw_instance_predictions(predictions=instances)
    vis = vis.get_image()

    masked = image.copy()[:, :, ::-1]
    mask = instances.pred_masks.cpu().numpy().astype(int).max(axis=0)
    masked[mask == 0] = 255

    return vis, masked


image_paths = load_sample_image_paths()
examples = [[path, 0.1, 0.5] for path in image_paths]


with gr.Blocks(css_paths="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Input", type="filepath")
            class_score_threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.05, value=0.1)
            mask_score_threshold = gr.Slider(label="Mask Score Threshold", minimum=0, maximum=1, step=0.05, value=0.5)
            run_button = gr.Button()
        with gr.Column():
            result_instances = gr.Image(label="Instances")
            result_masked = gr.Image(label="Masked")

    inputs = [image, class_score_threshold, mask_score_threshold]
    outputs = [result_instances, result_masked]
    gr.Examples(
        examples=examples,
        inputs=inputs,
        outputs=outputs,
        fn=predict,
        cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
    )
    run_button.click(
        fn=predict,
        inputs=inputs,
        outputs=outputs,
    )

if __name__ == "__main__":
    demo.queue(max_size=15).launch()