narugo commited on
Commit
466ef55
·
1 Parent(s): 2d64837

dev(narugo): 1 more

Browse files
Files changed (3) hide show
  1. app2.py +2 -2
  2. detection/__init__.py +1 -0
  3. detection/censor.py +33 -0
app2.py CHANGED
@@ -2,7 +2,7 @@ import os
2
 
3
  import gradio as gr
4
 
5
- from detection import EyesDetection, FaceDetection, HeadDetection, PersonDetection, HandDetection
6
 
7
  _GLOBAL_CSS = """
8
  .limit-height {
@@ -26,7 +26,7 @@ if __name__ == '__main__':
26
  with gr.Tab('Hand Detection'):
27
  HandDetection().make_ui()
28
  with gr.Tab('Censor Point Detection'):
29
- pass
30
  with gr.Tab('Manbits Detection\n(Deprecated)'):
31
  pass
32
 
 
2
 
3
  import gradio as gr
4
 
5
+ from detection import EyesDetection, FaceDetection, HeadDetection, PersonDetection, HandDetection, CensorDetection
6
 
7
  _GLOBAL_CSS = """
8
  .limit-height {
 
26
  with gr.Tab('Hand Detection'):
27
  HandDetection().make_ui()
28
  with gr.Tab('Censor Point Detection'):
29
+ CensorDetection().make_ui()
30
  with gr.Tab('Manbits Detection\n(Deprecated)'):
31
  pass
32
 
detection/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
  from .base import ObjectDetection, DeepGHSObjectDetection
 
2
  from .eyes import EyesDetection
3
  from .face import FaceDetection
4
  from .hand import HandDetection
 
1
  from .base import ObjectDetection, DeepGHSObjectDetection
2
+ from .censor import CensorDetection
3
  from .eyes import EyesDetection
4
  from .face import FaceDetection
5
  from .hand import HandDetection
detection/censor.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Tuple
3
+
4
+ from imgutils.data import ImageTyping
5
+ from imgutils.detect.censor import detect_censors, _LABELS
6
+
7
+ from .base import DeepGHSObjectDetection
8
+
9
+
10
+ def _parse_model_name(model_name: str):
11
+ matching = re.fullmatch(r'^censor_detect_(?P<version>[\s\S]+?)_(?P<level>[\s\S]+?)$', model_name)
12
+ return matching.group('version'), matching.group('level')
13
+
14
+
15
+ class CensorDetection(DeepGHSObjectDetection):
16
+ def __init__(self):
17
+ DeepGHSObjectDetection.__init__(self, repo_id='deepghs/anime_censor_detection')
18
+
19
+ def _get_default_model(self) -> str:
20
+ return 'censor_detect_v1.0_s'
21
+
22
+ def _get_default_iou_and_score(self, model_name: str) -> Tuple[float, float]:
23
+ return 0.7, 0.3
24
+
25
+ def _get_labels(self, model_name: str) -> List[str]:
26
+ return _LABELS
27
+
28
+ def detect(self, image: ImageTyping, model_name: str,
29
+ iou_threshold: float = 0.7, score_threshold: float = 0.25) \
30
+ -> List[Tuple[Tuple[float, float, float, float], str, float]]:
31
+ version, level = _parse_model_name(model_name)
32
+ return detect_censors(image, level=level, version=version,
33
+ conf_threshold=score_threshold, iou_threshold=iou_threshold)