Spaces:
Sleeping
Sleeping
| try: | |
| import spaces | |
| gpu_decorator = spaces.GPU | |
| except ImportError: | |
| # Define a no-operation decorator as fallback | |
| def gpu_decorator(func): | |
| return func | |
| import torch | |
| from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
| from .model import OwlViTForClassification | |
| def load_xclip(device: str = "cuda:0", | |
| n_classes: int = 183, | |
| use_teacher_logits: bool = False, | |
| custom_box_head: bool = False, | |
| model_path: str = 'data/models/peeb_pretrain.pt', | |
| ): | |
| owlvit_det_processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") | |
| owlvit_det_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device) | |
| # BirdSoup mean std | |
| mean = [0.48168647, 0.49244233, 0.42851609] | |
| std = [0.18656386, 0.18614962, 0.19659419] | |
| owlvit_det_processor.image_processor.image_mean = mean | |
| owlvit_det_processor.image_processor.image_std = std | |
| # load finetuned owl-vit model | |
| weight_dict = {"loss_ce": 0, "loss_bbox": 0, "loss_giou": 0, | |
| "loss_sym_box_label": 0, "loss_xclip": 0} | |
| model = OwlViTForClassification(owlvit_det_model=owlvit_det_model, num_classes=n_classes, device=device, weight_dict=weight_dict, logits_from_teacher=use_teacher_logits, custom_box_head=custom_box_head) | |
| if model_path is not None: | |
| ckpt = torch.load(model_path, map_location='cpu') | |
| model.load_state_dict(ckpt, strict=False) | |
| model.to(device) | |
| return model, owlvit_det_processor |