TabPFN commited on
Commit
e0f96a9
·
1 Parent(s): 9196c50

Upload encoders.py

Browse files
Files changed (1) hide show
  1. TabPFN/encoders.py +90 -72
TabPFN/encoders.py CHANGED
@@ -8,26 +8,26 @@ from torch.nn import TransformerEncoder, TransformerEncoderLayer
8
 
9
 
10
  class StyleEncoder(nn.Module):
11
- def __init__(self, em_size, hyperparameter_definitions):
12
  super().__init__()
13
- # self.embeddings = {}
14
  self.em_size = em_size
15
- # self.hyperparameter_definitions = {}
16
- # for hp in hyperparameter_definitions:
17
- # self.embeddings[hp] = nn.Linear(1, self.em_size)
18
- # self.embeddings = nn.ModuleDict(self.embeddings)
19
- self.embedding = nn.Linear(hyperparameter_definitions.shape[0], self.em_size)
20
-
21
- def forward(self, hyperparameters): # T x B x num_features
22
- # Make faster by using matrices
23
- # sampled_embeddings = [torch.stack([
24
- # self.embeddings[hp](torch.tensor([batch[hp]], device=self.embeddings[hp].weight.device, dtype=torch.float))
25
- # for hp in batch
26
- # ], -1).sum(-1) for batch in hyperparameters]
27
- # return torch.stack(sampled_embeddings, 0)
28
  return self.embedding(hyperparameters)
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  class _PositionalEncoding(nn.Module):
32
  def __init__(self, d_model, dropout=0.):
33
  super().__init__()
@@ -97,6 +97,71 @@ def get_normalized_uniform_encoder(encoder_creator):
97
  return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  Linear = nn.Linear
101
  MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
102
  nn.ReLU(),
@@ -120,69 +185,23 @@ class NanHandlingEncoder(nn.Module):
120
  x = torch.nan_to_num(x, nan=0.0)
121
  return self.layer(x)
122
 
 
123
  class Linear(nn.Linear):
124
- def __init__(self, num_features, emsize):
125
  super().__init__(num_features, emsize)
126
  self.num_features = num_features
127
  self.emsize = emsize
 
128
 
129
  def forward(self, x):
130
- x = torch.nan_to_num(x, nan=0.0)
 
131
  return super().forward(x)
132
 
133
- class SequenceSpanningEncoder(nn.Module):
134
- # Regular Encoder transforms Seq_len, B, S -> Seq_len, B, E attending only to last dimension
135
- # This Encoder accesses the Seq_Len dimension additionally
136
-
137
- # Why would we want this? We can learn normalization and embedding of features
138
- # , this might be more important for e.g. categorical, ordinal feats, nan detection
139
- # However maybe this can be easily learned through transformer as well?
140
- # A problem is to make this work across any sequence length and be independent of ordering
141
-
142
- # We could use average and maximum pooling and use those with a linear layer
143
-
144
-
145
- # Another idea !! Similar to this we would like to encode features so that their number is variable
146
- # We would like to embed features, also using knowledge of the features in the entire sequence
147
-
148
- # We could use convolution or another transformer
149
- # Convolution:
150
-
151
- # Transformer/Conv across sequence dimension that encodes and normalizes features
152
- # -> Transformer across feature dimension that encodes features to a constant size
153
 
