File size: 3,576 Bytes
ec93f77 49ff668 ec93f77 2b6089a 0635e35 ec93f77 2b6089a ec93f77 2b6089a ec93f77 2b6089a ec93f77 2b6089a ec93f77 0635e35 2b6089a ec93f77 2b6089a ec93f77 0635e35 2b6089a 0635e35 2b6089a ec93f77 2b6089a ec93f77 2b6089a ec93f77 49ff668 0635e35 49ff668 08eb34b 0635e35 08eb34b 2b6089a 0635e35 08eb34b 2b6089a 08eb34b 2b6089a 0635e35 2b6089a 0635e35 2b6089a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
#!/usr/bin/env python
import os
import pathlib
import shlex
import subprocess
import tarfile
if os.getenv("SYSTEM") == "spaces":
subprocess.run(shlex.split("pip install git+https://github.com/facebookresearch/[email protected]"), check=True) # noqa: S603
subprocess.run(shlex.split("pip install git+https://github.com/aim-uofa/AdelaiDet@7bf9d87"), check=True) # noqa: S603
subprocess.run(shlex.split("pip install Pillow==9.5.0"), check=True) # noqa: S603
import gradio as gr
import huggingface_hub
import numpy as np
import torch
from adet.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.engine.defaults import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
DESCRIPTION = "# [Yet-Another-Anime-Segmenter](https://github.com/zymk9/Yet-Another-Anime-Segmenter)"
MODEL_REPO = "public-data/Yet-Another-Anime-Segmenter"
def load_sample_image_paths() -> list[pathlib.Path]:
image_dir = pathlib.Path("images")
if not image_dir.exists():
dataset_repo = "hysts/sample-images-TADNE"
path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset")
with tarfile.open(path) as f:
f.extractall() # noqa: S202
return sorted(image_dir.glob("*"))
def load_model(device: torch.device) -> DefaultPredictor:
config_path = huggingface_hub.hf_hub_download(MODEL_REPO, "SOLOv2.yaml")
model_path = huggingface_hub.hf_hub_download(MODEL_REPO, "SOLOv2.pth")
cfg = get_cfg()
cfg.merge_from_file(config_path)
cfg.MODEL.WEIGHTS = model_path
cfg.MODEL.DEVICE = device.type
cfg.freeze()
return DefaultPredictor(cfg)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(device)
def predict(
image_path: str, class_score_threshold: float, mask_score_threshold: float
) -> tuple[np.ndarray, np.ndarray]:
model.score_threshold = class_score_threshold
model.mask_threshold = mask_score_threshold
image = read_image(image_path, format="BGR")
preds = model(image)
instances = preds["instances"].to("cpu")
visualizer = Visualizer(image[:, :, ::-1])
vis = visualizer.draw_instance_predictions(predictions=instances)
vis = vis.get_image()
masked = image.copy()[:, :, ::-1]
mask = instances.pred_masks.cpu().numpy().astype(int).max(axis=0)
masked[mask == 0] = 255
return vis, masked
image_paths = load_sample_image_paths()
examples = [[path, 0.1, 0.5] for path in image_paths]
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
image = gr.Image(label="Input", type="filepath")
class_score_threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.05, value=0.1)
mask_score_threshold = gr.Slider(label="Mask Score Threshold", minimum=0, maximum=1, step=0.05, value=0.5)
run_button = gr.Button()
with gr.Column():
result_instances = gr.Image(label="Instances")
result_masked = gr.Image(label="Masked")
inputs = [image, class_score_threshold, mask_score_threshold]
outputs = [result_instances, result_masked]
gr.Examples(
examples=examples,
inputs=inputs,
outputs=outputs,
fn=predict,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
)
run_button.click(
fn=predict,
inputs=inputs,
outputs=outputs,
)
if __name__ == "__main__":
demo.queue(max_size=15).launch()
|