|
import functools |
|
import io |
|
import json |
|
import logging |
|
import math |
|
import pathlib |
|
import typing |
|
|
|
import beartype |
|
import einops |
|
import einops.layers.torch |
|
import gradio as gr |
|
import matplotlib |
|
import numpy as np |
|
import saev.activations |
|
import saev.config |
|
import saev.nn |
|
import saev.visuals |
|
import torch |
|
from jaxtyping import Bool, Float, Int, UInt8, jaxtyped |
|
from PIL import Image, ImageDraw |
|
from torch import Tensor |
|
|
|
import constants |
|
import data |
|
import modeling |
|
|
|
logger = logging.getLogger("app.py") |
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_FREQ = 3e-2 |
|
"""Maximum frequency. Any feature that fires more than this is ignored.""" |
|
|
|
RESIZE_SIZE = 512 |
|
"""Resize shorter size to this size in pixels.""" |
|
|
|
CROP_SIZE = (448, 448) |
|
"""Crop size in pixels.""" |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
"""Hardware accelerator, if any.""" |
|
|
|
CWD = pathlib.Path(".") |
|
"""Current working directory.""" |
|
|
|
N_SAE_LATENTS = 4 |
|
"""Number of SAE latents to show.""" |
|
|
|
N_LATENT_EXAMPLES = 4 |
|
"""Number of examples per SAE latent to show.""" |
|
|
|
COLORMAP = matplotlib.colormaps.get_cmap("plasma") |
|
|
|
|
|
@beartype.beartype |
|
class Example(typing.TypedDict): |
|
"""Represents an example image and its associated label. |
|
|
|
Used to store examples of SAE latent activations for visualization. |
|
""" |
|
|
|
orig_url: str |
|
"""The URL or path to access the original example image.""" |
|
highlighted_url: typing.NotRequired[str] |
|
"""The URL or path to access the SAE-highlighted image.""" |
|
seg_url: str |
|
"""Base64-encoded version of the colored segmentation map.""" |
|
classes: list[int] |
|
"""Unique list of all classes in the seg_url.""" |
|
|
|
|
|
@beartype.beartype |
|
class SaeActivation(typing.TypedDict): |
|
"""Represents the activation pattern of a single SAE latent across patches. |
|
|
|
This captures how strongly a particular SAE latent fires on different patches of an input image. |
|
""" |
|
|
|
latent: int |
|
"""The index of the SAE latent being measured.""" |
|
|
|
highlighted_url: str |
|
"""The image with the colormaps applied.""" |
|
|
|
activations: list[float] |
|
"""The activation values of this latent across different patches. Each value represents how strongly this latent fired on a particular patch.""" |
|
|
|
examples: list[Example] |
|
"""Top examples for this latent.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.cache |
|
def load_sae(device: str) -> saev.nn.SparseAutoencoder: |
|
""" |
|
Loads a sparse autoencoder from disk. |
|
""" |
|
sae_ckpt_fpath = CWD / "assets" / "sae.pt" |
|
sae = saev.nn.load(str(sae_ckpt_fpath)) |
|
sae.to(device).eval() |
|
return sae |
|
|
|
|
|
@functools.cache |
|
def load_clf() -> torch.nn.Module: |
|
|
|
head_ckpt_fpath = CWD / "assets" / "clf.pt" |
|
with open(head_ckpt_fpath, "rb") as fd: |
|
kwargs = json.loads(fd.readline().decode()) |
|
buffer = io.BytesIO(fd.read()) |
|
|
|
model = torch.nn.Linear(**kwargs) |
|
state_dict = torch.load(buffer, weights_only=True, map_location=DEVICE) |
|
model.load_state_dict(state_dict) |
|
model = model.to(DEVICE).eval() |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@beartype.beartype |
|
def load_tensor(path: str | pathlib.Path) -> Tensor: |
|
return torch.load(path, weights_only=True, map_location="cpu") |
|
|
|
|
|
@functools.cache |
|
def load_tensors() -> tuple[ |
|
Int[Tensor, "d_sae k"], |
|
UInt8[Tensor, "d_sae k n_patches"], |
|
Bool[Tensor, " d_sae"], |
|
]: |
|
""" |
|
Loads the tensors for the SAE for ADE20K. |
|
""" |
|
top_img_i = load_tensor(CWD / "assets" / "top_img_i.pt") |
|
top_values = load_tensor(CWD / "assets" / "top_values_uint8.pt") |
|
sparsity = load_tensor(CWD / "assets" / "sparsity.pt") |
|
|
|
mask = torch.ones(sparsity.shape, dtype=bool) |
|
mask = mask & (sparsity < MAX_FREQ) |
|
|
|
return top_img_i, top_values, mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
def add_highlights( |
|
img: Image.Image, |
|
patches: Float[np.ndarray, " n_patches"], |
|
*, |
|
upper: int | None = None, |
|
opacity: float = 0.9, |
|
) -> Image.Image: |
|
if not len(patches): |
|
return img |
|
|
|
iw_np, ih_np = int(math.sqrt(len(patches))), int(math.sqrt(len(patches))) |
|
iw_px, ih_px = img.size |
|
pw_px, ph_px = iw_px // iw_np, ih_px // ih_np |
|
assert iw_np * ih_np == len(patches) |
|
|
|
|
|
overlay = Image.new("RGBA", img.size, (0, 0, 0, 0)) |
|
draw = ImageDraw.Draw(overlay) |
|
|
|
colors = np.zeros((len(patches), 3), dtype=np.uint8) |
|
colors[:, 0] = ((patches / (upper + 1e-9)) * 255).astype(np.uint8) |
|
|
|
|
|
for p, (val, color) in enumerate(zip(patches, colors)): |
|
assert upper is not None |
|
val /= upper + 1e-9 |
|
x_np, y_np = p % iw_np, p // ih_np |
|
draw.rectangle( |
|
[ |
|
(x_np * pw_px, y_np * ph_px), |
|
(x_np * pw_px + pw_px, y_np * ph_px + ph_px), |
|
], |
|
fill=(*color, int(opacity * val * 255)), |
|
) |
|
|
|
|
|
return Image.alpha_composite(img.convert("RGBA"), overlay) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@beartype.beartype |
|
def get_img(i: int) -> Example: |
|
img_sized = data.to_sized(data.get_img(i)) |
|
seg_sized = data.to_sized(data.get_seg(i)) |
|
seg_u8_sized = data.to_u8(seg_sized) |
|
seg_img_sized = data.u8_to_img(seg_u8_sized) |
|
|
|
return { |
|
"orig_url": data.img_to_base64(img_sized), |
|
"seg_url": data.img_to_base64(seg_img_sized), |
|
"classes": data.to_classes(seg_u8_sized), |
|
} |
|
|
|
|
|
@beartype.beartype |
|
@torch.inference_mode |
|
def get_sae_latents(img: Image.Image, patches: list[int]) -> list[SaeActivation]: |
|
""" |
|
Given a particular cell, returns some highlighted images showing what feature fires most on this cell. |
|
""" |
|
if not patches: |
|
return [] |
|
|
|
split_vit, vit_transform = modeling.load_vit(DEVICE) |
|
sae = load_sae(DEVICE) |
|
|
|
x_BCWH = vit_transform(img.convert("RGB"))[None, ...].to(DEVICE) |
|
|
|
x_BPD = split_vit.forward_start(x_BCWH) |
|
x_BPD = ( |
|
x_BPD.clamp(-1e-5, 1e5) - (constants.DINOV2_IMAGENET1K_MEAN).to(DEVICE) |
|
) / constants.DINOV2_IMAGENET1K_SCALAR |
|
|
|
|
|
|
|
x_PD = x_BPD[0, [p + 1 + 4 for p in patches]] |
|
_, f_x_PS, _ = sae(x_PD) |
|
|
|
f_x_S = einops.reduce(f_x_PS, "patches n_latents -> n_latents", "sum") |
|
logger.info("Got SAE activations.") |
|
|
|
top_img_i, top_values, mask = load_tensors() |
|
|
|
latents = torch.argsort(f_x_S, descending=True).cpu() |
|
latents = latents[mask[latents]][:N_SAE_LATENTS].tolist() |
|
|
|
sae_activations = [] |
|
for latent in latents: |
|
pairs, seen_i_im = [], set() |
|
for i_im, values_p in zip(top_img_i[latent].tolist(), top_values[latent]): |
|
if i_im in seen_i_im: |
|
continue |
|
|
|
pairs.append((i_im, values_p)) |
|
seen_i_im.add(i_im) |
|
if len(pairs) >= N_LATENT_EXAMPLES: |
|
break |
|
|
|
|
|
upper = None |
|
if top_values[latent].numel() > 0: |
|
upper = top_values[latent].max().item() |
|
|
|
examples = [] |
|
for i_im, values_p in pairs: |
|
seg_sized = data.to_sized(data.get_seg(i_im)) |
|
img_sized = data.to_sized(data.get_img(i_im)) |
|
|
|
seg_u8_sized = data.to_u8(seg_sized) |
|
seg_img_sized = data.u8_to_img(seg_u8_sized) |
|
|
|
highlighted_sized = add_highlights( |
|
img_sized, values_p.float().numpy(), upper=upper |
|
) |
|
|
|
examples.append({ |
|
"orig_url": data.img_to_base64(img_sized), |
|
"highlighted_url": data.img_to_base64(highlighted_sized), |
|
"seg_url": data.img_to_base64(seg_img_sized), |
|
"classes": data.to_classes(seg_u8_sized), |
|
}) |
|
|
|
sae_activations.append({ |
|
"latent": latent, |
|
"examples": examples, |
|
}) |
|
|
|
return sae_activations |
|
|
|
|
|
@beartype.beartype |
|
@torch.inference_mode |
|
def get_orig_preds(img: Image.Image) -> Example: |
|
split_vit, vit_transform = modeling.load_vit(DEVICE) |
|
|
|
x_BCWH = vit_transform(img.convert("RGB"))[None, ...].to(DEVICE) |
|
|
|
x_BPD = split_vit.forward_start(x_BCWH) |
|
x_BPD = split_vit.forward_end(x_BPD) |
|
|
|
x_WHD = einops.rearrange(x_BPD, "() (w h) dim -> w h dim", w=16, h=16) |
|
|
|
clf = load_clf() |
|
logits_WHC = clf(x_WHD) |
|
|
|
pred_WH = logits_WHC[:, :, 1:].argmax(axis=-1) + 1 |
|
return { |
|
"orig_url": data.img_to_base64(data.to_sized(img)), |
|
"seg_url": data.img_to_base64(data.u8_to_overlay(pred_WH, img)), |
|
"classes": data.to_classes(pred_WH), |
|
} |
|
|
|
|
|
@beartype.beartype |
|
def unscaled(x: float, max_obs: float | int) -> float: |
|
"""Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs].""" |
|
return map_range(x, (-10.0, 10.0), (-10.0 * max_obs, 10.0 * max_obs)) |
|
|
|
|
|
@beartype.beartype |
|
def map_range( |
|
x: float, |
|
domain: tuple[float | int, float | int], |
|
range: tuple[float | int, float | int], |
|
): |
|
a, b = domain |
|
c, d = range |
|
if not (a <= x <= b): |
|
raise ValueError(f"x={x:.3f} must be in {[a, b]}.") |
|
return c + (x - a) * (d - c) / (b - a) |
|
|
|
|
|
@beartype.beartype |
|
@torch.inference_mode |
|
def get_mod_preds(img: Image.Image, latents: dict[str, int | float]) -> Example: |
|
latents = {int(k): float(v) for k, v in latents.items()} |
|
|
|
split_vit, vit_transform = modeling.load_vit(DEVICE) |
|
sae = load_sae(DEVICE) |
|
_, top_values, _ = load_tensors() |
|
clf = load_clf() |
|
|
|
x_BCWH = vit_transform(img.convert("RGB"))[None, ...].to(DEVICE) |
|
x_BPD = split_vit.forward_start(x_BCWH) |
|
x_hat_BPD, f_x_BPS, _ = sae(x_BPD) |
|
|
|
err_BPD = x_BPD - x_hat_BPD |
|
|
|
values = torch.tensor( |
|
[ |
|
unscaled(float(value), top_values[latent].max().item()) |
|
for latent, value in latents.items() |
|
], |
|
device=DEVICE, |
|
) |
|
f_x_BPS[..., torch.tensor(list(latents.keys()), device=DEVICE)] = values |
|
|
|
|
|
mod_x_hat_BPD = ( |
|
einops.einsum( |
|
f_x_BPS, |
|
sae.W_dec, |
|
"batch patches d_sae, d_sae d_vit -> batch patches d_vit", |
|
) |
|
+ sae.b_dec |
|
) |
|
mod_BPD = err_BPD + mod_x_hat_BPD |
|
|
|
mod_BPD = split_vit.forward_end(mod_BPD) |
|
mod_WHD = einops.rearrange(mod_BPD, "() (w h) dim -> w h dim", w=16, h=16) |
|
|
|
logits_WHC = clf(mod_WHD) |
|
pred_WH = logits_WHC[:, :, 1:].argmax(axis=-1) + 1 |
|
|
|
return { |
|
"orig_url": data.img_to_base64(data.to_sized(img)), |
|
"seg_url": data.img_to_base64(data.u8_to_overlay(pred_WH, img)), |
|
"classes": data.to_classes(pred_WH), |
|
} |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
|
|
|
|
|
|
img_number = gr.Number(label="Example Index") |
|
|
|
|
|
get_img_out = gr.JSON(label="get_img_out", value={}) |
|
|
|
get_input_img_btn = gr.Button(value="Get Input Image") |
|
get_input_img_btn.click( |
|
get_img, |
|
inputs=[img_number], |
|
outputs=[get_img_out], |
|
api_name="get-img", |
|
concurrency_limit=10, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
patches_json = gr.JSON(label="Patches", value=[]) |
|
input_img = gr.Image( |
|
label="Input Image", |
|
sources=["upload", "clipboard"], |
|
type="pil", |
|
interactive=True, |
|
) |
|
|
|
get_sae_latents_out = gr.JSON(label="get_sae_latents_out", value=[]) |
|
|
|
get_sae_latents_btn = gr.Button(value="Get SAE Latents") |
|
get_sae_latents_btn.click( |
|
get_sae_latents, |
|
inputs=[input_img, patches_json], |
|
outputs=[get_sae_latents_out], |
|
api_name="get-sae-latents", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
get_orig_preds_out = gr.JSON(label="get_orig_preds_out", value=[]) |
|
|
|
get_pred_labels_btn = gr.Button(value="Get Predictions") |
|
get_pred_labels_btn.click( |
|
get_orig_preds, |
|
inputs=[input_img], |
|
outputs=[get_orig_preds_out], |
|
api_name="get-orig-preds", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
latents_json = gr.JSON(label="Modified Latents", value={}) |
|
|
|
|
|
get_mod_preds_out = gr.JSON(label="get_mod_preds_out", value=[]) |
|
|
|
get_pred_labels_btn = gr.Button(value="Get Predictions") |
|
get_pred_labels_btn.click( |
|
get_mod_preds, |
|
inputs=[input_img, latents_json], |
|
outputs=[get_mod_preds_out], |
|
api_name="get-mod-preds", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(default_concurrency_limit=2, max_size=32) |
|
demo.launch() |
|
|