154
- # Conv with flexible features but no sequence info: S,B,F -(reshape)-> S*B,1,F
155
- # -(Conv1d)-> S*B,N,F -(AvgPool,MaxPool)-> S*B,N,1 -> S,B,N
156
- # This probably won't work since it's missing a way to recognize which feature is encoded
157
-
158
- # Transformer with flexible features: S,B,F -> F,B*S,1 -> F2,B*S,1 -> S,B,F2
159
-
160
- def __init__(self, num_features, em_size):
161
- super().__init__()
162
-
163
- raise NotImplementedError()
164
- # Seq_len, B, S -> Seq_len, B, E
165
- #
166
- self.convs = torch.nn.ModuleList([nn.Conv1d(64 if i else 1, 64, 3) for i in range(5)])
167
- # self.linear = nn.Linear(64, emsize)
168
-
169
- class TransformerBasedFeatureEncoder(nn.Module):
170
- def __init__(self, num_features, emsize):
171
- super().__init__()
172
-
173
- hidden_emsize = emsize
174
- encoder = Linear(1, hidden_emsize)
175
- n_out = emsize
176
- nhid = 2*emsize
177
- dropout =0.0
178
- nhead=4
179
- nlayers=4
180
- model = nn.Transformer(nhead=nhead, num_encoder_layers=4, num_decoder_layers=4, d_model=1)
181
-
182
- def forward(self, *input):
183
- # S,B,F -> F,S*B,1 -> F2,S*B,1 -> S,B,F2
184
- input = input.transpose()
185
- self.model(input)
186
 
187
  class Conv(nn.Module):
188
  def __init__(self, input_size, emsize):
@@ -190,7 +209,6 @@ class Conv(nn.Module):
190
  self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
191
  self.linear = nn.Linear(64,emsize)
192
 
193
-
194
  def forward(self, x):
195
  size = math.isqrt(x.shape[-1])
196
  assert size*size == x.shape[-1]
@@ -204,8 +222,6 @@ class Conv(nn.Module):
204
  return self.linear(x)
205
 
206
 
207
-
208
-
209
  class CanEmb(nn.Embedding):
210
  def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
211
  assert embedding_dim % num_features == 0
@@ -218,8 +234,10 @@ class CanEmb(nn.Embedding):
218
  x = super().forward(lx)
219
  return x.view(*x.shape[:-2], -1)
220
 
 
221
  def get_Canonical(num_classes):
222
  return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
223
 
 
224
  def get_Embedding(num_embs_per_feature=100):
225
  return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
 
8
 
9
 
10
  class StyleEncoder(nn.Module):
11
+ def __init__(self, num_hyperparameters, em_size):
12
  super().__init__()
 
13
  self.em_size = em_size
14
+ self.embedding = nn.Linear(num_hyperparameters, self.em_size)
15
+
16
+ def forward(self, hyperparameters): # B x num_hps
 
 
 
 
 
 
 
 
 
 
17
  return self.embedding(hyperparameters)
18
 
19
 
20
+ class StyleEmbEncoder(nn.Module):
21
+ def __init__(self, num_hyperparameters, em_size, num_embeddings=100):
22
+ super().__init__()
23
+ assert num_hyperparameters == 1
24
+ self.em_size = em_size
25
+ self.embedding = nn.Embedding(num_embeddings, self.em_size)
26
+
27
+ def forward(self, hyperparameters): # B x num_hps
28
+ return self.embedding(hyperparameters.squeeze(1))
29
+
30
+
31
  class _PositionalEncoding(nn.Module):
32
  def __init__(self, d_model, dropout=0.):
33
  super().__init__()
 
97
  return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
98
 
99
 
