hsshin98
commited on
Commit
·
1fdfa56
1
Parent(s):
e20de5f
cuda support
Browse files- app.py +9 -6
- requirements.txt +4 -4
app.py
CHANGED
@@ -4,7 +4,7 @@ import argparse
|
|
4 |
import glob
|
5 |
import multiprocessing as mp
|
6 |
import os
|
7 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
8 |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
9 |
|
10 |
# fmt: off
|
@@ -28,6 +28,7 @@ from detectron2.utils.logger import setup_logger
|
|
28 |
from cat_seg import add_cat_seg_config
|
29 |
from demo.predictor import VisualizationDemo
|
30 |
import gradio as gr
|
|
|
31 |
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
|
32 |
|
33 |
# constants
|
@@ -41,6 +42,8 @@ def setup_cfg(args):
|
|
41 |
add_cat_seg_config(cfg)
|
42 |
cfg.merge_from_file(args.config_file)
|
43 |
cfg.merge_from_list(args.opts)
|
|
|
|
|
44 |
cfg.freeze()
|
45 |
return cfg
|
46 |
|
@@ -62,14 +65,14 @@ def get_parser():
|
|
62 |
parser.add_argument(
|
63 |
"--opts",
|
64 |
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
65 |
-
default=
|
|
|
66 |
"MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
|
67 |
"MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
|
68 |
"TEST.SLIDING_WINDOW", "True",
|
69 |
"MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
|
70 |
-
"MODEL.
|
71 |
-
"MODEL.
|
72 |
-
],
|
73 |
nargs=argparse.REMAINDER,
|
74 |
)
|
75 |
return parser
|
@@ -103,7 +106,7 @@ if __name__ == "__main__":
|
|
103 |
description="""## CAT-Seg Demo
|
104 |
Welcome to the CAT-Seg Demo! Here, we present the CAT-Seg with ViT-L model for open-vocabulary semantic segmentation.
|
105 |
|
106 |
-
Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model.
|
107 |
|
108 |
To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
|
109 |
iface.launch()
|
|
|
4 |
import glob
|
5 |
import multiprocessing as mp
|
6 |
import os
|
7 |
+
#os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
8 |
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
9 |
|
10 |
# fmt: off
|
|
|
28 |
from cat_seg import add_cat_seg_config
|
29 |
from demo.predictor import VisualizationDemo
|
30 |
import gradio as gr
|
31 |
+
import torch
|
32 |
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
|
33 |
|
34 |
# constants
|
|
|
42 |
add_cat_seg_config(cfg)
|
43 |
cfg.merge_from_file(args.config_file)
|
44 |
cfg.merge_from_list(args.opts)
|
45 |
+
if torch.cuda.is_available():
|
46 |
+
cfg.MODEL.DEVICE = "cuda"
|
47 |
cfg.freeze()
|
48 |
return cfg
|
49 |
|
|
|
65 |
parser.add_argument(
|
66 |
"--opts",
|
67 |
help="Modify config options using the command-line 'KEY VALUE' pairs",
|
68 |
+
default=(
|
69 |
+
["MODEL.WEIGHTS", "model_final.pth",
|
70 |
"MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
|
71 |
"MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
|
72 |
"TEST.SLIDING_WINDOW", "True",
|
73 |
"MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
|
74 |
+
"MODEL.PROMPT_ENSEMBLE_TYPE", "single",
|
75 |
+
"MODEL.DEVICE", "cpu"]),
|
|
|
76 |
nargs=argparse.REMAINDER,
|
77 |
)
|
78 |
return parser
|
|
|
106 |
description="""## CAT-Seg Demo
|
107 |
Welcome to the CAT-Seg Demo! Here, we present the CAT-Seg with ViT-L model for open-vocabulary semantic segmentation.
|
108 |
|
109 |
+
Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model. Also, the demo runs on a CPU, so it may take a little time to process your image.
|
110 |
|
111 |
To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
|
112 |
iface.launch()
|
requirements.txt
CHANGED
@@ -7,7 +7,7 @@ imageio==2.4.1
|
|
7 |
timm==0.8.3.dev0
|
8 |
regex
|
9 |
einops
|
10 |
-
torch==1.13.
|
11 |
-
torchvision==0.14.
|
12 |
-
torchaudio==0.13.
|
13 |
-
--extra-index-url https://download.pytorch.org/whl/
|
|
|
7 |
timm==0.8.3.dev0
|
8 |
regex
|
9 |
einops
|
10 |
+
torch==1.13.1+cu116
|
11 |
+
torchvision==0.14.1+cu116
|
12 |
+
torchaudio==0.13.1
|
13 |
+
--extra-index-url https://download.pytorch.org/whl/cu116
|