Spaces:
Running
on
Zero
Running
on
Zero
Update SAE/sae.py
Browse files- SAE/sae.py +5 -9
SAE/sae.py
CHANGED
|
@@ -41,13 +41,6 @@ class SparseAutoencoder(nn.Module):
|
|
| 41 |
self.stats_last_nonzero: torch.Tensor
|
| 42 |
self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
|
| 43 |
|
| 44 |
-
def auxk_mask_fn(x):
|
| 45 |
-
dead_mask = self.stats_last_nonzero > dead_steps_threshold
|
| 46 |
-
x.data *= dead_mask # inplace to save memory
|
| 47 |
-
return x
|
| 48 |
-
|
| 49 |
-
self.auxk_mask_fn = auxk_mask_fn
|
| 50 |
-
|
| 51 |
## initialization
|
| 52 |
|
| 53 |
# "tied" init
|
|
@@ -58,6 +51,11 @@ class SparseAutoencoder(nn.Module):
|
|
| 58 |
|
| 59 |
unit_norm_decoder_(self)
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def save_to_disk(self, path: str):
|
| 62 |
PATH_TO_CFG = 'config.json'
|
| 63 |
PATH_TO_WEIGHTS = 'state_dict.pth'
|
|
@@ -122,7 +120,6 @@ class SparseAutoencoder(nn.Module):
|
|
| 122 |
|
| 123 |
return latents
|
| 124 |
|
| 125 |
-
@spaces.GPU
|
| 126 |
def forward(self, x):
|
| 127 |
x = x - self.pre_bias
|
| 128 |
latents_pre_act = self.encoder(x) + self.latent_bias
|
|
@@ -182,7 +179,6 @@ class SparseAutoencoder(nn.Module):
|
|
| 182 |
"auxk_vals": auxk_vals,
|
| 183 |
}
|
| 184 |
|
| 185 |
-
@spaces.GPU
|
| 186 |
def decode_sparse(self, inds, vals):
|
| 187 |
rows, cols = inds.shape[0], self.n_dirs
|
| 188 |
|
|
|
|
| 41 |
self.stats_last_nonzero: torch.Tensor
|
| 42 |
self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
## initialization
|
| 45 |
|
| 46 |
# "tied" init
|
|
|
|
| 51 |
|
| 52 |
unit_norm_decoder_(self)
|
| 53 |
|
| 54 |
+
def auxk_mask_fn(self, x):
|
| 55 |
+
dead_mask = self.stats_last_nonzero > dead_steps_threshold
|
| 56 |
+
x.data *= dead_mask # inplace to save memory
|
| 57 |
+
return x
|
| 58 |
+
|
| 59 |
def save_to_disk(self, path: str):
|
| 60 |
PATH_TO_CFG = 'config.json'
|
| 61 |
PATH_TO_WEIGHTS = 'state_dict.pth'
|
|
|
|
| 120 |
|
| 121 |
return latents
|
| 122 |
|
|
|
|
| 123 |
def forward(self, x):
|
| 124 |
x = x - self.pre_bias
|
| 125 |
latents_pre_act = self.encoder(x) + self.latent_bias
|
|
|
|
| 179 |
"auxk_vals": auxk_vals,
|
| 180 |
}
|
| 181 |
|
|
|
|
| 182 |
def decode_sparse(self, inds, vals):
|
| 183 |
rows, cols = inds.shape[0], self.n_dirs
|
| 184 |
|