|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision.transforms as T |
|
import timm |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
|
|
device = torch.device("cpu") |
|
input_width, input_height = 224, 224 |
|
|
|
|
|
ear_detector = torch.hub.load("ultralytics/yolov5", "custom", path=os.path.join(os.path.dirname(__file__), "weights", "ear_YOLOv5_n.pt")) |
|
ear_detector.to(device) |
|
|
|
|
|
model = timm.create_model("hf-hub:BVRA/MegaDescriptor-T-224", pretrained=True, num_classes=0) |
|
|
|
|
|
state_dict = torch.load("weights/weights.pt", map_location=device) |
|
if "optimizer" in state_dict: |
|
model.load_state_dict(state_dict["model"]) |
|
else: |
|
model.load_state_dict(state_dict) |
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
transforms = T.Compose([ |
|
T.Resize([input_height, input_width]), |
|
T.ToTensor(), |
|
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
]) |
|
|
|
database = np.load("database.npz", allow_pickle=True) |
|
features, identities = database["features"], database["identities"] |
|
features = torch.from_numpy(features).to(device) |
|
|
|
|
|
def strings2ints(a): |
|
idx = {v: i for i, v in enumerate(set([*a]))} |
|
return torch.Tensor([idx[e] for e in a]).to(dtype=torch.int64), {v: k for k, v in idx.items()} |
|
|
|
def predict(image, n_individuals_to_return=20): |
|
with torch.inference_mode(): |
|
|
|
output = ear_detector(image) |
|
n_preds = len(output.pred[0].tolist()) |
|
if n_preds == 0: |
|
return "Error: Unable to detect elephant ears" |
|
|
|
xyxy = output.xyxy[0].tolist() |
|
|
|
noncenterness = [(image.width - (xyxy[i][0] + xyxy[i][2] / 2)) ** 2 + (image.height - (xyxy[i][1] + xyxy[i][3] / 2)) ** 2 for i in range(n_preds)] |
|
centermost_idx = np.argmin(noncenterness) |
|
|
|
image = image.crop(tuple(output.xyxy[0].tolist()[centermost_idx][:4])) |
|
|
|
if output.pred[0].tolist()[centermost_idx][-1] >= 0.5: |
|
image = image.transpose(Image.FLIP_LEFT_RIGHT) |
|
|
|
image = transforms(image).unsqueeze(0).to(device) |
|
|
|
embedding = model(image) |
|
similarity = torch.matmul(F.normalize(embedding), F.normalize(features).T) |
|
|
|
similarity_sorted_idx = torch.argsort(similarity[0], descending=True).cpu().numpy().reshape(-1) |
|
candidates, candidates_unique_idx = np.unique(identities.reshape(-1)[similarity_sorted_idx], return_index=True) |
|
|
|
candidates = candidates[np.argsort(candidates_unique_idx)] |
|
candidates_similarity = similarity[0, similarity_sorted_idx].numpy()[np.argsort(candidates_unique_idx)] |
|
|
|
return "Individuals ranked by similarity:\n\n" + "\n\n".join([f"{'_'.join(candidate.split('_')[:-1])}" for candidate, candidate_similarity in zip(candidates[:n_individuals_to_return], candidates_similarity[:n_individuals_to_return])]) |
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Markdown(), |
|
).launch(auth=(os.environ.get("USERNAME"), os.environ.get("PASSWORD"))) |