Image Classification
Transformers
Safetensors
cetaceanet
biology
biodiversity
custom_code
File size: 3,129 Bytes
db195b0
 
ea195e7
db195b0
 
 
ea195e7
 
 
db195b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import albumentations as A
from transformers import PreTrainedModel
# from PIL import Image
import numpy as np
import torch
import cv2

from .train import SphereClassifier
from .configuration_cetacean_classifier import CetaceanClassifierConfig


WHALE_CLASSES = np.array(
    [
        "beluga",
        "blue_whale",
        "bottlenose_dolphin",
        "brydes_whale",
        "commersons_dolphin",
        "common_dolphin",
        "cuviers_beaked_whale",
        "dusky_dolphin",
        "false_killer_whale",
        "fin_whale",
        "frasiers_dolphin",
        "gray_whale",
        "humpback_whale",
        "killer_whale",
        "long_finned_pilot_whale",
        "melon_headed_whale",
        "minke_whale",
        "pantropic_spotted_dolphin",
        "pygmy_killer_whale",
        "rough_toothed_dolphin",
        "sei_whale",
        "short_finned_pilot_whale",
        "southern_right_whale",
        "spinner_dolphin",
        "spotted_dolphin",
        "white_sided_dolphin",
    ]
)


class CetaceanClassifierModelForImageClassification(PreTrainedModel):
    config_class = CetaceanClassifierConfig

    def __init__(self, config):
        super().__init__(config)

        self.model = SphereClassifier(cfg=config.to_dict())

        # load_from_checkpoint("cetacean_classifier/last.ckpt")
        # self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt")

        self.model.eval()
        self.config = config
        self.transforms = self.make_transforms(data_aug=True)
    
    def make_transforms(self, data_aug: bool):
        augments = []
        if data_aug:
            aug = self.config.aug
            augments = [
                A.RandomResizedCrop(
                    self.config.image_size[0],
                    self.config.image_size[1],
                    scale=(aug["crop_scale"], 1.0),
                    ratio=(aug["crop_l"], aug["crop_r"]),
                ),]
        return A.Compose(augments)

    def preprocess_image(self, img) -> torch.Tensor:        
        rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        image = cv2.resize(rgb, self.config.image_size, interpolation=cv2.INTER_CUBIC)
        image =  self.transforms(image=image)["image"]
        return torch.Tensor(image).transpose(2, 0).unsqueeze(0)
        #image_resized = img.resize((480, 480))
        #image_resized = np.array(image_resized)[None]
        #image_resized = np.transpose(image_resized, [0, 3, 2, 1])
        #image_tensor = torch.Tensor(image_resized)
        #return image_tensor

    def forward(self, img, labels=None):
        tensor = self.preprocess_image(img)
        head_id_logits, head_species_logits = self.model(tensor)
        head_species_logits = head_species_logits.detach().numpy()
        sorted_idx = head_species_logits.argsort()[0]
        sorted_idx = np.array(list(reversed(sorted_idx)))
        top_three_logits = sorted_idx[:3]
        top_three_whale_preds = WHALE_CLASSES[top_three_logits]

        return {"predictions": top_three_whale_preds}