Samuel Stevens
commited on
Commit
·
0ab58fa
1
Parent(s):
e508563
Use cloudflare for ade20k images
Browse files- app.py +9 -12
- constants.py +0 -1
- data.py +15 -96
app.py
CHANGED
@@ -143,9 +143,9 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
|
|
143 |
return torch.load(path, weights_only=True, map_location="cpu")
|
144 |
|
145 |
|
146 |
-
top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
|
147 |
-
top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
|
148 |
-
sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
|
149 |
|
150 |
|
151 |
# mask = torch.ones((sae.cfg.d_sae), dtype=bool)
|
@@ -231,14 +231,13 @@ class SaeActivation(typing.TypedDict):
|
|
231 |
|
232 |
|
233 |
@beartype.beartype
|
234 |
-
def get_image(
|
235 |
-
|
236 |
-
|
237 |
-
seg_sized = data.to_sized(sample["segmentation"])
|
238 |
seg_u8_sized = data.to_u8(seg_sized)
|
239 |
seg_img_sized = data.u8_to_img(seg_u8_sized)
|
240 |
|
241 |
-
return data.img_to_base64(img_sized), data.img_to_base64(seg_img_sized),
|
242 |
|
243 |
|
244 |
@beartype.beartype
|
@@ -253,9 +252,9 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
|
|
253 |
vit, vit_transform = load_vit()
|
254 |
sae = load_sae()
|
255 |
|
256 |
-
|
257 |
|
258 |
-
x = vit_transform(
|
259 |
|
260 |
_, vit_acts_BLPD = vit(x)
|
261 |
vit_acts_PD = (
|
@@ -268,8 +267,6 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
|
|
268 |
acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches")
|
269 |
logger.info("Got SAE activations.")
|
270 |
|
271 |
-
breakpoint()
|
272 |
-
|
273 |
top_img_i, top_values = load_tensors(model_cfg)
|
274 |
logger.info("Loaded top SAE activations for '%s'.", model_name)
|
275 |
|
|
|
143 |
return torch.load(path, weights_only=True, map_location="cpu")
|
144 |
|
145 |
|
146 |
+
# top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
|
147 |
+
# top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
|
148 |
+
# sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
|
149 |
|
150 |
|
151 |
# mask = torch.ones((sae.cfg.d_sae), dtype=bool)
|
|
|
231 |
|
232 |
|
233 |
@beartype.beartype
|
234 |
+
def get_image(i: int) -> tuple[str, str, int]:
|
235 |
+
img_sized = data.to_sized(data.get_image(i))
|
236 |
+
seg_sized = data.to_sized(data.get_seg(i))
|
|
|
237 |
seg_u8_sized = data.to_u8(seg_sized)
|
238 |
seg_img_sized = data.u8_to_img(seg_u8_sized)
|
239 |
|
240 |
+
return data.img_to_base64(img_sized), data.img_to_base64(seg_img_sized), i
|
241 |
|
242 |
|
243 |
@beartype.beartype
|
|
|
252 |
vit, vit_transform = load_vit()
|
253 |
sae = load_sae()
|
254 |
|
255 |
+
img = data.get_image(image_i)
|
256 |
|
257 |
+
x = vit_transform(img)[None, ...].to(DEVICE)
|
258 |
|
259 |
_, vit_acts_BLPD = vit(x)
|
260 |
vit_acts_PD = (
|
|
|
267 |
acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches")
|
268 |
logger.info("Got SAE activations.")
|
269 |
|
|
|
|
|
270 |
top_img_i, top_values = load_tensors(model_cfg)
|
271 |
logger.info("Loaded top SAE activations for '%s'.", model_name)
|
272 |
|
constants.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import torch
|
2 |
|
3 |
-
|
4 |
DINOV2_IMAGENET1K_SCALAR = 2.0181241035461426
|
5 |
|
6 |
|
|
|
1 |
import torch
|
2 |
|
|
|
3 |
DINOV2_IMAGENET1K_SCALAR = 2.0181241035461426
|
4 |
|
5 |
|
data.py
CHANGED
@@ -1,15 +1,13 @@
|
|
1 |
import base64
|
2 |
-
import dataclasses
|
3 |
import functools
|
4 |
import io
|
5 |
import logging
|
6 |
-
import os.path
|
7 |
import random
|
8 |
|
9 |
import beartype
|
10 |
import einops.layers.torch
|
11 |
import numpy as np
|
12 |
-
import
|
13 |
from jaxtyping import UInt8, jaxtyped
|
14 |
from PIL import Image
|
15 |
from torch import Tensor
|
@@ -17,104 +15,25 @@ from torchvision.transforms import v2
|
|
17 |
|
18 |
logger = logging.getLogger("data.py")
|
19 |
|
|
|
|
|
20 |
|
21 |
@beartype.beartype
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
label: str
|
29 |
-
target: int
|
30 |
-
|
31 |
-
samples: list[Sample]
|
32 |
-
|
33 |
-
def __init__(self, root: str, split: str):
|
34 |
-
self.logger = logging.getLogger("ade20k")
|
35 |
-
self.root = root
|
36 |
-
self.split = split
|
37 |
-
self.img_dir = os.path.join(root, "images")
|
38 |
-
self.seg_dir = os.path.join(root, "annotations")
|
39 |
-
|
40 |
-
# Check that we have the right path.
|
41 |
-
for subdir in ("images", "annotations"):
|
42 |
-
if not os.path.isdir(os.path.join(root, subdir)):
|
43 |
-
# Something is missing.
|
44 |
-
if os.path.realpath(root).endswith(subdir):
|
45 |
-
self.logger.warning(
|
46 |
-
"The ADE20K root should contain 'images/' and 'annotations/' directories."
|
47 |
-
)
|
48 |
-
raise ValueError(f"Can't find path '{os.path.join(root, subdir)}'.")
|
49 |
-
|
50 |
-
_, split_mapping = torchvision.datasets.folder.find_classes(self.img_dir)
|
51 |
-
split_lookup: dict[int, str] = {
|
52 |
-
value: key for key, value in split_mapping.items()
|
53 |
-
}
|
54 |
-
self.loader = torchvision.datasets.folder.default_loader
|
55 |
-
|
56 |
-
err_msg = f"Split '{split}' not in '{set(split_lookup.values())}'."
|
57 |
-
assert split in set(split_lookup.values()), err_msg
|
58 |
-
|
59 |
-
# Load all the image paths.
|
60 |
-
imgs: list[str] = [
|
61 |
-
path
|
62 |
-
for path, s in torchvision.datasets.folder.make_dataset(
|
63 |
-
self.img_dir,
|
64 |
-
split_mapping,
|
65 |
-
extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
|
66 |
-
)
|
67 |
-
if split_lookup[s] == split
|
68 |
-
]
|
69 |
-
|
70 |
-
segs: list[str] = [
|
71 |
-
path
|
72 |
-
for path, s in torchvision.datasets.folder.make_dataset(
|
73 |
-
self.seg_dir,
|
74 |
-
split_mapping,
|
75 |
-
extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
|
76 |
-
)
|
77 |
-
if split_lookup[s] == split
|
78 |
-
]
|
79 |
-
|
80 |
-
# Load all the targets, classes and mappings
|
81 |
-
with open(os.path.join(root, "sceneCategories.txt")) as fd:
|
82 |
-
img_labels: list[str] = [line.split()[1] for line in fd.readlines()]
|
83 |
-
|
84 |
-
label_set = sorted(set(img_labels))
|
85 |
-
label_to_idx = {label: i for i, label in enumerate(label_set)}
|
86 |
-
|
87 |
-
self.samples = [
|
88 |
-
self.Sample(img_path, seg_path, label, label_to_idx[label])
|
89 |
-
for img_path, seg_path, label in zip(imgs, segs, img_labels)
|
90 |
-
]
|
91 |
-
|
92 |
-
def __getitem__(self, index: int) -> dict[str, object]:
|
93 |
-
# Convert to dict.
|
94 |
-
sample = dataclasses.asdict(self.samples[index])
|
95 |
-
|
96 |
-
sample["image"] = self.loader(sample.pop("img_path"))
|
97 |
-
sample["segmentation"] = Image.open(sample.pop("seg_path")).convert("L")
|
98 |
-
sample["index"] = index
|
99 |
-
|
100 |
-
return sample
|
101 |
-
|
102 |
-
def __len__(self) -> int:
|
103 |
-
return len(self.samples)
|
104 |
-
|
105 |
-
|
106 |
-
@functools.cache
|
107 |
-
def get_dataset() -> Ade20k:
|
108 |
-
return Ade20k(
|
109 |
-
root="/research/nfs_su_809/workspace/stevens.994/datasets/ade20k/",
|
110 |
-
split="validation",
|
111 |
-
)
|
112 |
|
113 |
|
114 |
@beartype.beartype
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
118 |
|
119 |
|
120 |
@jaxtyped(typechecker=beartype.beartype)
|
|
|
1 |
import base64
|
|
|
2 |
import functools
|
3 |
import io
|
4 |
import logging
|
|
|
5 |
import random
|
6 |
|
7 |
import beartype
|
8 |
import einops.layers.torch
|
9 |
import numpy as np
|
10 |
+
import requests
|
11 |
from jaxtyping import UInt8, jaxtyped
|
12 |
from PIL import Image
|
13 |
from torch import Tensor
|
|
|
15 |
|
16 |
logger = logging.getLogger("data.py")
|
17 |
|
18 |
+
R2_URL = "https://pub-129e98faed1048af94c4d4119ea47be7.r2.dev"
|
19 |
+
|
20 |
|
21 |
@beartype.beartype
|
22 |
+
@functools.lru_cache(maxsize=512)
|
23 |
+
def get_image(i: int) -> Image.Image:
|
24 |
+
fpath = f"/images/ADE_val_{i + 1:08}.jpg"
|
25 |
+
url = R2_URL + fpath
|
26 |
+
logger.info("Getting image from '%s'.", url)
|
27 |
+
return Image.open(requests.get(url, stream=True).raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
@beartype.beartype
|
31 |
+
@functools.lru_cache(maxsize=512)
|
32 |
+
def get_seg(i: int) -> Image.Image:
|
33 |
+
fpath = f"/annotations/ADE_val_{i + 1:08}.png"
|
34 |
+
url = R2_URL + fpath
|
35 |
+
logger.info("Getting annotations from '%s'.", url)
|
36 |
+
return Image.open(requests.get(url, stream=True).raw)
|
37 |
|
38 |
|
39 |
@jaxtyped(typechecker=beartype.beartype)
|