sync with upstream codebase
Browse files
app.py
CHANGED
@@ -44,7 +44,7 @@ def strings2ints(a):
|
|
44 |
idx = {v: i for i, v in enumerate(set([*a]))}
|
45 |
return torch.Tensor([idx[e] for e in a]).to(dtype=torch.int64), {v: k for k, v in idx.items()}
|
46 |
|
47 |
-
def predict(image):
|
48 |
with torch.inference_mode():
|
49 |
|
50 |
output = ear_detector(image)
|
@@ -68,14 +68,15 @@ def predict(image):
|
|
68 |
similarity = torch.matmul(F.normalize(embedding), F.normalize(features).T)
|
69 |
|
70 |
similarity_sorted_idx = torch.argsort(similarity[0], descending=True).cpu().numpy().reshape(-1)
|
71 |
-
candidates = identities.reshape(-1)[similarity_sorted_idx]
|
72 |
-
candidates_similarity = similarity[0, similarity_sorted_idx].tolist()
|
73 |
|
74 |
-
|
|
|
75 |
|
|
|
76 |
|
77 |
gr.Interface(
|
78 |
fn=predict,
|
79 |
inputs=gr.Image(type="pil"),
|
80 |
-
outputs=gr.
|
81 |
-
).launch(share=True)
|
|
|
44 |
idx = {v: i for i, v in enumerate(set([*a]))}
|
45 |
return torch.Tensor([idx[e] for e in a]).to(dtype=torch.int64), {v: k for k, v in idx.items()}
|
46 |
|
47 |
+
def predict(image, n_individuals_to_return=20):
|
48 |
with torch.inference_mode():
|
49 |
|
50 |
output = ear_detector(image)
|
|
|
68 |
similarity = torch.matmul(F.normalize(embedding), F.normalize(features).T)
|
69 |
|
70 |
similarity_sorted_idx = torch.argsort(similarity[0], descending=True).cpu().numpy().reshape(-1)
|
71 |
+
candidates, candidates_unique_idx = np.unique(identities.reshape(-1)[similarity_sorted_idx], return_index=True)
|
|
|
72 |
|
73 |
+
candidates = candidates[np.argsort(candidates_unique_idx)]
|
74 |
+
candidates_similarity = similarity[0, similarity_sorted_idx].numpy()[np.argsort(candidates_unique_idx)]
|
75 |
|
76 |
+
return "Individuals ranked by similarity:\n\n" + "\n\n".join([f"{'_'.join(candidate.split('_')[:-1])}" for candidate, candidate_similarity in zip(candidates[:n_individuals_to_return], candidates_similarity[:n_individuals_to_return])])
|
77 |
|
78 |
gr.Interface(
|
79 |
fn=predict,
|
80 |
inputs=gr.Image(type="pil"),
|
81 |
+
outputs=gr.Markdown(),
|
82 |
+
).launch(share=True, auth=("user", "mkb6makCBRYGUAd3"))
|