Upload featurizers.py
Browse files- featurizers.py +489 -0
featurizers.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
featurizers.py
|
3 |
+
==============
|
4 |
+
Utility classes for defining *invertible* feature spaces on top of a model’s
|
5 |
+
hidden-state tensors, together with intervention helpers that operate inside
|
6 |
+
those spaces.
|
7 |
+
|
8 |
+
Key ideas
|
9 |
+
---------
|
10 |
+
|
11 |
+
* **Featurizer** – a lightweight wrapper holding:
|
12 |
+
• a forward `featurizer` module that maps a tensor **x → (f, error)**
|
13 |
+
where *error* is the reconstruction residual (useful for lossy
|
14 |
+
featurizers such as sparse auto-encoders);
|
15 |
+
• an `inverse_featurizer` that re-assembles the original space
|
16 |
+
**(f, error) → x̂**.
|
17 |
+
|
18 |
+
* **Interventions** – three higher-order factory functions build PyVENE
|
19 |
+
interventions that work in the featurized space:
|
20 |
+
- *interchange*
|
21 |
+
- *collect*
|
22 |
+
- *mask* (differential binary masking)
|
23 |
+
|
24 |
+
All public classes / functions below carry PEP-257-style doc-strings.
|
25 |
+
"""
|
26 |
+
|
27 |
+
from typing import Optional, Tuple
|
28 |
+
|
29 |
+
import torch
|
30 |
+
import pyvene as pv
|
31 |
+
|
32 |
+
|
33 |
+
# --------------------------------------------------------------------------- #
|
34 |
+
# Basic identity featurizers #
|
35 |
+
# --------------------------------------------------------------------------- #
|
36 |
+
class IdentityFeaturizerModule(torch.nn.Module):
|
37 |
+
"""A no-op featurizer: *x → (x, None)*."""
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
40 |
+
return x, None
|
41 |
+
|
42 |
+
|
43 |
+
class IdentityInverseFeaturizerModule(torch.nn.Module):
|
44 |
+
"""Inverse of :class:`IdentityFeaturizerModule`."""
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor, error: None) -> torch.Tensor: # noqa: D401
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
# --------------------------------------------------------------------------- #
|
51 |
+
# High-level Featurizer wrapper #
|
52 |
+
# --------------------------------------------------------------------------- #
|
53 |
+
class Featurizer:
|
54 |
+
"""Container object holding paired featurizer and inverse modules.
|
55 |
+
|
56 |
+
Parameters
|
57 |
+
----------
|
58 |
+
featurizer :
|
59 |
+
A `torch.nn.Module` mapping **x → (features, error)**.
|
60 |
+
inverse_featurizer :
|
61 |
+
A `torch.nn.Module` mapping **(features, error) → x̂**.
|
62 |
+
n_features :
|
63 |
+
Dimensionality of the feature space. **Required** when you intend to
|
64 |
+
build a *mask* intervention; optional otherwise.
|
65 |
+
id :
|
66 |
+
Human-readable identifier used by `__str__` methods of the generated
|
67 |
+
interventions.
|
68 |
+
"""
|
69 |
+
|
70 |
+
# --------------------------------------------------------------------- #
|
71 |
+
# Construction / public accessors #
|
72 |
+
# --------------------------------------------------------------------- #
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
featurizer: torch.nn.Module = IdentityFeaturizerModule(),
|
76 |
+
inverse_featurizer: torch.nn.Module = IdentityInverseFeaturizerModule(),
|
77 |
+
*,
|
78 |
+
n_features: Optional[int] = None,
|
79 |
+
id: str = "null",
|
80 |
+
):
|
81 |
+
self.featurizer = featurizer
|
82 |
+
self.inverse_featurizer = inverse_featurizer
|
83 |
+
self.n_features = n_features
|
84 |
+
self.id = id
|
85 |
+
|
86 |
+
# -------------------- Intervention builders -------------------------- #
|
87 |
+
def get_interchange_intervention(self):
|
88 |
+
if not hasattr(self, "_interchange_intervention"):
|
89 |
+
self._interchange_intervention = build_feature_interchange_intervention(
|
90 |
+
self.featurizer, self.inverse_featurizer, self.id
|
91 |
+
)
|
92 |
+
return self._interchange_intervention
|
93 |
+
|
94 |
+
def get_collect_intervention(self):
|
95 |
+
if not hasattr(self, "_collect_intervention"):
|
96 |
+
self._collect_intervention = build_feature_collect_intervention(
|
97 |
+
self.featurizer, self.id
|
98 |
+
)
|
99 |
+
return self._collect_intervention
|
100 |
+
|
101 |
+
def get_mask_intervention(self):
|
102 |
+
if self.n_features is None:
|
103 |
+
raise ValueError(
|
104 |
+
"`n_features` must be provided on the Featurizer "
|
105 |
+
"to construct a mask intervention."
|
106 |
+
)
|
107 |
+
if not hasattr(self, "_mask_intervention"):
|
108 |
+
self._mask_intervention = build_feature_mask_intervention(
|
109 |
+
self.featurizer,
|
110 |
+
self.inverse_featurizer,
|
111 |
+
self.n_features,
|
112 |
+
self.id,
|
113 |
+
)
|
114 |
+
return self._mask_intervention
|
115 |
+
|
116 |
+
# ------------------------- Convenience I/O --------------------------- #
|
117 |
+
def featurize(self, x: torch.Tensor):
|
118 |
+
return self.featurizer(x)
|
119 |
+
|
120 |
+
def inverse_featurize(self, x: torch.Tensor, error):
|
121 |
+
return self.inverse_featurizer(x, error)
|
122 |
+
|
123 |
+
# --------------------------------------------------------------------- #
|
124 |
+
# (De)serialisation helpers #
|
125 |
+
# --------------------------------------------------------------------- #
|
126 |
+
def save_modules(self, path: str) -> Tuple[str, str]:
|
127 |
+
"""Serialise featurizer & inverse to `<path>_{featurizer, inverse}`.
|
128 |
+
|
129 |
+
Notes
|
130 |
+
-----
|
131 |
+
* **SAE featurizers** are *not* serialisable: a
|
132 |
+
:class:`NotImplementedError` is raised.
|
133 |
+
* Existing files will be *silently overwritten*.
|
134 |
+
"""
|
135 |
+
featurizer_class = self.featurizer.__class__.__name__
|
136 |
+
|
137 |
+
if featurizer_class == "SAEFeaturizerModule":
|
138 |
+
#SAE featurizers are to be loaded from sae_lens
|
139 |
+
return None, None
|
140 |
+
|
141 |
+
inverse_featurizer_class = self.inverse_featurizer.__class__.__name__
|
142 |
+
|
143 |
+
# Extra config needed for Subspace featurizers
|
144 |
+
additional_config = {}
|
145 |
+
if featurizer_class == "SubspaceFeaturizerModule":
|
146 |
+
additional_config["rotation_matrix"] = (
|
147 |
+
self.featurizer.rotate.weight.detach().clone()
|
148 |
+
)
|
149 |
+
additional_config["requires_grad"] = (
|
150 |
+
self.featurizer.rotate.weight.requires_grad
|
151 |
+
)
|
152 |
+
|
153 |
+
model_info = {
|
154 |
+
"featurizer_class": featurizer_class,
|
155 |
+
"inverse_featurizer_class": inverse_featurizer_class,
|
156 |
+
"n_features": self.n_features,
|
157 |
+
"additional_config": additional_config,
|
158 |
+
}
|
159 |
+
|
160 |
+
torch.save(
|
161 |
+
{"model_info": model_info, "state_dict": self.featurizer.state_dict()},
|
162 |
+
f"{path}_featurizer",
|
163 |
+
)
|
164 |
+
torch.save(
|
165 |
+
{
|
166 |
+
"model_info": model_info,
|
167 |
+
"state_dict": self.inverse_featurizer.state_dict(),
|
168 |
+
},
|
169 |
+
f"{path}_inverse_featurizer",
|
170 |
+
)
|
171 |
+
return f"{path}_featurizer", f"{path}_inverse_featurizer"
|
172 |
+
|
173 |
+
@classmethod
|
174 |
+
def load_modules(cls, path: str) -> "Featurizer":
|
175 |
+
"""Inverse of :meth:`save_modules`.
|
176 |
+
|
177 |
+
Returns
|
178 |
+
-------
|
179 |
+
Featurizer
|
180 |
+
A *new* instance with reconstructed modules and metadata.
|
181 |
+
"""
|
182 |
+
featurizer_data = torch.load(f"{path}_featurizer")
|
183 |
+
inverse_data = torch.load(f"{path}_inverse_featurizer")
|
184 |
+
|
185 |
+
model_info = featurizer_data["model_info"]
|
186 |
+
featurizer_class = model_info["featurizer_class"]
|
187 |
+
|
188 |
+
if featurizer_class == "SubspaceFeaturizerModule":
|
189 |
+
rot = model_info["additional_config"]["rotation_matrix"]
|
190 |
+
requires_grad = model_info["additional_config"]["requires_grad"]
|
191 |
+
|
192 |
+
# Re-build a parametrised orthogonal layer with identical shape.
|
193 |
+
in_dim, out_dim = rot.shape
|
194 |
+
rotate_layer = pv.models.layers.LowRankRotateLayer(
|
195 |
+
in_dim, out_dim, init_orth=False
|
196 |
+
)
|
197 |
+
rotate_layer.weight.data.copy_(rot)
|
198 |
+
rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
|
199 |
+
rotate_layer.requires_grad_(requires_grad)
|
200 |
+
|
201 |
+
featurizer = SubspaceFeaturizerModule(rotate_layer)
|
202 |
+
inverse = SubspaceInverseFeaturizerModule(rotate_layer)
|
203 |
+
|
204 |
+
# Sanity-check weight shape
|
205 |
+
assert (
|
206 |
+
featurizer.rotate.weight.shape == rot.shape
|
207 |
+
), "Rotation-matrix shape mismatch after deserialisation."
|
208 |
+
elif featurizer_class == "IdentityFeaturizerModule":
|
209 |
+
featurizer = IdentityFeaturizerModule()
|
210 |
+
inverse = IdentityInverseFeaturizerModule()
|
211 |
+
else:
|
212 |
+
raise ValueError(f"Unknown featurizer class '{featurizer_class}'.")
|
213 |
+
|
214 |
+
featurizer.load_state_dict(featurizer_data["state_dict"])
|
215 |
+
inverse.load_state_dict(inverse_data["state_dict"])
|
216 |
+
|
217 |
+
return cls(
|
218 |
+
featurizer,
|
219 |
+
inverse,
|
220 |
+
n_features=model_info["n_features"],
|
221 |
+
id=model_info.get("featurizer_id", "loaded"),
|
222 |
+
)
|
223 |
+
|
224 |
+
|
225 |
+
# --------------------------------------------------------------------------- #
|
226 |
+
# Intervention factory helpers #
|
227 |
+
# --------------------------------------------------------------------------- #
|
228 |
+
def build_feature_interchange_intervention(
|
229 |
+
featurizer: torch.nn.Module,
|
230 |
+
inverse_featurizer: torch.nn.Module,
|
231 |
+
featurizer_id: str,
|
232 |
+
):
|
233 |
+
"""Return a class implementing PyVENE’s TrainableIntervention."""
|
234 |
+
|
235 |
+
class FeatureInterchangeIntervention(
|
236 |
+
pv.TrainableIntervention, pv.DistributedRepresentationIntervention
|
237 |
+
):
|
238 |
+
"""Swap features between *base* and *source* in the featurized space."""
|
239 |
+
|
240 |
+
def __init__(self, **kwargs):
|
241 |
+
super().__init__(**kwargs)
|
242 |
+
self._featurizer = featurizer
|
243 |
+
self._inverse = inverse_featurizer
|
244 |
+
|
245 |
+
def forward(self, base, source, subspaces=None):
|
246 |
+
f_base, base_err = self._featurizer(base)
|
247 |
+
f_src, _ = self._featurizer(source)
|
248 |
+
|
249 |
+
if subspaces is None or _subspace_is_all_none(subspaces):
|
250 |
+
f_out = f_src
|
251 |
+
else:
|
252 |
+
f_out = pv.models.intervention_utils._do_intervention_by_swap(
|
253 |
+
f_base,
|
254 |
+
f_src,
|
255 |
+
"interchange",
|
256 |
+
self.interchange_dim,
|
257 |
+
subspaces,
|
258 |
+
subspace_partition=self.subspace_partition,
|
259 |
+
use_fast=self.use_fast,
|
260 |
+
)
|
261 |
+
return self._inverse(f_out, base_err).to(base.dtype)
|
262 |
+
|
263 |
+
def __str__(self): # noqa: D401
|
264 |
+
return f"FeatureInterchangeIntervention(id={featurizer_id})"
|
265 |
+
|
266 |
+
return FeatureInterchangeIntervention
|
267 |
+
|
268 |
+
|
269 |
+
def build_feature_collect_intervention(
|
270 |
+
featurizer: torch.nn.Module, featurizer_id: str
|
271 |
+
):
|
272 |
+
"""Return a `CollectIntervention` operating in feature space."""
|
273 |
+
|
274 |
+
class FeatureCollectIntervention(pv.CollectIntervention):
|
275 |
+
def __init__(self, **kwargs):
|
276 |
+
super().__init__(**kwargs)
|
277 |
+
self._featurizer = featurizer
|
278 |
+
|
279 |
+
def forward(self, base, source=None, subspaces=None):
|
280 |
+
f_base, _ = self._featurizer(base)
|
281 |
+
return pv.models.intervention_utils._do_intervention_by_swap(
|
282 |
+
f_base,
|
283 |
+
source,
|
284 |
+
"collect",
|
285 |
+
self.interchange_dim,
|
286 |
+
subspaces,
|
287 |
+
subspace_partition=self.subspace_partition,
|
288 |
+
use_fast=self.use_fast,
|
289 |
+
)
|
290 |
+
|
291 |
+
def __str__(self): # noqa: D401
|
292 |
+
return f"FeatureCollectIntervention(id={featurizer_id})"
|
293 |
+
|
294 |
+
return FeatureCollectIntervention
|
295 |
+
|
296 |
+
|
297 |
+
def build_feature_mask_intervention(
|
298 |
+
featurizer: torch.nn.Module,
|
299 |
+
inverse_featurizer: torch.nn.Module,
|
300 |
+
n_features: int,
|
301 |
+
featurizer_id: str,
|
302 |
+
):
|
303 |
+
"""Return a trainable mask intervention."""
|
304 |
+
|
305 |
+
class FeatureMaskIntervention(pv.TrainableIntervention):
|
306 |
+
"""Differential-binary masking in the featurized space."""
|
307 |
+
|
308 |
+
def __init__(self, **kwargs):
|
309 |
+
super().__init__(**kwargs)
|
310 |
+
self._featurizer = featurizer
|
311 |
+
self._inverse = inverse_featurizer
|
312 |
+
|
313 |
+
# Learnable parameters
|
314 |
+
self.mask = torch.nn.Parameter(torch.zeros(n_features), requires_grad=True)
|
315 |
+
self.temperature: Optional[torch.Tensor] = None # must be set by user
|
316 |
+
|
317 |
+
# -------------------- API helpers -------------------- #
|
318 |
+
def get_temperature(self) -> torch.Tensor:
|
319 |
+
if self.temperature is None:
|
320 |
+
raise ValueError("Temperature has not been set.")
|
321 |
+
return self.temperature
|
322 |
+
|
323 |
+
def set_temperature(self, temp: float | torch.Tensor):
|
324 |
+
self.temperature = (
|
325 |
+
torch.as_tensor(temp, dtype=self.mask.dtype).to(self.mask.device)
|
326 |
+
)
|
327 |
+
|
328 |
+
def _nonlinear_transform(self, f: torch.Tensor) -> torch.Tensor:
|
329 |
+
# You can swap this for a real MLP if desired
|
330 |
+
return torch.tanh(f)
|
331 |
+
|
332 |
+
# ------------------------- forward ------------------- #
|
333 |
+
def forward(self, base, source, subspaces=None):
|
334 |
+
if self.temperature is None:
|
335 |
+
raise ValueError("Cannot run forward without a temperature.")
|
336 |
+
|
337 |
+
f_base, base_err = self._featurizer(base)
|
338 |
+
f_src, _ = self._featurizer(source)
|
339 |
+
|
340 |
+
# Align devices / dtypes
|
341 |
+
mask = self.mask.to(f_base.device)
|
342 |
+
temp = self.temperature.to(f_base.device)
|
343 |
+
|
344 |
+
f_base = f_base.to(mask.dtype)
|
345 |
+
f_src = f_src.to(mask.dtype)
|
346 |
+
|
347 |
+
if self.training:
|
348 |
+
gate = torch.sigmoid(mask / temp)
|
349 |
+
else:
|
350 |
+
gate = (torch.sigmoid(mask) > 0.5).float()
|
351 |
+
|
352 |
+
|
353 |
+
f_out = (1.0 - gate) * f_base + gate * f_src
|
354 |
+
|
355 |
+
# === Apply nonlinearity during training only ===
|
356 |
+
if self.training:
|
357 |
+
f_out = self._nonlinear_transform(f_out)
|
358 |
+
|
359 |
+
return self._inverse(f_out.to(base.dtype), base_err).to(base.dtype)
|
360 |
+
|
361 |
+
# ---------------- Sparsity regulariser --------------- #
|
362 |
+
def get_sparsity_loss(self) -> torch.Tensor:
|
363 |
+
if self.temperature is None:
|
364 |
+
raise ValueError("Temperature has not been set.")
|
365 |
+
gate = torch.sigmoid(self.mask / self.temperature)
|
366 |
+
return torch.norm(gate, p=1)
|
367 |
+
|
368 |
+
def __str__(self): # noqa: D401
|
369 |
+
return f"FeatureMaskIntervention(id={featurizer_id})"
|
370 |
+
|
371 |
+
return FeatureMaskIntervention
|
372 |
+
|
373 |
+
|
374 |
+
# --------------------------------------------------------------------------- #
|
375 |
+
# Concrete featurizer implementations #
|
376 |
+
# --------------------------------------------------------------------------- #
|
377 |
+
class SubspaceFeaturizerModule(torch.nn.Module):
|
378 |
+
"""Linear projector onto an orthogonal *rotation* sub-space."""
|
379 |
+
|
380 |
+
def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer):
|
381 |
+
super().__init__()
|
382 |
+
self.rotate = rotate_layer
|
383 |
+
|
384 |
+
def forward(self, x: torch.Tensor):
|
385 |
+
r = self.rotate.weight.T # (out, in)ᵀ
|
386 |
+
f = x.to(r.dtype) @ r.T
|
387 |
+
error = x - (f @ r).to(x.dtype)
|
388 |
+
return f, error
|
389 |
+
|
390 |
+
|
391 |
+
class SubspaceInverseFeaturizerModule(torch.nn.Module):
|
392 |
+
"""Inverse of :class:`SubspaceFeaturizerModule`."""
|
393 |
+
|
394 |
+
def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer):
|
395 |
+
super().__init__()
|
396 |
+
self.rotate = rotate_layer
|
397 |
+
|
398 |
+
def forward(self, f, error):
|
399 |
+
r = self.rotate.weight.T
|
400 |
+
return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype)
|
401 |
+
|
402 |
+
|
403 |
+
class SubspaceFeaturizer(Featurizer):
|
404 |
+
"""Orthogonal linear sub-space featurizer."""
|
405 |
+
|
406 |
+
def __init__(
|
407 |
+
self,
|
408 |
+
*,
|
409 |
+
shape: Tuple[int, int] | None = None,
|
410 |
+
rotation_subspace: torch.Tensor | None = None,
|
411 |
+
trainable: bool = True,
|
412 |
+
id: str = "subspace",
|
413 |
+
):
|
414 |
+
assert (
|
415 |
+
shape is not None or rotation_subspace is not None
|
416 |
+
), "Provide either `shape` or `rotation_subspace`."
|
417 |
+
|
418 |
+
if shape is not None:
|
419 |
+
rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True)
|
420 |
+
else:
|
421 |
+
shape = rotation_subspace.shape
|
422 |
+
rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False)
|
423 |
+
rotate.weight.data.copy_(rotation_subspace)
|
424 |
+
|
425 |
+
rotate = torch.nn.utils.parametrizations.orthogonal(rotate)
|
426 |
+
rotate.requires_grad_(trainable)
|
427 |
+
|
428 |
+
super().__init__(
|
429 |
+
SubspaceFeaturizerModule(rotate),
|
430 |
+
SubspaceInverseFeaturizerModule(rotate),
|
431 |
+
n_features=rotate.weight.shape[1],
|
432 |
+
id=id,
|
433 |
+
)
|
434 |
+
|
435 |
+
|
436 |
+
class SAEFeaturizerModule(torch.nn.Module):
|
437 |
+
"""Wrapper around a *Sparse Autoencoder*’s encode() / decode() pair."""
|
438 |
+
|
439 |
+
def __init__(self, sae):
|
440 |
+
super().__init__()
|
441 |
+
self.sae = sae
|
442 |
+
|
443 |
+
def forward(self, x):
|
444 |
+
features = self.sae.encode(x.to(self.sae.dtype))
|
445 |
+
error = x - self.sae.decode(features).to(x.dtype)
|
446 |
+
return features.to(x.dtype), error
|
447 |
+
|
448 |
+
|
449 |
+
class SAEInverseFeaturizerModule(torch.nn.Module):
|
450 |
+
"""Inverse for :class:`SAEFeaturizerModule`."""
|
451 |
+
|
452 |
+
def __init__(self, sae):
|
453 |
+
super().__init__()
|
454 |
+
self.sae = sae
|
455 |
+
|
456 |
+
def forward(self, features, error):
|
457 |
+
return (
|
458 |
+
self.sae.decode(features.to(self.sae.dtype)).to(features.dtype)
|
459 |
+
+ error.to(features.dtype)
|
460 |
+
)
|
461 |
+
|
462 |
+
|
463 |
+
class SAEFeaturizer(Featurizer):
|
464 |
+
"""Featurizer backed by a pre-trained sparse auto-encoder.
|
465 |
+
|
466 |
+
Notes
|
467 |
+
-----
|
468 |
+
Serialisation is *disabled* for SAE featurizers – saving will raise
|
469 |
+
``NotImplementedError``.
|
470 |
+
"""
|
471 |
+
|
472 |
+
def __init__(self, sae, *, trainable: bool = False):
|
473 |
+
sae.requires_grad_(trainable)
|
474 |
+
super().__init__(
|
475 |
+
SAEFeaturizerModule(sae),
|
476 |
+
SAEInverseFeaturizerModule(sae),
|
477 |
+
n_features=sae.cfg.to_dict()["d_sae"],
|
478 |
+
id="sae",
|
479 |
+
)
|
480 |
+
|
481 |
+
|
482 |
+
# --------------------------------------------------------------------------- #
|
483 |
+
# Utility helpers #
|
484 |
+
# --------------------------------------------------------------------------- #
|
485 |
+
def _subspace_is_all_none(subspaces) -> bool:
|
486 |
+
"""Return ``True`` if *every* element of *subspaces* is ``None``."""
|
487 |
+
return subspaces is None or all(
|
488 |
+
inner is None or all(elem is None for elem in inner) for inner in subspaces
|
489 |
+
)
|