hsshin98 commited on
Commit
1fdfa56
·
1 Parent(s): e20de5f

cuda support

Browse files
Files changed (2) hide show
  1. app.py +9 -6
  2. requirements.txt +4 -4
app.py CHANGED
@@ -4,7 +4,7 @@ import argparse
4
  import glob
5
  import multiprocessing as mp
6
  import os
7
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
8
  os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
9
 
10
  # fmt: off
@@ -28,6 +28,7 @@ from detectron2.utils.logger import setup_logger
28
  from cat_seg import add_cat_seg_config
29
  from demo.predictor import VisualizationDemo
30
  import gradio as gr
 
31
  from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
32
 
33
  # constants
@@ -41,6 +42,8 @@ def setup_cfg(args):
41
  add_cat_seg_config(cfg)
42
  cfg.merge_from_file(args.config_file)
43
  cfg.merge_from_list(args.opts)
 
 
44
  cfg.freeze()
45
  return cfg
46
 
@@ -62,14 +65,14 @@ def get_parser():
62
  parser.add_argument(
63
  "--opts",
64
  help="Modify config options using the command-line 'KEY VALUE' pairs",
65
- default=["MODEL.WEIGHTS", "model_final.pth",
 
66
  "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
67
  "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
68
  "TEST.SLIDING_WINDOW", "True",
69
  "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
70
- "MODEL.DEVICE", "cpu",
71
- "MODEL.PROMPT_ENSEMBLE_TYPE", "single"
72
- ],
73
  nargs=argparse.REMAINDER,
74
  )
75
  return parser
@@ -103,7 +106,7 @@ if __name__ == "__main__":
103
  description="""## CAT-Seg Demo
104
  Welcome to the CAT-Seg Demo! Here, we present the CAT-Seg with ViT-L model for open-vocabulary semantic segmentation.
105
 
106
- Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model.
107
 
108
  To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
109
  iface.launch()
 
4
  import glob
5
  import multiprocessing as mp
6
  import os
7
+ #os.environ["CUDA_VISIBLE_DEVICES"] = ""
8
  os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
9
 
10
  # fmt: off
 
28
  from cat_seg import add_cat_seg_config
29
  from demo.predictor import VisualizationDemo
30
  import gradio as gr
31
+ import torch
32
  from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
33
 
34
  # constants
 
42
  add_cat_seg_config(cfg)
43
  cfg.merge_from_file(args.config_file)
44
  cfg.merge_from_list(args.opts)
45
+ if torch.cuda.is_available():
46
+ cfg.MODEL.DEVICE = "cuda"
47
  cfg.freeze()
48
  return cfg
49
 
 
65
  parser.add_argument(
66
  "--opts",
67
  help="Modify config options using the command-line 'KEY VALUE' pairs",
68
+ default=(
69
+ ["MODEL.WEIGHTS", "model_final.pth",
70
  "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json",
71
  "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json",
72
  "TEST.SLIDING_WINDOW", "True",
73
  "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]",
74
+ "MODEL.PROMPT_ENSEMBLE_TYPE", "single",
75
+ "MODEL.DEVICE", "cpu"]),
 
76
  nargs=argparse.REMAINDER,
77
  )
78
  return parser
 
106
  description="""## CAT-Seg Demo
107
  Welcome to the CAT-Seg Demo! Here, we present the CAT-Seg with ViT-L model for open-vocabulary semantic segmentation.
108
 
109
+ Please note that this is an optimized version of the full model, and as such, its performance may be limited compared to the full model. Also, the demo runs on a CPU, so it may take a little time to process your image.
110
 
111
  To get started, simply upload an image and a comma-separated list of categories, and let the model work its magic!""")
112
  iface.launch()
requirements.txt CHANGED
@@ -7,7 +7,7 @@ imageio==2.4.1
7
  timm==0.8.3.dev0
8
  regex
9
  einops
10
- torch==1.13.0+cpu
11
- torchvision==0.14.0+cpu
12
- torchaudio==0.13.0
13
- --extra-index-url https://download.pytorch.org/whl/cpu
 
7
  timm==0.8.3.dev0
8
  regex
9
  einops
10
+ torch==1.13.1+cu116
11
+ torchvision==0.14.1+cu116
12
+ torchaudio==0.13.1
13
+ --extra-index-url https://download.pytorch.org/whl/cu116