File size: 3,075 Bytes
d617811 dc8dfa3 a5c51bb d617811 dfe1f0b d617811 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
import argparse
import glob
import multiprocessing as mp
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
# fmt: off
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
# fmt: on
import tempfile
import time
import warnings
import cv2
import numpy as np
import tqdm
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.utils.logger import setup_logger
from cat_seg import add_cat_seg_config
from demo.predictor import VisualizationDemo
import gradio as gr
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
# constants
WINDOW_NAME = "MaskFormer demo"
def setup_cfg(args):
# load config from file and command-line arguments
cfg = get_cfg()
add_deeplab_config(cfg)
add_cat_seg_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.MODEL.DEVICE = "cpu"
cfg.freeze()
return cfg
def get_parser():
parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
parser.add_argument(
"--config-file",
default="configs/vitl_swinb_384.yaml",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--input",
nargs="+",
help="A list of space separated input images; "
"or a single glob pattern such as 'directory/*.jpg'",
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=["MODEL.WEIGHTS", "model_final.pth",
"MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
"MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
"TEST.SLIDING_WINDOW", "True",
"MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]"],
nargs=argparse.REMAINDER,
)
return parser
def save_masks(preds, text):
preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W
for i, t in enumerate(text):
dir = f"masks/mask_{t}.png"
mask = preds == i
cv2.imwrite(dir, mask * 255)
def predict(image, text):
args = get_parser().parse_args()
cfg = setup_cfg(args)
demo = VisualizationDemo(cfg, text=text)
predictions, visualized_output = demo.run_on_image(image)
save_masks(predictions, text.split(','))
canvas = fc(visualized_output.fig)
canvas.draw()
out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,))
return out[..., ::-1]
if __name__ == "__main__":
args = get_parser().parse_args()
cfg = setup_cfg(args)
iface = gr.Interface(
fn=predict,
inputs=[gr.Image(), gr.Textbox(placeholder="Classes to segment")],
outputs="image",
)
iface.launch()
|