from typing import List, Tuple, Optional, Any, Union from .model import _classifier, _regressor, Classifier, Regressor from .clip import _clip_ebc, CLIP_EBC clip_names = ["resnet50", "resnet50x4", "resnet50x16", "resnet50x64", "resnet101", "vit_b_16", "vit_b_32", "vit_l_14"] def get_model( backbone: str, input_size: int, reduction: int, bins: Optional[List[Tuple[float, float]]] = None, anchor_points: Optional[List[float]] = None, **kwargs: Any, ) -> Union[Regressor, Classifier, CLIP_EBC]: backbone = backbone.lower() if "clip" in backbone: backbone = backbone[5:] assert backbone in clip_names, f"Expected backbone to be in {clip_names}, got {backbone}" return _clip_ebc( backbone=backbone, input_size=input_size, reduction=reduction, bins=bins, anchor_points=anchor_points, **kwargs ) elif bins is None and anchor_points is None: return _regressor( backbone=backbone, input_size=input_size, reduction=reduction, ) else: assert bins is not None and anchor_points is not None, f"Expected bins and anchor_points to be both None or not None, got {bins} and {anchor_points}" return _classifier( backbone=backbone, input_size=input_size, reduction=reduction, bins=bins, anchor_points=anchor_points, ) __all__ = [ "get_model", ]