100
+ def get_normalized_encoder(encoder_creator, data_std):
101
+ return lambda in_dim, out_dim: nn.Sequential(Normalize(0., data_std), encoder_creator(in_dim, out_dim))
102
+
103
+
104
+ class ZNormalize(nn.Module):
105
+ def forward(self, x):
106
+ return (x-x.mean(-1,keepdim=True))/x.std(-1,keepdim=True)
107
+
108
+
109
+ class AppendEmbeddingEncoder(nn.Module):
110
+ def __init__(self, base_encoder, num_features, emsize):
111
+ super().__init__()
112
+ self.num_features = num_features
113
+ self.base_encoder = base_encoder
114
+ self.emb = nn.Parameter(torch.zeros(emsize))
115
+
116
+ def forward(self, x):
117
+ if (x[-1] == 1.).all():
118
+ append_embedding = True
119
+ else:
120
+ assert (x[-1] == 0.).all(), "You need to specify as last position whether to append embedding. " \
121
+ "If you don't want this behavior, please use the wrapped encoder instead."
122
+ append_embedding = False
123
+ x = x[:-1]
124
+ encoded_x = self.base_encoder(x)
125
+ if append_embedding:
126
+ encoded_x = torch.cat([encoded_x, self.emb[None, None, :].repeat(1, encoded_x.shape[1], 1)], 0)
127
+ return encoded_x
128
+
129
+ def get_append_embedding_encoder(encoder_creator):
130
+ return lambda num_features, emsize: AppendEmbeddingEncoder(encoder_creator(num_features, emsize), num_features, emsize)
131
+
132
+
133
+ class VariableNumFeaturesEncoder(nn.Module):
134
+ def __init__(self, base_encoder, num_features):
135
+ super().__init__()
136
+ self.base_encoder = base_encoder
137
+ self.num_features = num_features
138
+
139
+ def forward(self, x):
140
+ x = x * (self.num_features/x.shape[-1])
141
+ x = torch.cat((x, torch.zeros(*x.shape[:-1], self.num_features - x.shape[-1], device=x.device)), -1)
142
+ return self.base_encoder(x)
143
+
144
+
145
+ def get_variable_num_features_encoder(encoder_creator):
146
+ return lambda num_features, emsize: VariableNumFeaturesEncoder(encoder_creator(num_features, emsize), num_features)
147
+
148
+ class NoMeanEncoder(nn.Module):
149
+ """
150
+ This can be useful for any prior that is translation invariant in x or y.
151
+ A standard GP for example is translation invariant in x.
152
+ That is, GP(x_test+const,x_train+const,y_train) = GP(x_test,x_train,y_train).
153
+ """
154
+ def __init__(self, base_encoder):
155
+ super().__init__()
156
+ self.base_encoder = base_encoder
157
+
158
+ def forward(self, x):
159
+ return self.base_encoder(x - x.mean(0, keepdim=True))
160
+
161
+
162
+ def get_no_mean_encoder(encoder_creator):
163
+ return lambda num_features, emsize: NoMeanEncoder(encoder_creator(num_features, emsize))
164
+
165
  Linear = nn.Linear
166
  MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
167
  nn.ReLU(),
 
185
  x = torch.nan_to_num(x, nan=0.0)
186
  return self.layer(x)
187
 
188
+
189
  class Linear(nn.Linear):
190
+ def __init__(self, num_features, emsize, replace_nan_by_zero=False):
191
  super().__init__(num_features, emsize)
192
  self.num_features = num_features
193
  self.emsize = emsize
194
+ self.replace_nan_by_zero = replace_nan_by_zero
195
 
196
  def forward(self, x):
197
+ if self.replace_nan_by_zero:
198
+ x = torch.nan_to_num(x, nan=0.0)
199
  return super().forward(x)
200
 
201
+ def __setstate__(self, state):
202
+ super().__setstate__(state)
203
+ self.__dict__.setdefault('replace_nan_by_zero', True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  class Conv(nn.Module):
207
  def __init__(self, input_size, emsize):
 
209
  self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
210
  self.linear = nn.Linear(64,emsize)
211
 
 
212
  def forward(self, x):
213
  size = math.isqrt(x.shape[-1])
214
  assert size*size == x.shape[-1]
 
222
  return self.linear(x)
223
 
224
 
 
 
225
  class CanEmb(nn.Embedding):
226
  def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
227
  assert embedding_dim % num_features == 0
 
234
  x = super().forward(lx)
235
  return x.view(*x.shape[:-2], -1)
236
 
237
+
238
  def get_Canonical(num_classes):
239
  return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
240
 
241
+
242
  def get_Embedding(num_embs_per_feature=100):
243
  return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)