File size: 3,046 Bytes
77cd70d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11e6dd5
77cd70d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ada1b6a
77cd70d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ada1b6a
77cd70d
ada1b6a
 
77cd70d
ada1b6a
77cd70d
 
 
 
ada1b6a
675d20e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import timm
from PIL import Image
import gradio as gr

# hyperparameters
device = torch.device("cpu")
input_width, input_height = 224, 224

# load ear detector
ear_detector = torch.hub.load("ultralytics/yolov5", "custom", path=os.path.join(os.path.dirname(__file__), "weights", "ear_YOLOv5_n.pt"))
ear_detector.to(device)

# initialize model
model = timm.create_model("hf-hub:BVRA/MegaDescriptor-T-224", pretrained=True, num_classes=0)

# load state dict containing miscellaneous state or just the model weights
state_dict = torch.load("weights/weights.pt", map_location=device)
if "optimizer" in state_dict:
    model.load_state_dict(state_dict["model"])
else:
    model.load_state_dict(state_dict)

model.to(device)
model.eval()


transforms = T.Compose([
    T.Resize([input_height, input_width]),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

database = np.load("database.npz", allow_pickle=True)
features, identities = database["features"], database["identities"]
features = torch.from_numpy(features).to(device)


def strings2ints(a):
    idx = {v: i for i, v in enumerate(set([*a]))}
    return torch.Tensor([idx[e] for e in a]).to(dtype=torch.int64), {v: k for k, v in idx.items()}

def predict(image, n_individuals_to_return=20):
    with torch.inference_mode():

        output = ear_detector(image)
        n_preds = len(output.pred[0].tolist())
        if n_preds == 0:
            return "Error: Unable to detect elephant ears"

        xyxy = output.xyxy[0].tolist()
        
        noncenterness = [(image.width - (xyxy[i][0] + xyxy[i][2] / 2)) ** 2 + (image.height - (xyxy[i][1] + xyxy[i][3] / 2)) ** 2 for i in range(n_preds)]
        centermost_idx = np.argmin(noncenterness)
        
        image = image.crop(tuple(output.xyxy[0].tolist()[centermost_idx][:4]))

        if output.pred[0].tolist()[centermost_idx][-1] >= 0.5:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)

        image = transforms(image).unsqueeze(0).to(device)

        embedding = model(image)
        similarity = torch.matmul(F.normalize(embedding), F.normalize(features).T)

        similarity_sorted_idx = torch.argsort(similarity[0], descending=True).cpu().numpy().reshape(-1)
        candidates, candidates_unique_idx = np.unique(identities.reshape(-1)[similarity_sorted_idx], return_index=True)

        candidates = candidates[np.argsort(candidates_unique_idx)]
        candidates_similarity = similarity[0, similarity_sorted_idx].numpy()[np.argsort(candidates_unique_idx)]

    return "Individuals ranked by similarity:\n\n" + "\n\n".join([f"{'_'.join(candidate.split('_')[:-1])}" for candidate, candidate_similarity in zip(candidates[:n_individuals_to_return], candidates_similarity[:n_individuals_to_return])])

gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Markdown(),
).launch(auth=(os.environ.get("USERNAME"), os.environ.get("PASSWORD")))