File size: 3,573 Bytes
ec93f77
 
 
 
 
 
 
49ff668
ec93f77
 
 
2b6089a
 
 
ec93f77
 
 
 
 
 
 
 
 
 
2b6089a
ec93f77
2b6089a
ec93f77
 
 
2b6089a
ec93f77
2b6089a
 
ec93f77
 
2b6089a
ec93f77
 
 
2b6089a
 
ec93f77
 
 
 
 
 
 
 
2b6089a
 
 
ec93f77
 
2b6089a
ec93f77
2b6089a
ec93f77
 
 
 
 
 
 
 
 
 
 
 
49ff668
 
 
2b6089a
49ff668
 
08eb34b
 
2b6089a
08eb34b
 
 
2b6089a
 
 
 
08eb34b
2b6089a
 
08eb34b
 
 
2b6089a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python

from __future__ import annotations

import functools
import os
import pathlib
import shlex
import subprocess
import tarfile

if os.getenv("SYSTEM") == "spaces":
    subprocess.run(shlex.split("pip install git+https://github.com/facebookresearch/[email protected]"))
    subprocess.run(shlex.split("pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87"))

import gradio as gr
import huggingface_hub
import numpy as np
import torch
from adet.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import Visualizer

DESCRIPTION = "# [Yet-Another-Anime-Segmenter](https://github.com/zymk9/Yet-Another-Anime-Segmenter)"

MODEL_REPO = "public-data/Yet-Another-Anime-Segmenter"


def load_sample_image_paths() -> list[pathlib.Path]:
    image_dir = pathlib.Path("images")
    if not image_dir.exists():
        dataset_repo = "hysts/sample-images-TADNE"
        path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset")
        with tarfile.open(path) as f:
            f.extractall()
    return sorted(image_dir.glob("*"))


def load_model(device: torch.device) -> DefaultPredictor:
    config_path = huggingface_hub.hf_hub_download(MODEL_REPO, "SOLOv2.yaml")
    model_path = huggingface_hub.hf_hub_download(MODEL_REPO, "SOLOv2.pth")
    cfg = get_cfg()
    cfg.merge_from_file(config_path)
    cfg.MODEL.WEIGHTS = model_path
    cfg.MODEL.DEVICE = device.type
    cfg.freeze()
    return DefaultPredictor(cfg)


def predict(
    image_path: str, class_score_threshold: float, mask_score_threshold: float, model: DefaultPredictor
) -> tuple[np.ndarray, np.ndarray]:
    model.score_threshold = class_score_threshold
    model.mask_threshold = mask_score_threshold
    image = read_image(image_path, format="BGR")
    preds = model(image)
    instances = preds["instances"].to("cpu")

    visualizer = Visualizer(image[:, :, ::-1])
    vis = visualizer.draw_instance_predictions(predictions=instances)
    vis = vis.get_image()

    masked = image.copy()[:, :, ::-1]
    mask = instances.pred_masks.cpu().numpy().astype(int).max(axis=0)
    masked[mask == 0] = 255

    return vis, masked


image_paths = load_sample_image_paths()
examples = [[path.as_posix(), 0.1, 0.5] for path in image_paths]

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(device)

fn = functools.partial(predict, model=model)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Input", type="filepath")
            class_score_threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.05, value=0.1)
            mask_score_threshold = gr.Slider(label="Mask Score Threshold", minimum=0, maximum=1, step=0.05, value=0.5)
            run_button = gr.Button("Run")
        with gr.Column():
            result_instances = gr.Image(label="Instances")
            result_masked = gr.Image(label="Masked")

    inputs = [image, class_score_threshold, mask_score_threshold]
    outputs = [result_instances, result_masked]
    gr.Examples(
        examples=examples,
        inputs=inputs,
        outputs=outputs,
        fn=fn,
        cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
    )
    run_button.click(
        fn=fn,
        inputs=inputs,
        outputs=outputs,
        api_name="predict",
    )

if __name__ == "__main__":
    demo.queue(max_size=15).launch()