timmhaucke commited on
Commit
77cd70d
·
verified ·
1 Parent(s): 85f31a1

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +9 -0
  2. app.py +81 -0
  3. database.npz +3 -0
  4. weights.pt +3 -0
Dockerfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /usr/src/app
4
+ COPY . .
5
+ RUN pip install --no-cache-dir -r requirements.txt
6
+ EXPOSE 7860
7
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
8
+
9
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as T
6
+ import timm
7
+ from PIL import Image
8
+ import gradio as gr
9
+
10
+ # hyperparameters
11
+ device = torch.device("cpu")
12
+ input_width, input_height = 224, 224
13
+
14
+ # load ear detector
15
+ ear_detector = torch.hub.load("ultralytics/yolov5", "custom", path=os.path.join(os.path.dirname(__file__), "weights", "ear_YOLOv5_n.pt"))
16
+ ear_detector.to(device)
17
+
18
+ # initialize model
19
+ model = timm.create_model("hf-hub:BVRA/MegaDescriptor-T-224", pretrained=True, num_classes=0)
20
+
21
+ # load state dict containing miscellaneous state or just the model weights
22
+ state_dict = torch.load("weights.pt", map_location=device)
23
+ if "optimizer" in state_dict:
24
+ model.load_state_dict(state_dict["model"])
25
+ else:
26
+ model.load_state_dict(state_dict)
27
+
28
+ model.to(device)
29
+ model.eval()
30
+
31
+
32
+ transforms = T.Compose([
33
+ T.Resize([input_height, input_width]),
34
+ T.ToTensor(),
35
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
36
+ ])
37
+
38
+ database = np.load("database.npz", allow_pickle=True)
39
+ features, identities = database["features"], database["identities"]
40
+ features = torch.from_numpy(features).to(device)
41
+
42
+
43
+ def strings2ints(a):
44
+ idx = {v: i for i, v in enumerate(set([*a]))}
45
+ return torch.Tensor([idx[e] for e in a]).to(dtype=torch.int64), {v: k for k, v in idx.items()}
46
+
47
+ def predict(image):
48
+ with torch.inference_mode():
49
+
50
+ output = ear_detector(image)
51
+ n_preds = len(output.pred[0].tolist())
52
+ if n_preds == 0:
53
+ return "Error: Unable to detect elephant ears"
54
+
55
+ xyxy = output.xyxy[0].tolist()
56
+
57
+ noncenterness = [(image.width - (xyxy[i][0] + xyxy[i][2] / 2)) ** 2 + (image.height - (xyxy[i][1] + xyxy[i][3] / 2)) ** 2 for i in range(n_preds)]
58
+ centermost_idx = np.argmin(noncenterness)
59
+
60
+ image = image.crop(tuple(output.xyxy[0].tolist()[centermost_idx][:4]))
61
+
62
+ if output.pred[0].tolist()[centermost_idx][-1] >= 0.5:
63
+ image = image.transpose(Image.FLIP_LEFT_RIGHT)
64
+
65
+ image = transforms(image).unsqueeze(0).to(device)
66
+
67
+ embedding = model(image)
68
+ similarity = torch.matmul(F.normalize(embedding), F.normalize(features).T)
69
+
70
+ similarity_sorted_idx = torch.argsort(similarity[0], descending=True).cpu().numpy().reshape(-1)
71
+ candidates = identities.reshape(-1)[similarity_sorted_idx].tolist()
72
+ candidates_similarity = similarity[0, similarity_sorted_idx].tolist()
73
+
74
+ return f"We are about {max(0, candidates_similarity[0]):.0%} confident that this elephant is {'_'.join(candidates[0].split('_')[:-1])}"
75
+
76
+
77
+ gr.Interface(
78
+ fn=predict,
79
+ inputs=gr.Image(type="pil"),
80
+ outputs=gr.Label(),
81
+ ).launch(share=True)
database.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1985b902ee2826f7215bced0c330c0a6b0bda13559779071cfb34a9ee0342c03
3
+ size 23072727
weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da82dcf2c50dec71cafefd803cc8d078cdcf9226cbf571b7adde0dd4b14c6e7a
3
+ size 224710970