|
import functools |
|
import io |
|
import json |
|
import logging |
|
import os.path |
|
import pathlib |
|
import typing |
|
|
|
import beartype |
|
import einops |
|
import einops.layers.torch |
|
import gradio as gr |
|
import saev.activations |
|
import saev.config |
|
import saev.nn |
|
import saev.visuals |
|
import torch |
|
from jaxtyping import Float, Int, UInt8, jaxtyped |
|
from PIL import Image |
|
from torch import Tensor |
|
|
|
import constants |
|
import data |
|
|
|
logger = logging.getLogger("app.py") |
|
|
|
|
|
|
|
|
|
|
|
|
|
DEBUG = False |
|
"""Whether we are debugging.""" |
|
|
|
max_frequency = 1e-2 |
|
"""Maximum frequency. Any feature that fires more than this is ignored.""" |
|
|
|
n_sae_latents = 3 |
|
"""Number of SAE latents to show.""" |
|
|
|
n_sae_examples = 4 |
|
"""Number of SAE examples per latent to show.""" |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
"""Hardware accelerator, if any.""" |
|
|
|
RESIZE_SIZE = 512 |
|
"""Resize shorter size to this size in pixels.""" |
|
|
|
CROP_SIZE = (448, 448) |
|
"""Crop size in pixels.""" |
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
"""Hardware accelerator, if any.""" |
|
|
|
CWD = pathlib.Path(".") |
|
"""Current working directory.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.cache |
|
def load_vit() -> tuple[saev.activations.WrappedVisionTransformer, typing.Callable]: |
|
vit = ( |
|
saev.activations.WrappedVisionTransformer( |
|
saev.config.Activations( |
|
model_family="dinov2", |
|
model_ckpt="dinov2_vitb14_reg", |
|
layers=[-2], |
|
n_patches_per_img=256, |
|
) |
|
) |
|
.to(DEVICE) |
|
.eval() |
|
) |
|
vit_transform = saev.activations.make_img_transform("dinov2", "dinov2_vitb14_reg") |
|
logger.info("Loaded ViT.") |
|
|
|
return vit, vit_transform |
|
|
|
|
|
@functools.cache |
|
def load_sae() -> saev.nn.SparseAutoencoder: |
|
""" |
|
Loads a sparse autoencoder from disk. |
|
""" |
|
sae_ckpt_fpath = CWD / "assets" / "sae.pt" |
|
sae = saev.nn.load(str(sae_ckpt_fpath)) |
|
sae.to(device).eval() |
|
return sae |
|
|
|
|
|
@functools.cache |
|
def load_clf() -> torch.nn.Module: |
|
|
|
head_ckpt_fpath = CWD / "assets" / "clf.pt" |
|
with open(head_ckpt_fpath, "rb") as fd: |
|
kwargs = json.loads(fd.readline().decode()) |
|
buffer = io.BytesIO(fd.read()) |
|
|
|
model = torch.nn.Linear(**kwargs) |
|
state_dict = torch.load(buffer, weights_only=True, map_location=device) |
|
model.load_state_dict(state_dict) |
|
model = model.to(device).eval() |
|
return model |
|
|
|
|
|
class RestOfDinoV2(torch.nn.Module): |
|
def __init__(self, *, n_end_layers: int): |
|
super().__init__() |
|
self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg") |
|
self.n_end_layers = n_end_layers |
|
|
|
def forward_start(self, x: Float[Tensor, "batch channels width height"]): |
|
x_BPD = self.vit.prepare_tokens_with_masks(x) |
|
for blk in self.vit.blocks[: -self.n_end_layers]: |
|
x_BPD = blk(x_BPD) |
|
|
|
return x_BPD |
|
|
|
def forward_end(self, x_BPD: Float[Tensor, "batch n_patches dim"]): |
|
for blk in self.vit.blocks[-self.n_end_layers :]: |
|
x_BPD = blk(x_BPD) |
|
|
|
x_BPD = self.vit.norm(x_BPD) |
|
return x_BPD[:, self.vit.num_register_tokens + 1 :] |
|
|
|
|
|
rest_of_vit = RestOfDinoV2(n_end_layers=1) |
|
rest_of_vit = rest_of_vit.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@beartype.beartype |
|
def load_tensor(path: str | pathlib.Path) -> Tensor: |
|
return torch.load(path, weights_only=True, map_location="cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@beartype.beartype |
|
class Example(typing.TypedDict): |
|
"""Represents an example image and its associated label. |
|
|
|
Used to store examples of SAE latent activations for visualization. |
|
""" |
|
|
|
orig_url: str |
|
"""The URL or path to access the original example image.""" |
|
highlighted_url: str |
|
"""The URL or path to access the SAE-highlighted image.""" |
|
index: int |
|
"""Dataset index.""" |
|
|
|
|
|
@beartype.beartype |
|
class SaeActivation(typing.TypedDict): |
|
"""Represents the activation pattern of a single SAE latent across patches. |
|
|
|
This captures how strongly a particular SAE latent fires on different patches of an input image. |
|
""" |
|
|
|
latent: int |
|
"""The index of the SAE latent being measured.""" |
|
|
|
highlighted_url: str |
|
"""The image with the colormaps applied.""" |
|
|
|
activations: list[float] |
|
"""The activation values of this latent across different patches. Each value represents how strongly this latent fired on a particular patch.""" |
|
|
|
examples: list[Example] |
|
"""Top examples for this latent.""" |
|
|
|
|
|
@beartype.beartype |
|
def get_image(i: int) -> tuple[str, str, int]: |
|
img_sized = data.to_sized(data.get_image(i)) |
|
seg_sized = data.to_sized(data.get_seg(i)) |
|
seg_u8_sized = data.to_u8(seg_sized) |
|
seg_img_sized = data.u8_to_img(seg_u8_sized) |
|
|
|
return data.img_to_base64(img_sized), data.img_to_base64(seg_img_sized), i |
|
|
|
|
|
@beartype.beartype |
|
@torch.inference_mode |
|
def get_sae_activations(image_i: int, patches: list[int]) -> list[SaeActivation]: |
|
""" |
|
Given a particular cell, returns some highlighted images showing what feature fires most on this cell. |
|
""" |
|
if not patches: |
|
return [] |
|
|
|
vit, vit_transform = load_vit() |
|
sae = load_sae() |
|
|
|
img = data.get_image(image_i) |
|
|
|
x = vit_transform(img)[None, ...].to(DEVICE) |
|
|
|
_, vit_acts_BLPD = vit(x) |
|
vit_acts_PD = ( |
|
vit_acts_BLPD[0, 0, 1:].to(DEVICE).clamp(-1e-5, 1e5) |
|
- (constants.DINOV2_IMAGENET1K_MEAN).to(DEVICE) |
|
) / constants.DINOV2_IMAGENET1K_SCALAR |
|
|
|
_, f_x_PS, _ = sae(vit_acts_PD) |
|
|
|
acts_SP = einops.rearrange(f_x_PS, "patches n_latents -> n_latents patches") |
|
logger.info("Got SAE activations.") |
|
|
|
top_img_i, top_values = load_tensors(model_cfg) |
|
logger.info("Loaded top SAE activations for '%s'.", model_name) |
|
|
|
vit_acts_MD = torch.stack([ |
|
acts_dataset[image_i * acts_dataset.metadata.n_patches_per_img + i]["act"] |
|
for i in patches |
|
]).to(device) |
|
|
|
_, f_x_MS, _ = sae(vit_acts_MD) |
|
f_x_S = f_x_MS.sum(axis=0) |
|
|
|
latents = torch.argsort(f_x_S, descending=True).cpu() |
|
latents = latents[mask[latents]][:n_sae_latents].tolist() |
|
|
|
images = [] |
|
for latent in latents: |
|
elems, seen_i_im = [], set() |
|
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]): |
|
if i_im in seen_i_im: |
|
continue |
|
|
|
example = in1k_dataset[i_im] |
|
elems.append( |
|
saev.visuals.GridElement(example["image"], example["label"], values_p) |
|
) |
|
seen_i_im.add(i_im) |
|
|
|
|
|
upper = None |
|
if top_values[latent].numel() > 0: |
|
upper = top_values[latent].max().item() |
|
|
|
latent_images = [make_img(elem, upper=upper) for elem in elems[:n_sae_examples]] |
|
|
|
while len(latent_images) < n_sae_examples: |
|
latent_images += [None] |
|
|
|
images.extend(latent_images) |
|
|
|
return images + latents |
|
|
|
|
|
@torch.inference_mode |
|
def get_true_labels(image_i: int) -> Image.Image: |
|
seg = human_dataset[image_i]["segmentation"] |
|
image = seg_to_img(seg) |
|
return image |
|
|
|
|
|
@torch.inference_mode |
|
def get_pred_labels(i: int) -> list[Image.Image | list[int]]: |
|
sample = vit_dataset[i] |
|
x = sample["image"][None, ...].to(device) |
|
x_BPD = rest_of_vit.forward_start(x) |
|
x_BPD = rest_of_vit.forward_end(x_BPD) |
|
|
|
x_WHD = einops.rearrange(x_BPD, "() (w h) dim -> w h dim", w=16, h=16) |
|
|
|
logits_WHC = head(x_WHD) |
|
|
|
pred_WH = logits_WHC.argmax(axis=-1) |
|
preds = einops.rearrange(pred_WH, "w h -> (w h)").tolist() |
|
return [seg_to_img(upsample(pred_WH)), preds] |
|
|
|
|
|
@beartype.beartype |
|
def unscaled(x: float, max_obs: float) -> float: |
|
"""Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs].""" |
|
return map_range(x, (-10.0, 10.0), (-10.0 * max_obs, 10.0 * max_obs)) |
|
|
|
|
|
@beartype.beartype |
|
def map_range( |
|
x: float, |
|
domain: tuple[float | int, float | int], |
|
range: tuple[float | int, float | int], |
|
): |
|
a, b = domain |
|
c, d = range |
|
if not (a <= x <= b): |
|
raise ValueError(f"x={x:.3f} must be in {[a, b]}.") |
|
return c + (x - a) * (d - c) / (b - a) |
|
|
|
|
|
@torch.inference_mode |
|
def get_modified_labels( |
|
i: int, |
|
latent1: int, |
|
latent2: int, |
|
latent3: int, |
|
value1: float, |
|
value2: float, |
|
value3: float, |
|
) -> list[Image.Image | list[int]]: |
|
sample = vit_dataset[i] |
|
x = sample["image"][None, ...].to(device) |
|
x_BPD = rest_of_vit.forward_start(x) |
|
|
|
x_hat_BPD, f_x_BPS, _ = sae(x_BPD) |
|
|
|
err_BPD = x_BPD - x_hat_BPD |
|
|
|
values = torch.tensor( |
|
[ |
|
unscaled(float(value), top_values[latent].max().item()) |
|
for value, latent in [ |
|
(value1, latent1), |
|
(value2, latent2), |
|
(value3, latent3), |
|
] |
|
], |
|
device=device, |
|
) |
|
f_x_BPS[..., torch.tensor([latent1, latent2, latent3], device=device)] = values |
|
|
|
|
|
modified_x_hat_BPD = ( |
|
einops.einsum( |
|
f_x_BPS, |
|
sae.W_dec, |
|
"batch patches d_sae, d_sae d_vit -> batch patches d_vit", |
|
) |
|
+ sae.b_dec |
|
) |
|
modified_BPD = err_BPD + modified_x_hat_BPD |
|
|
|
modified_BPD = rest_of_vit.forward_end(modified_BPD) |
|
|
|
logits_BPC = head(modified_BPD) |
|
pred_P = logits_BPC[0].argmax(axis=-1) |
|
pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16) |
|
return seg_to_img(upsample(pred_WH)), pred_P.tolist() |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
@torch.inference_mode |
|
def upsample( |
|
x_WH: Int[Tensor, "width_ps height_ps"], |
|
) -> UInt8[Tensor, "width_px height_px"]: |
|
return ( |
|
torch.nn.functional.interpolate( |
|
x_WH.view((1, 1, 16, 16)).float(), |
|
scale_factor=28, |
|
) |
|
.view((448, 448)) |
|
.type(torch.uint8) |
|
) |
|
|
|
|
|
@beartype.beartype |
|
def make_img( |
|
elem: saev.visuals.GridElement, *, upper: float | None = None |
|
) -> Image.Image: |
|
|
|
resize_size_px = (512, 512) |
|
resize_w_px, resize_h_px = resize_size_px |
|
crop_size_px = (448, 448) |
|
crop_w_px, crop_h_px = crop_size_px |
|
crop_coords_px = ( |
|
(resize_w_px - crop_w_px) // 2, |
|
(resize_h_px - crop_h_px) // 2, |
|
(resize_w_px + crop_w_px) // 2, |
|
(resize_h_px + crop_h_px) // 2, |
|
) |
|
|
|
img = elem.img.resize(resize_size_px).crop(crop_coords_px) |
|
img = saev.imaging.add_highlights( |
|
img, elem.patches.numpy(), upper=upper, opacity=0.5 |
|
) |
|
return img |
|
|
|
|
|
with gr.Blocks() as demo: |
|
image_number = gr.Number(label="Validation Example") |
|
|
|
input_image_base64 = gr.Text(label="Image in Base64") |
|
true_labels_base64 = gr.Text(label="Labels in Base64") |
|
|
|
get_input_image_btn = gr.Button(value="Get Input Image") |
|
get_input_image_btn.click( |
|
get_image, |
|
inputs=[image_number], |
|
outputs=[input_image_base64, true_labels_base64, image_number], |
|
api_name="get-image", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
patches_json = gr.JSON(label="Patches", value=[]) |
|
activations_json = gr.JSON(label="Activations", value=[]) |
|
|
|
get_sae_activations_btn = gr.Button(value="Get SAE Activations") |
|
get_sae_activations_btn.click( |
|
get_sae_activations, |
|
inputs=[image_number, patches_json], |
|
outputs=[activations_json], |
|
api_name="get-sae-examples", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|