File size: 2,077 Bytes
d4e7f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import Optional, Tuple
import torch

from .io import read_hwc, write_hwc
from .util import hwc2bchw, bchw2hwc, bchw2bhwc, bhwc2bchw, bhwc2hwc
from .draw import draw_bchw, draw_landmarks
from .show import show_bchw, show_bhw

from .face_detection import FaceDetector
from .face_parsing import FaceParser
from .face_alignment import FaceAlignment
from .face_attribute import FaceAttribute


def _split_name(name: str) -> Tuple[str, Optional[str]]:
    if '/' in name:
        detector_type, conf_name = name.split('/', 1)
    else:
        detector_type, conf_name = name, None
    return detector_type, conf_name


def face_detector(name: str, device: torch.device, **kwargs) -> FaceDetector:
    detector_type, conf_name = _split_name(name)
    if detector_type == 'retinaface':
        from .face_detection import RetinaFaceDetector
        return RetinaFaceDetector(conf_name, **kwargs).to(device)
    else:
        raise RuntimeError(f'Unknown detector type: {detector_type}')


def face_parser(name: str, device: torch.device, **kwargs) -> FaceParser:
    parser_type, conf_name = _split_name(name)
    if parser_type == 'farl':
        from .face_parsing import FaRLFaceParser
        return FaRLFaceParser(conf_name, device=device, **kwargs).to(device)
    else:
        raise RuntimeError(f'Unknown parser type: {parser_type}')


def face_aligner(name: str, device: torch.device, **kwargs) -> FaceAlignment:
    aligner_type, conf_name = _split_name(name)
    if aligner_type == 'farl':
        from .face_alignment import FaRLFaceAlignment
        return FaRLFaceAlignment(conf_name, device=device, **kwargs).to(device)
    else:
        raise RuntimeError(f'Unknown aligner type: {aligner_type}')

def face_attr(name: str, device: torch.device, **kwargs) -> FaceAttribute:
    attr_type, conf_name = _split_name(name)
    if attr_type == 'farl':
        from .face_attribute import FaRLFaceAttribute
        return FaRLFaceAttribute(conf_name, device=device, **kwargs).to(device)
    else:
        raise RuntimeError(f'Unknown attribute type: {attr_type}')