import os.path import re from typing import List, Tuple from hfutils.operate import get_hf_fs from hfutils.utils import hf_fs_path, parse_hf_fs_path from imgutils.data import ImageTyping from imgutils.detect import detect_heads from .base import ObjectDetection def _parse_model_name(model_name: str): matching = re.fullmatch(r'^head_detect_best_(?P[\s\S]+?)$', model_name) return matching.group('level') class HeadDetection(ObjectDetection): def __init__(self): self.repo_id = 'deepghs/imgutils-models' def _get_default_model(self) -> str: return 'head_detect_best_s' def _list_models(self) -> List[str]: hf_fs = get_hf_fs() return [ os.path.splitext(os.path.basename(parse_hf_fs_path(path).filename))[0] for path in hf_fs.glob(hf_fs_path( repo_id=self.repo_id, repo_type='model', filename='head_detect/*.onnx', )) ] def _get_default_iou_and_score(self, model_name: str) -> Tuple[float, float]: return 0.7, 0.3 def _get_labels(self, model_name: str) -> List[str]: return ['head'] def detect(self, image: ImageTyping, model_name: str, iou_threshold: float = 0.7, score_threshold: float = 0.25) -> \ List[Tuple[Tuple[float, float, float, float], str, float]]: level = _parse_model_name(model_name) return detect_heads(image=image, level=level, iou_threshold=iou_threshold, conf_threshold=score_threshold)