Samuel Stevens commited on
Commit
0ab58fa
·
1 Parent(s): e508563

Use cloudflare for ade20k images

Browse files
Files changed (3) hide show
  1. app.py +9 -12
  2. constants.py +0 -1
  3. data.py +15 -96
app.py CHANGED
@@ -143,9 +143,9 @@ def load_tensor(path: str | pathlib.Path) -> Tensor:
143
  return torch.load(path, weights_only=True, map_location="cpu")
144
 
145
 
146
- top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
147
- top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
148
- sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
149
 
150
 
151
  # mask = torch.ones((sae.cfg.d_sae), dtype=bool)
@@ -231,14 +231,13 @@ class SaeActivation(typing.TypedDict):
231
 
232
 
233
  @beartype.beartype
234
- def get_image(image_i: int) -> tuple[str, str, int]:
235
- sample = data.get_sample(image_i)
236
- img_sized = data.to_sized(sample["image"])
237
- seg_sized = data.to_sized(sample["segmentation"])
238
  seg_u8_sized = data.to_u8(seg_sized)
239
  seg_img_sized = data.u8_to_img(seg_u8_sized)
240
 
241
- return data.img_to_base64(img_sized), data.img_to_base64(seg_img_sized), image_i
242
 
243
 
244
  @beartype.beartype
@@ -253,9 +252,9 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
253
  vit, vit_transform = load_vit()
254
  sae = load_sae()
255
 
256
- sample = data.get_sample(image_i)
257
 
258
- x = vit_transform(sample["image"])[None, ...].to(DEVICE)
259
 
260
  _, vit_acts_BLPD = vit(x)
261
  vit_acts_PD = (
@@ -268,8 +267,6 @@ def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]
268
  acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches")
269
  logger.info("Got SAE activations.")
270
 
271
- breakpoint()
272
-
273
  top_img_i, top_values = load_tensors(model_cfg)
274
  logger.info("Loaded top SAE activations for '%s'.", model_name)
275
 
 
143
  return torch.load(path, weights_only=True, map_location="cpu")
144
 
145
 
146
+ # top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt")
147
+ # top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt")
148
+ # sparsity = load_tensor(CWD / "assets" / "sparsity.pt")
149
 
150
 
151
  # mask = torch.ones((sae.cfg.d_sae), dtype=bool)
 
231
 
232
 
233
  @beartype.beartype
234
+ def get_image(i: int) -> tuple[str, str, int]:
235
+ img_sized = data.to_sized(data.get_image(i))
236
+ seg_sized = data.to_sized(data.get_seg(i))
 
237
  seg_u8_sized = data.to_u8(seg_sized)
238
  seg_img_sized = data.u8_to_img(seg_u8_sized)
239
 
240
+ return data.img_to_base64(img_sized), data.img_to_base64(seg_img_sized), i
241
 
242
 
243
  @beartype.beartype
 
252
  vit, vit_transform = load_vit()
253
  sae = load_sae()
254
 
255
+ img = data.get_image(image_i)
256
 
257
+ x = vit_transform(img)[None, ...].to(DEVICE)
258
 
259
  _, vit_acts_BLPD = vit(x)
