File size: 2,077 Bytes
d4e7f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from typing import Optional, Tuple
import torch
from .io import read_hwc, write_hwc
from .util import hwc2bchw, bchw2hwc, bchw2bhwc, bhwc2bchw, bhwc2hwc
from .draw import draw_bchw, draw_landmarks
from .show import show_bchw, show_bhw
from .face_detection import FaceDetector
from .face_parsing import FaceParser
from .face_alignment import FaceAlignment
from .face_attribute import FaceAttribute
def _split_name(name: str) -> Tuple[str, Optional[str]]:
if '/' in name:
detector_type, conf_name = name.split('/', 1)
else:
detector_type, conf_name = name, None
return detector_type, conf_name
def face_detector(name: str, device: torch.device, **kwargs) -> FaceDetector:
detector_type, conf_name = _split_name(name)
if detector_type == 'retinaface':
from .face_detection import RetinaFaceDetector
return RetinaFaceDetector(conf_name, **kwargs).to(device)
else:
raise RuntimeError(f'Unknown detector type: {detector_type}')
def face_parser(name: str, device: torch.device, **kwargs) -> FaceParser:
parser_type, conf_name = _split_name(name)
if parser_type == 'farl':
from .face_parsing import FaRLFaceParser
return FaRLFaceParser(conf_name, device=device, **kwargs).to(device)
else:
raise RuntimeError(f'Unknown parser type: {parser_type}')
def face_aligner(name: str, device: torch.device, **kwargs) -> FaceAlignment:
aligner_type, conf_name = _split_name(name)
if aligner_type == 'farl':
from .face_alignment import FaRLFaceAlignment
return FaRLFaceAlignment(conf_name, device=device, **kwargs).to(device)
else:
raise RuntimeError(f'Unknown aligner type: {aligner_type}')
def face_attr(name: str, device: torch.device, **kwargs) -> FaceAttribute:
attr_type, conf_name = _split_name(name)
if attr_type == 'farl':
from .face_attribute import FaRLFaceAttribute
return FaRLFaceAttribute(conf_name, device=device, **kwargs).to(device)
else:
raise RuntimeError(f'Unknown attribute type: {attr_type}') |