from typing import Optional, Tuple import torch from .io import read_hwc, write_hwc from .util import hwc2bchw, bchw2hwc, bchw2bhwc, bhwc2bchw, bhwc2hwc from .draw import draw_bchw, draw_landmarks from .show import show_bchw, show_bhw from .face_detection import FaceDetector from .face_parsing import FaceParser from .face_alignment import FaceAlignment from .face_attribute import FaceAttribute def _split_name(name: str) -> Tuple[str, Optional[str]]: if '/' in name: detector_type, conf_name = name.split('/', 1) else: detector_type, conf_name = name, None return detector_type, conf_name def face_detector(name: str, device: torch.device, **kwargs) -> FaceDetector: detector_type, conf_name = _split_name(name) if detector_type == 'retinaface': from .face_detection import RetinaFaceDetector return RetinaFaceDetector(conf_name, **kwargs).to(device) else: raise RuntimeError(f'Unknown detector type: {detector_type}') def face_parser(name: str, device: torch.device, **kwargs) -> FaceParser: parser_type, conf_name = _split_name(name) if parser_type == 'farl': from .face_parsing import FaRLFaceParser return FaRLFaceParser(conf_name, device=device, **kwargs).to(device) else: raise RuntimeError(f'Unknown parser type: {parser_type}') def face_aligner(name: str, device: torch.device, **kwargs) -> FaceAlignment: aligner_type, conf_name = _split_name(name) if aligner_type == 'farl': from .face_alignment import FaRLFaceAlignment return FaRLFaceAlignment(conf_name, device=device, **kwargs).to(device) else: raise RuntimeError(f'Unknown aligner type: {aligner_type}') def face_attr(name: str, device: torch.device, **kwargs) -> FaceAttribute: attr_type, conf_name = _split_name(name) if attr_type == 'farl': from .face_attribute import FaRLFaceAttribute return FaRLFaceAttribute(conf_name, device=device, **kwargs).to(device) else: raise RuntimeError(f'Unknown attribute type: {attr_type}')