260
  vit_acts_PD = (
 
267
  acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches")
268
  logger.info("Got SAE activations.")
269
 
 
 
270
  top_img_i, top_values = load_tensors(model_cfg)
271
  logger.info("Loaded top SAE activations for '%s'.", model_name)
272
 
constants.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch
2
 
3
-
4
  DINOV2_IMAGENET1K_SCALAR = 2.0181241035461426
5
 
6
 
 
1
  import torch
2
 
 
3
  DINOV2_IMAGENET1K_SCALAR = 2.0181241035461426
4
 
5
 
data.py CHANGED
@@ -1,15 +1,13 @@
1
  import base64
2
- import dataclasses
3
  import functools
4
  import io
5
  import logging
6
- import os.path
7
  import random
8
 
9
  import beartype
10
  import einops.layers.torch
11
  import numpy as np
12
- import torchvision.datasets.folder
13
  from jaxtyping import UInt8, jaxtyped
14
  from PIL import Image
15
  from torch import Tensor
@@ -17,104 +15,25 @@ from torchvision.transforms import v2
17
 
18
  logger = logging.getLogger("data.py")
19
 
 
 
20
 
21
  @beartype.beartype
22
- class Ade20k:
23
- @beartype.beartype
24
- @dataclasses.dataclass(frozen=True)
25
- class Sample:
26
- img_path: str
27
- seg_path: str
28
- label: str
29
- target: int
30
-
31
- samples: list[Sample]
32
-
33
- def __init__(self, root: str, split: str):
34
- self.logger = logging.getLogger("ade20k")
35
- self.root = root
36
- self.split = split
37
- self.img_dir = os.path.join(root, "images")
38
- self.seg_dir = os.path.join(root, "annotations")
39
-
40
- # Check that we have the right path.
41
- for subdir in ("images", "annotations"):
42
- if not os.path.isdir(os.path.join(root, subdir)):
43
- # Something is missing.
44
- if os.path.realpath(root).endswith(subdir):
45
- self.logger.warning(
46
- "The ADE20K root should contain 'images/' and 'annotations/' directories."
47
- )
48
- raise ValueError(f"Can't find path '{os.path.join(root, subdir)}'.")
49
-
50
- _, split_mapping = torchvision.datasets.folder.find_classes(self.img_dir)
51
- split_lookup: dict[int, str] = {
52
- value: key for key, value in split_mapping.items()
53
- }
54
- self.loader = torchvision.datasets.folder.default_loader
55
-
56
- err_msg = f"Split '{split}' not in '{set(split_lookup.values())}'."
57
- assert split in set(split_lookup.values()), err_msg
58
-
59
- # Load all the image paths.
60
- imgs: list[str] = [
61
- path
62
- for path, s in torchvision.datasets.folder.make_dataset(
63
- self.img_dir,
64
- split_mapping,
65
- extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
66
- )
67
- if split_lookup[s] == split
68
- ]
69
-
70
- segs: list[str] = [
71
- path
72
- for path, s in torchvision.datasets.folder.make_dataset(
73
- self.seg_dir,
74
- split_mapping,
75
- extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
76
- )
77
- if split_lookup[s] == split
78
- ]
79
-
80
- # Load all the targets, classes and mappings
81
- with open(os.path.join(root, "sceneCategories.txt")) as fd:
82
- img_labels: list[str] = [line.split()[1] for line in fd.readlines()]
83
-
84
- label_set = sorted(set(img_labels))
85
- label_to_idx = {label: i for i, label in enumerate(label_set)}
86
-
87
- self.samples = [
88
- self.Sample(img_path, seg_path, label, label_to_idx[label])
89
- for img_path, seg_path, label in zip(imgs, segs, img_labels)
90
- ]
91
-
92
- def __getitem__(self, index: int) -> dict[str, object]:
93
- # Convert to dict.
94
- sample = dataclasses.asdict(self.samples[index])
95
-
96
- sample["image"] = self.loader(sample.pop("img_path"))
97
- sample["segmentation"] = Image.open(sample.pop("seg_path")).convert("L")
98
- sample["index"] = index
99
-
100
- return sample
101
-
102
- def __len__(self) -> int:
103
- return len(self.samples)
104
-
105
-
106
- @functools.cache
107
- def get_dataset() -> Ade20k:
108
- return Ade20k(
109
- root="/research/nfs_su_809/workspace/stevens.994/datasets/ade20k/",
110
- split="validation",
111
- )
112
 
113
 
114
  @beartype.beartype
115
- def get_sample(i: int) -> dict[str, object]:
116
- dataset = get_dataset()
117
- return dataset[i]
 
 
 
118
 
119
 
120
  @jaxtyped(typechecker=beartype.beartype)
 
1
  import base64
 
2
  import functools
3
  import io
4
  import logging
 
5
  import random
6
 
7
  import beartype
8
  import einops.layers.torch
9
  import numpy as np
10
+ import requests
11
  from jaxtyping import UInt8, jaxtyped
12
  from PIL import Image
13
  from torch import Tensor
 
15
 
16
  logger = logging.getLogger("data.py")
17
 
18
+ R2_URL = "https://pub-129e98faed1048af94c4d4119ea47be7.r2.dev"
19
+
20
 
21
  @beartype.beartype
22
+ @functools.lru_cache(maxsize=512)
23
+ def get_image(i: int) -> Image.Image:
24
+ fpath = f"/images/ADE_val_{i + 1:08}.jpg"
25
+ url = R2_URL + fpath
26
+ logger.info("Getting image from '%s'.", url)
27
+ return Image.open(requests.get(url, stream=True).raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  @beartype.beartype
31
+ @functools.lru_cache(maxsize=512)
32
+ def get_seg(i: int) -> Image.Image:
33
+ fpath = f"/annotations/ADE_val_{i + 1:08}.png"
34
+ url = R2_URL + fpath
35
+ logger.info("Getting annotations from '%s'.", url)
36
+ return Image.open(requests.get(url, stream=True).raw)
37
 
38
 
39
  @jaxtyped(typechecker=beartype.beartype)