leonorv commited on
Commit
4436d63
·
verified ·
1 Parent(s): 52df79a

Upload featurizers.py

Browse files
Files changed (1) hide show
  1. featurizers.py +489 -0
featurizers.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ featurizers.py
3
+ ==============
4
+ Utility classes for defining *invertible* feature spaces on top of a model’s
5
+ hidden-state tensors, together with intervention helpers that operate inside
6
+ those spaces.
7
+
8
+ Key ideas
9
+ ---------
10
+
11
+ * **Featurizer** – a lightweight wrapper holding:
12
+ • a forward `featurizer` module that maps a tensor **x → (f, error)**
13
+ where *error* is the reconstruction residual (useful for lossy
14
+ featurizers such as sparse auto-encoders);
15
+ • an `inverse_featurizer` that re-assembles the original space
16
+ **(f, error) → x̂**.
17
+
18
+ * **Interventions** – three higher-order factory functions build PyVENE
19
+ interventions that work in the featurized space:
20
+ - *interchange*
21
+ - *collect*
22
+ - *mask* (differential binary masking)
23
+
24
+ All public classes / functions below carry PEP-257-style doc-strings.
25
+ """
26
+
27
+ from typing import Optional, Tuple
28
+
29
+ import torch
30
+ import pyvene as pv
31
+
32
+
33
+ # --------------------------------------------------------------------------- #
34
+ # Basic identity featurizers #
35
+ # --------------------------------------------------------------------------- #
36
+ class IdentityFeaturizerModule(torch.nn.Module):
37
+ """A no-op featurizer: *x → (x, None)*."""
38
+
39
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
40
+ return x, None
41
+
42
+
43
+ class IdentityInverseFeaturizerModule(torch.nn.Module):
44
+ """Inverse of :class:`IdentityFeaturizerModule`."""
45
+
46
+ def forward(self, x: torch.Tensor, error: None) -> torch.Tensor: # noqa: D401
47
+ return x
48
+
49
+
50
+ # --------------------------------------------------------------------------- #
51
+ # High-level Featurizer wrapper #
52
+ # --------------------------------------------------------------------------- #
53
+ class Featurizer:
54
+ """Container object holding paired featurizer and inverse modules.
55
+
56
+ Parameters
57
+ ----------
58
+ featurizer :
59
+ A `torch.nn.Module` mapping **x → (features, error)**.
60
+ inverse_featurizer :
61
+ A `torch.nn.Module` mapping **(features, error) → x̂**.
62
+ n_features :
63
+ Dimensionality of the feature space. **Required** when you intend to
64
+ build a *mask* intervention; optional otherwise.
65
+ id :
66
+ Human-readable identifier used by `__str__` methods of the generated
67
+ interventions.
68
+ """
69
+
70
+ # --------------------------------------------------------------------- #
71
+ # Construction / public accessors #
72
+ # --------------------------------------------------------------------- #
73
+ def __init__(
74
+ self,
75
+ featurizer: torch.nn.Module = IdentityFeaturizerModule(),
76
+ inverse_featurizer: torch.nn.Module = IdentityInverseFeaturizerModule(),
77
+ *,
78
+ n_features: Optional[int] = None,
79
+ id: str = "null",
80
+ ):
81
+ self.featurizer = featurizer
82
+ self.inverse_featurizer = inverse_featurizer
83
+ self.n_features = n_features
84
+ self.id = id
85
+
86
+ # -------------------- Intervention builders -------------------------- #
87
+ def get_interchange_intervention(self):
88
+ if not hasattr(self, "_interchange_intervention"):
89
+ self._interchange_intervention = build_feature_interchange_intervention(
90
+ self.featurizer, self.inverse_featurizer, self.id
91
+ )
92
+ return self._interchange_intervention
93
+
94
+ def get_collect_intervention(self):
95
+ if not hasattr(self, "_collect_intervention"):
96
+ self._collect_intervention = build_feature_collect_intervention(
97
+ self.featurizer, self.id
98
+ )
99
+ return self._collect_intervention
100
+
101
+ def get_mask_intervention(self):
102
+ if self.n_features is None:
103
+ raise ValueError(
104
+ "`n_features` must be provided on the Featurizer "
105
+ "to construct a mask intervention."
106
+ )
107
+ if not hasattr(self, "_mask_intervention"):
108
+ self._mask_intervention = build_feature_mask_intervention(
109
+ self.featurizer,
110
+ self.inverse_featurizer,
111
+ self.n_features,
112
+ self.id,
113
+ )
114
+ return self._mask_intervention
115
+
116
+ # ------------------------- Convenience I/O --------------------------- #
117
+ def featurize(self, x: torch.Tensor):
118
+ return self.featurizer(x)
119
+
120
+ def inverse_featurize(self, x: torch.Tensor, error):
121
+ return self.inverse_featurizer(x, error)
122
+
123
+ # --------------------------------------------------------------------- #
124
+ # (De)serialisation helpers #
125
+ # --------------------------------------------------------------------- #
126
+ def save_modules(self, path: str) -> Tuple[str, str]:
127
+ """Serialise featurizer & inverse to `<path>_{featurizer, inverse}`.
128
+
129
+ Notes
130
+ -----
131
+ * **SAE featurizers** are *not* serialisable: a
132
+ :class:`NotImplementedError` is raised.
133
+ * Existing files will be *silently overwritten*.
134
+ """
135
+ featurizer_class = self.featurizer.__class__.__name__
136
+
137
+ if featurizer_class == "SAEFeaturizerModule":
138
+ #SAE featurizers are to be loaded from sae_lens
139
+ return None, None
140
+
141
+ inverse_featurizer_class = self.inverse_featurizer.__class__.__name__
142
+
143
+ # Extra config needed for Subspace featurizers
144
+ additional_config = {}
145
+ if featurizer_class == "SubspaceFeaturizerModule":
146
+ additional_config["rotation_matrix"] = (
147
+ self.featurizer.rotate.weight.detach().clone()
148
+ )
149
+ additional_config["requires_grad"] = (
150
+ self.featurizer.rotate.weight.requires_grad
151
+ )
152
+
153
+ model_info = {
154
+ "featurizer_class": featurizer_class,
155
+ "inverse_featurizer_class": inverse_featurizer_class,
156
+ "n_features": self.n_features,
157
+ "additional_config": additional_config,
158
+ }
159
+
160
+ torch.save(
161
+ {"model_info": model_info, "state_dict": self.featurizer.state_dict()},
162
+ f"{path}_featurizer",
163
+ )
164
+ torch.save(
165
+ {
166
+ "model_info": model_info,
167
+ "state_dict": self.inverse_featurizer.state_dict(),
168
+ },
169
+ f"{path}_inverse_featurizer",
170
+ )
171
+ return f"{path}_featurizer", f"{path}_inverse_featurizer"
172
+
173
+ @classmethod
174
+ def load_modules(cls, path: str) -> "Featurizer":
175
+ """Inverse of :meth:`save_modules`.
176
+
177
+ Returns
178
+ -------
179
+ Featurizer
180
+ A *new* instance with reconstructed modules and metadata.
181
+ """
182
+ featurizer_data = torch.load(f"{path}_featurizer")
183
+ inverse_data = torch.load(f"{path}_inverse_featurizer")
184
+
185
+ model_info = featurizer_data["model_info"]
186
+ featurizer_class = model_info["featurizer_class"]
187
+
188
+ if featurizer_class == "SubspaceFeaturizerModule":
189
+ rot = model_info["additional_config"]["rotation_matrix"]
190
+ requires_grad = model_info["additional_config"]["requires_grad"]
191
+
192
+ # Re-build a parametrised orthogonal layer with identical shape.
193
+ in_dim, out_dim = rot.shape
194
+ rotate_layer = pv.models.layers.LowRankRotateLayer(
195
+ in_dim, out_dim, init_orth=False
196
+ )
197
+ rotate_layer.weight.data.copy_(rot)
198
+ rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
199
+ rotate_layer.requires_grad_(requires_grad)
200
+
201
+ featurizer = SubspaceFeaturizerModule(rotate_layer)
202
+ inverse = SubspaceInverseFeaturizerModule(rotate_layer)
203
+
204
+ # Sanity-check weight shape
205
+ assert (
206
+ featurizer.rotate.weight.shape == rot.shape
207
+ ), "Rotation-matrix shape mismatch after deserialisation."
208
+ elif featurizer_class == "IdentityFeaturizerModule":
209
+ featurizer = IdentityFeaturizerModule()
210
+ inverse = IdentityInverseFeaturizerModule()
211
+ else:
212
+ raise ValueError(f"Unknown featurizer class '{featurizer_class}'.")
213
+
214
+ featurizer.load_state_dict(featurizer_data["state_dict"])
215
+ inverse.load_state_dict(inverse_data["state_dict"])
216
+
217
+ return cls(
218
+ featurizer,
219
+ inverse,
220
+ n_features=model_info["n_features"],
221
+ id=model_info.get("featurizer_id", "loaded"),
222
+ )
223
+
224
+
225
+ # --------------------------------------------------------------------------- #
226
+ # Intervention factory helpers #
227
+ # --------------------------------------------------------------------------- #
228
+ def build_feature_interchange_intervention(
229
+ featurizer: torch.nn.Module,
230
+ inverse_featurizer: torch.nn.Module,
231
+ featurizer_id: str,
232
+ ):
233
+ """Return a class implementing PyVENE’s TrainableIntervention."""
234
+
235
+ class FeatureInterchangeIntervention(
236
+ pv.TrainableIntervention, pv.DistributedRepresentationIntervention
237
+ ):
238
+ """Swap features between *base* and *source* in the featurized space."""
239
+
240
+ def __init__(self, **kwargs):
241
+ super().__init__(**kwargs)
242
+ self._featurizer = featurizer
243
+ self._inverse = inverse_featurizer
244
+
245
+ def forward(self, base, source, subspaces=None):
246
+ f_base, base_err = self._featurizer(base)
247
+ f_src, _ = self._featurizer(source)
248
+
249
+ if subspaces is None or _subspace_is_all_none(subspaces):
250
+ f_out = f_src
251
+ else:
252
+ f_out = pv.models.intervention_utils._do_intervention_by_swap(
253
+ f_base,
254
+ f_src,
255
+ "interchange",
256
+ self.interchange_dim,
257
+ subspaces,
258
+ subspace_partition=self.subspace_partition,
259
+ use_fast=self.use_fast,
260
+ )
261
+ return self._inverse(f_out, base_err).to(base.dtype)
262
+
263
+ def __str__(self): # noqa: D401
264
+ return f"FeatureInterchangeIntervention(id={featurizer_id})"
265
+
266
+ return FeatureInterchangeIntervention
267
+
268
+
269
+ def build_feature_collect_intervention(
270
+ featurizer: torch.nn.Module, featurizer_id: str
271
+ ):
272
+ """Return a `CollectIntervention` operating in feature space."""
273
+
274
+ class FeatureCollectIntervention(pv.CollectIntervention):
275
+ def __init__(self, **kwargs):
276
+ super().__init__(**kwargs)
277
+ self._featurizer = featurizer
278
+
279
+ def forward(self, base, source=None, subspaces=None):
280
+ f_base, _ = self._featurizer(base)
281
+ return pv.models.intervention_utils._do_intervention_by_swap(
282
+ f_base,
283
+ source,
284
+ "collect",
285
+ self.interchange_dim,
286
+ subspaces,
287
+ subspace_partition=self.subspace_partition,
288
+ use_fast=self.use_fast,
289
+ )
290
+
291
+ def __str__(self): # noqa: D401
292
+ return f"FeatureCollectIntervention(id={featurizer_id})"
293
+
294
+ return FeatureCollectIntervention
295
+
296
+
297
+ def build_feature_mask_intervention(
298
+ featurizer: torch.nn.Module,
299
+ inverse_featurizer: torch.nn.Module,
300
+ n_features: int,
301
+ featurizer_id: str,
302
+ ):
303
+ """Return a trainable mask intervention."""
304
+
305
+ class FeatureMaskIntervention(pv.TrainableIntervention):
306
+ """Differential-binary masking in the featurized space."""
307
+
308
+ def __init__(self, **kwargs):
309
+ super().__init__(**kwargs)
310
+ self._featurizer = featurizer
311
+ self._inverse = inverse_featurizer
312
+
313
+ # Learnable parameters
314
+ self.mask = torch.nn.Parameter(torch.zeros(n_features), requires_grad=True)
315
+ self.temperature: Optional[torch.Tensor] = None # must be set by user
316
+
317
+ # -------------------- API helpers -------------------- #
318
+ def get_temperature(self) -> torch.Tensor:
319
+ if self.temperature is None:
320
+ raise ValueError("Temperature has not been set.")
321
+ return self.temperature
322
+
323
+ def set_temperature(self, temp: float | torch.Tensor):
324
+ self.temperature = (
325
+ torch.as_tensor(temp, dtype=self.mask.dtype).to(self.mask.device)
326
+ )
327
+
328
+ def _nonlinear_transform(self, f: torch.Tensor) -> torch.Tensor:
329
+ # You can swap this for a real MLP if desired
330
+ return torch.tanh(f)
331
+
332
+ # ------------------------- forward ------------------- #
333
+ def forward(self, base, source, subspaces=None):
334
+ if self.temperature is None:
335
+ raise ValueError("Cannot run forward without a temperature.")
336
+
337
+ f_base, base_err = self._featurizer(base)
338
+ f_src, _ = self._featurizer(source)
339
+
340
+ # Align devices / dtypes
341
+ mask = self.mask.to(f_base.device)
342
+ temp = self.temperature.to(f_base.device)
343
+
344
+ f_base = f_base.to(mask.dtype)
345
+ f_src = f_src.to(mask.dtype)
346
+
347
+ if self.training:
348
+ gate = torch.sigmoid(mask / temp)
349
+ else:
350
+ gate = (torch.sigmoid(mask) > 0.5).float()
351
+
352
+
353
+ f_out = (1.0 - gate) * f_base + gate * f_src
354
+
355
+ # === Apply nonlinearity during training only ===
356
+ if self.training:
357
+ f_out = self._nonlinear_transform(f_out)
358
+
359
+ return self._inverse(f_out.to(base.dtype), base_err).to(base.dtype)
360
+
361
+ # ---------------- Sparsity regulariser --------------- #
362
+ def get_sparsity_loss(self) -> torch.Tensor:
363
+ if self.temperature is None:
364
+ raise ValueError("Temperature has not been set.")
365
+ gate = torch.sigmoid(self.mask / self.temperature)
366
+ return torch.norm(gate, p=1)
367
+
368
+ def __str__(self): # noqa: D401
369
+ return f"FeatureMaskIntervention(id={featurizer_id})"
370
+
371
+ return FeatureMaskIntervention
372
+
373
+
374
+ # --------------------------------------------------------------------------- #
375
+ # Concrete featurizer implementations #
376
+ # --------------------------------------------------------------------------- #
377
+ class SubspaceFeaturizerModule(torch.nn.Module):
378
+ """Linear projector onto an orthogonal *rotation* sub-space."""
379
+
380
+ def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer):
381
+ super().__init__()
382
+ self.rotate = rotate_layer
383
+
384
+ def forward(self, x: torch.Tensor):
385
+ r = self.rotate.weight.T # (out, in)ᵀ
386
+ f = x.to(r.dtype) @ r.T
387
+ error = x - (f @ r).to(x.dtype)
388
+ return f, error
389
+
390
+
391
+ class SubspaceInverseFeaturizerModule(torch.nn.Module):
392
+ """Inverse of :class:`SubspaceFeaturizerModule`."""
393
+
394
+ def __init__(self, rotate_layer: pv.models.layers.LowRankRotateLayer):
395
+ super().__init__()
396
+ self.rotate = rotate_layer
397
+
398
+ def forward(self, f, error):
399
+ r = self.rotate.weight.T
400
+ return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype)
401
+
402
+
403
+ class SubspaceFeaturizer(Featurizer):
404
+ """Orthogonal linear sub-space featurizer."""
405
+
406
+ def __init__(
407
+ self,
408
+ *,
409
+ shape: Tuple[int, int] | None = None,
410
+ rotation_subspace: torch.Tensor | None = None,
411
+ trainable: bool = True,
412
+ id: str = "subspace",
413
+ ):
414
+ assert (
415
+ shape is not None or rotation_subspace is not None
416
+ ), "Provide either `shape` or `rotation_subspace`."
417
+
418
+ if shape is not None:
419
+ rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True)
420
+ else:
421
+ shape = rotation_subspace.shape
422
+ rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False)
423
+ rotate.weight.data.copy_(rotation_subspace)
424
+
425
+ rotate = torch.nn.utils.parametrizations.orthogonal(rotate)
426
+ rotate.requires_grad_(trainable)
427
+
428
+ super().__init__(
429
+ SubspaceFeaturizerModule(rotate),
430
+ SubspaceInverseFeaturizerModule(rotate),
431
+ n_features=rotate.weight.shape[1],
432
+ id=id,
433
+ )
434
+
435
+
436
+ class SAEFeaturizerModule(torch.nn.Module):
437
+ """Wrapper around a *Sparse Autoencoder*’s encode() / decode() pair."""
438
+
439
+ def __init__(self, sae):
440
+ super().__init__()
441
+ self.sae = sae
442
+
443
+ def forward(self, x):
444
+ features = self.sae.encode(x.to(self.sae.dtype))
445
+ error = x - self.sae.decode(features).to(x.dtype)
446
+ return features.to(x.dtype), error
447
+
448
+
449
+ class SAEInverseFeaturizerModule(torch.nn.Module):
450
+ """Inverse for :class:`SAEFeaturizerModule`."""
451
+
452
+ def __init__(self, sae):
453
+ super().__init__()
454
+ self.sae = sae
455
+
456
+ def forward(self, features, error):
457
+ return (
458
+ self.sae.decode(features.to(self.sae.dtype)).to(features.dtype)
459
+ + error.to(features.dtype)
460
+ )
461
+
462
+
463
+ class SAEFeaturizer(Featurizer):
464
+ """Featurizer backed by a pre-trained sparse auto-encoder.
465
+
466
+ Notes
467
+ -----
468
+ Serialisation is *disabled* for SAE featurizers – saving will raise
469
+ ``NotImplementedError``.
470
+ """
471
+
472
+ def __init__(self, sae, *, trainable: bool = False):
473
+ sae.requires_grad_(trainable)
474
+ super().__init__(
475
+ SAEFeaturizerModule(sae),
476
+ SAEInverseFeaturizerModule(sae),
477
+ n_features=sae.cfg.to_dict()["d_sae"],
478
+ id="sae",
479
+ )
480
+
481
+
482
+ # --------------------------------------------------------------------------- #
483
+ # Utility helpers #
484
+ # --------------------------------------------------------------------------- #
485
+ def _subspace_is_all_none(subspaces) -> bool:
486
+ """Return ``True`` if *every* element of *subspaces* is ``None``."""
487
+ return subspaces is None or all(
488
+ inner is None or all(elem is None for elem in inner) for inner in subspaces
489
+ )