Spaces:
Running
Running
Samuel Stevens
commited on
Commit
·
6e5adf0
1
Parent(s):
d4005aa
add open-domain classification back
Browse files- .gitattributes +1 -1
- app.py +115 -112
- make_txt_embedding.py +21 -0
- txt_emb_species.json +3 -0
- txt_emb_species.npy +3 -0
.gitattributes
CHANGED
|
@@ -34,6 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
|
| 37 |
-
|
| 38 |
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
|
| 37 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
| 38 |
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
|
|
@@ -8,15 +10,18 @@ import torch.nn.functional as F
|
|
| 8 |
from open_clip import create_model, get_tokenizer
|
| 9 |
from torchvision import transforms
|
| 10 |
|
| 11 |
-
import lib
|
| 12 |
from templates import openai_imagenet_template
|
| 13 |
|
| 14 |
hf_token = os.getenv("HF_TOKEN")
|
| 15 |
|
| 16 |
model_str = "hf-hub:imageomics/bioclip"
|
| 17 |
tokenizer_str = "ViT-B-16"
|
| 18 |
-
|
| 19 |
-
txt_emb_npy = "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 22 |
|
|
@@ -33,12 +38,12 @@ preprocess_img = transforms.Compose(
|
|
| 33 |
|
| 34 |
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
zero_shot_examples = [
|
| 43 |
[
|
| 44 |
"examples/Ursus-arctos.jpeg",
|
|
@@ -73,6 +78,10 @@ zero_shot_examples = [
|
|
| 73 |
]
|
| 74 |
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
@torch.no_grad()
|
| 77 |
def get_txt_features(classnames, templates):
|
| 78 |
all_features = []
|
|
@@ -102,52 +111,38 @@ def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
|
|
| 102 |
|
| 103 |
|
| 104 |
@torch.no_grad()
|
| 105 |
-
def open_domain_classification(img, rank: int) ->
|
| 106 |
"""
|
| 107 |
-
Predicts from the
|
|
|
|
|
|
|
| 108 |
"""
|
| 109 |
img = preprocess_img(img).to(device)
|
| 110 |
img_features = model.encode_image(img.unsqueeze(0))
|
| 111 |
img_features = F.normalize(img_features, dim=-1)
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
name = []
|
| 116 |
-
for _ in range(rank + 1):
|
| 117 |
-
children = tuple(zip(*name_lookup.children(name)))
|
| 118 |
-
if not children:
|
| 119 |
-
break
|
| 120 |
-
values, indices = children
|
| 121 |
-
txt_features = txt_emb[:, indices].to(device)
|
| 122 |
-
logits = (model.logit_scale.exp() * img_features @ txt_features).view(-1)
|
| 123 |
-
|
| 124 |
-
probs = F.softmax(logits, dim=0).to("cpu").tolist()
|
| 125 |
-
parent = " ".join(name)
|
| 126 |
-
outputs.append(
|
| 127 |
-
{f"{parent} {value}": prob for value, prob in zip(values, probs)}
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
top = values[logits.argmax()]
|
| 131 |
-
name.append(top)
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
|
|
|
|
| 138 |
|
| 139 |
-
|
| 140 |
-
return [
|
| 141 |
-
gr.Label(
|
| 142 |
-
num_top_classes=5, label=rank, show_label=True, visible=(6 - i <= choice)
|
| 143 |
-
)
|
| 144 |
-
for i, rank in enumerate(reversed(ranks))
|
| 145 |
-
]
|
| 146 |
|
| 147 |
|
| 148 |
-
def
|
| 149 |
-
|
| 150 |
-
return lib.TaxonomicTree.from_dict(json.load(fd))
|
| 151 |
|
| 152 |
|
| 153 |
if __name__ == "__main__":
|
|
@@ -161,8 +156,9 @@ if __name__ == "__main__":
|
|
| 161 |
|
| 162 |
tokenizer = get_tokenizer(tokenizer_str)
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
|
|
|
| 166 |
|
| 167 |
done = txt_emb.any(axis=0).sum().item()
|
| 168 |
total = txt_emb.shape[1]
|
|
@@ -173,69 +169,76 @@ if __name__ == "__main__":
|
|
| 173 |
with gr.Blocks() as app:
|
| 174 |
img_input = gr.Image(height=512)
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
# [img_input, *open_domain_outputs], flagging_dir="logs/flagged"
|
| 204 |
-
# )
|
| 205 |
-
# open_domain_flag_btn.click(
|
| 206 |
-
# lambda *args: open_domain_callback.flag(args),
|
| 207 |
-
# [img_input, *open_domain_outputs],
|
| 208 |
-
# None,
|
| 209 |
-
# preprocess=False,
|
| 210 |
-
# )
|
| 211 |
-
|
| 212 |
-
# with gr.Tab("Zero-Shot"):
|
| 213 |
-
with gr.Row():
|
| 214 |
-
with gr.Column():
|
| 215 |
-
classes_txt = gr.Textbox(
|
| 216 |
-
placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
|
| 217 |
-
lines=3,
|
| 218 |
-
label="Classes",
|
| 219 |
-
show_label=True,
|
| 220 |
-
info="Use taxonomic names where possible; include common names if possible.",
|
| 221 |
)
|
| 222 |
-
zero_shot_btn = gr.Button("Submit", variant="primary")
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
outputs=[zero_shot_output],
|
| 237 |
)
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
zero_shot_callback = gr.HuggingFaceDatasetSaver(
|
| 240 |
hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
|
| 241 |
)
|
|
@@ -249,15 +252,15 @@ if __name__ == "__main__":
|
|
| 249 |
preprocess=False,
|
| 250 |
)
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
|
| 262 |
zero_shot_btn.click(
|
| 263 |
fn=zero_shot_classification,
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import heapq
|
| 3 |
import json
|
| 4 |
import os
|
| 5 |
|
|
|
|
| 10 |
from open_clip import create_model, get_tokenizer
|
| 11 |
from torchvision import transforms
|
| 12 |
|
|
|
|
| 13 |
from templates import openai_imagenet_template
|
| 14 |
|
| 15 |
hf_token = os.getenv("HF_TOKEN")
|
| 16 |
|
| 17 |
model_str = "hf-hub:imageomics/bioclip"
|
| 18 |
tokenizer_str = "ViT-B-16"
|
| 19 |
+
|
| 20 |
+
txt_emb_npy = "txt_emb_species.npy"
|
| 21 |
+
txt_names_json = "txt_emb_species.json"
|
| 22 |
+
|
| 23 |
+
min_prob = 1e-9
|
| 24 |
+
k = 5
|
| 25 |
|
| 26 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 27 |
|
|
|
|
| 38 |
|
| 39 |
ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
|
| 40 |
|
| 41 |
+
open_domain_examples = [
|
| 42 |
+
["examples/Ursus-arctos.jpeg", "Species"],
|
| 43 |
+
["examples/Phoca-vitulina.png", "Species"],
|
| 44 |
+
["examples/Felis-catus.jpeg", "Genus"],
|
| 45 |
+
["examples/Sarcoscypha-coccinea.jpeg", "Order"],
|
| 46 |
+
]
|
| 47 |
zero_shot_examples = [
|
| 48 |
[
|
| 49 |
"examples/Ursus-arctos.jpeg",
|
|
|
|
| 78 |
]
|
| 79 |
|
| 80 |
|
| 81 |
+
def indexed(lst, indices):
|
| 82 |
+
return [lst[i] for i in indices]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
@torch.no_grad()
|
| 86 |
def get_txt_features(classnames, templates):
|
| 87 |
all_features = []
|
|
|
|
| 111 |
|
| 112 |
|
| 113 |
@torch.no_grad()
|
| 114 |
+
def open_domain_classification(img, rank: int) -> dict[str, float]:
|
| 115 |
"""
|
| 116 |
+
Predicts from the entire tree of life.
|
| 117 |
+
If targeting a higher rank than species, then this function predicts among all
|
| 118 |
+
species, then sums up species-level probabilities for the given rank.
|
| 119 |
"""
|
| 120 |
img = preprocess_img(img).to(device)
|
| 121 |
img_features = model.encode_image(img.unsqueeze(0))
|
| 122 |
img_features = F.normalize(img_features, dim=-1)
|
| 123 |
|
| 124 |
+
logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
|
| 125 |
+
probs = F.softmax(logits, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
# If predicting species, no need to sum probabilities.
|
| 128 |
+
if rank + 1 == len(ranks):
|
| 129 |
+
topk = probs.topk(k)
|
| 130 |
+
return {
|
| 131 |
+
" ".join(txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
|
| 132 |
+
}
|
| 133 |
|
| 134 |
+
# Sum up by the rank
|
| 135 |
+
output = collections.defaultdict(float)
|
| 136 |
+
for i in torch.nonzero(probs > min_prob).squeeze():
|
| 137 |
+
output[" ".join(txt_names[i][: rank + 1])] += probs[i]
|
| 138 |
|
| 139 |
+
topk_names = heapq.nlargest(k, output, key=output.get)
|
| 140 |
|
| 141 |
+
return {name: output[name] for name in topk_names}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
+
def change_output(choice):
|
| 145 |
+
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
if __name__ == "__main__":
|
|
|
|
| 156 |
|
| 157 |
tokenizer = get_tokenizer(tokenizer_str)
|
| 158 |
|
| 159 |
+
txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r")).to(device)
|
| 160 |
+
with open(txt_names_json) as fd:
|
| 161 |
+
txt_names = json.load(fd)
|
| 162 |
|
| 163 |
done = txt_emb.any(axis=0).sum().item()
|
| 164 |
total = txt_emb.shape[1]
|
|
|
|
| 169 |
with gr.Blocks() as app:
|
| 170 |
img_input = gr.Image(height=512)
|
| 171 |
|
| 172 |
+
with gr.Tab("Open-Ended"):
|
| 173 |
+
with gr.Row():
|
| 174 |
+
with gr.Column():
|
| 175 |
+
rank_dropdown = gr.Dropdown(
|
| 176 |
+
label="Taxonomic Rank",
|
| 177 |
+
info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
|
| 178 |
+
choices=ranks,
|
| 179 |
+
value="Species",
|
| 180 |
+
type="index",
|
| 181 |
+
)
|
| 182 |
+
open_domain_btn = gr.Button("Submit", variant="primary")
|
| 183 |
+
with gr.Column():
|
| 184 |
+
open_domain_output = gr.Label(
|
| 185 |
+
num_top_classes=k,
|
| 186 |
+
label="Prediction",
|
| 187 |
+
show_label=True,
|
| 188 |
+
value=None,
|
| 189 |
+
)
|
| 190 |
+
open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
|
| 191 |
+
|
| 192 |
+
with gr.Row():
|
| 193 |
+
gr.Examples(
|
| 194 |
+
examples=open_domain_examples,
|
| 195 |
+
inputs=[img_input, rank_dropdown],
|
| 196 |
+
cache_examples=True,
|
| 197 |
+
fn=open_domain_classification,
|
| 198 |
+
outputs=[open_domain_output],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
)
|
|
|
|
| 200 |
|
| 201 |
+
open_domain_callback = gr.HuggingFaceDatasetSaver(
|
| 202 |
+
hf_token, "imageomics/bioclip-demo-open-domain-mistakes", private=True
|
| 203 |
+
)
|
| 204 |
+
open_domain_callback.setup(
|
| 205 |
+
[img_input, rank_dropdown, open_domain_output],
|
| 206 |
+
flagging_dir="logs/flagged",
|
| 207 |
+
)
|
| 208 |
+
open_domain_flag_btn.click(
|
| 209 |
+
lambda *args: open_domain_callback.flag(args),
|
| 210 |
+
[img_input, rank_dropdown, open_domain_output],
|
| 211 |
+
None,
|
| 212 |
+
preprocess=False,
|
|
|
|
| 213 |
)
|
| 214 |
|
| 215 |
+
with gr.Tab("Zero-Shot"):
|
| 216 |
+
with gr.Row():
|
| 217 |
+
with gr.Column():
|
| 218 |
+
classes_txt = gr.Textbox(
|
| 219 |
+
placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
|
| 220 |
+
lines=3,
|
| 221 |
+
label="Classes",
|
| 222 |
+
show_label=True,
|
| 223 |
+
info="Use taxonomic names where possible; include common names if possible.",
|
| 224 |
+
)
|
| 225 |
+
zero_shot_btn = gr.Button("Submit", variant="primary")
|
| 226 |
+
|
| 227 |
+
with gr.Column():
|
| 228 |
+
zero_shot_output = gr.Label(
|
| 229 |
+
num_top_classes=k, label="Prediction", show_label=True
|
| 230 |
+
)
|
| 231 |
+
zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
|
| 232 |
+
|
| 233 |
+
with gr.Row():
|
| 234 |
+
gr.Examples(
|
| 235 |
+
examples=zero_shot_examples,
|
| 236 |
+
inputs=[img_input, classes_txt],
|
| 237 |
+
cache_examples=True,
|
| 238 |
+
fn=zero_shot_classification,
|
| 239 |
+
outputs=[zero_shot_output],
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
zero_shot_callback = gr.HuggingFaceDatasetSaver(
|
| 243 |
hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
|
| 244 |
)
|
|
|
|
| 252 |
preprocess=False,
|
| 253 |
)
|
| 254 |
|
| 255 |
+
rank_dropdown.change(
|
| 256 |
+
fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
|
| 257 |
+
)
|
| 258 |
|
| 259 |
+
open_domain_btn.click(
|
| 260 |
+
fn=open_domain_classification,
|
| 261 |
+
inputs=[img_input, rank_dropdown],
|
| 262 |
+
outputs=[open_domain_output],
|
| 263 |
+
)
|
| 264 |
|
| 265 |
zero_shot_btn.click(
|
| 266 |
fn=zero_shot_classification,
|
make_txt_embedding.py
CHANGED
|
@@ -112,6 +112,26 @@ def convert_txt_features_to_avgs(name_lookup):
|
|
| 112 |
)
|
| 113 |
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def get_name_lookup(catalog_path, cache_path):
|
| 116 |
if os.path.isfile(cache_path):
|
| 117 |
with open(cache_path) as fd:
|
|
@@ -170,3 +190,4 @@ if __name__ == "__main__":
|
|
| 170 |
tokenizer = get_tokenizer(tokenizer_str)
|
| 171 |
write_txt_features(name_lookup)
|
| 172 |
convert_txt_features_to_avgs(name_lookup)
|
|
|
|
|
|
| 112 |
)
|
| 113 |
|
| 114 |
|
| 115 |
+
def convert_txt_features_to_species_only(name_lookup):
|
| 116 |
+
assert os.path.isfile(args.out_path)
|
| 117 |
+
|
| 118 |
+
all_features = np.load(args.out_path)
|
| 119 |
+
logger.info("Loaded text features from disk.")
|
| 120 |
+
|
| 121 |
+
species = [(d, i) for d, i in name_lookup.descendants() if len(d) == 7]
|
| 122 |
+
species_features = np.zeros((512, len(species)), dtype=np.float32)
|
| 123 |
+
species_names = [""] * len(species)
|
| 124 |
+
|
| 125 |
+
for new_i, (name, old_i) in enumerate(tqdm(species)):
|
| 126 |
+
species_features[:, new_i] = all_features[:, old_i]
|
| 127 |
+
species_names[new_i] = name
|
| 128 |
+
|
| 129 |
+
out_path, ext = os.path.splitext(args.out_path)
|
| 130 |
+
np.save(f"{out_path}_species{ext}", species_features)
|
| 131 |
+
with open(f"{out_path}_species.json", "w") as fd:
|
| 132 |
+
json.dump(species_names, fd, indent=2)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
def get_name_lookup(catalog_path, cache_path):
|
| 136 |
if os.path.isfile(cache_path):
|
| 137 |
with open(cache_path) as fd:
|
|
|
|
| 190 |
tokenizer = get_tokenizer(tokenizer_str)
|
| 191 |
write_txt_features(name_lookup)
|
| 192 |
convert_txt_features_to_avgs(name_lookup)
|
| 193 |
+
convert_txt_features_to_species_only(name_lookup)
|
txt_emb_species.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c71babd1b7bc275a1dbb12fd36e6329bcc2487784c0b7be10c2f4d0031d34211
|
| 3 |
+
size 50445969
|
txt_emb_species.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:91ce02dff2433222e3138b8bf7eefa1dd74b30f4d406c16cd3301f66d65ab4ed
|
| 3 |
+
size 787435648
|