lhallee commited on
Commit
3826ba8
·
verified ·
1 Parent(s): 0b7ba98

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +985 -0
modeling_fastesm.py ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from typing import Optional, Tuple, Union
6
+ from einops import rearrange
7
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
8
+ from transformers.modeling_outputs import (
9
+ MaskedLMOutput,
10
+ BaseModelOutputWithPastAndCrossAttentions,
11
+ BaseModelOutputWithPoolingAndCrossAttentions,
12
+ SequenceClassifierOutput,
13
+ TokenClassifierOutput
14
+ )
15
+ from transformers.models.esm.modeling_esm import (
16
+ EsmIntermediate,
17
+ EsmOutput,
18
+ EsmPooler,
19
+ EsmLMHead,
20
+ EsmSelfOutput,
21
+ EsmClassificationHead,
22
+ )
23
+ from tqdm.auto import tqdm
24
+
25
+
26
+ class FastEsmConfig(PretrainedConfig):
27
+ model_type = "fast_esm"
28
+ def __init__(
29
+ self,
30
+ vocab_size=None,
31
+ mask_token_id=None,
32
+ pad_token_id=None,
33
+ hidden_size=768,
34
+ num_hidden_layers=12,
35
+ num_attention_heads=12,
36
+ intermediate_size=3072,
37
+ hidden_dropout_prob=0.1,
38
+ attention_probs_dropout_prob=0.1,
39
+ max_position_embeddings=1026,
40
+ initializer_range=0.02,
41
+ layer_norm_eps=1e-12,
42
+ position_embedding_type="absolute",
43
+ emb_layer_norm_before=None,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
47
+
48
+ self.vocab_size = vocab_size
49
+ self.hidden_size = hidden_size
50
+ self.num_hidden_layers = num_hidden_layers
51
+ self.num_attention_heads = num_attention_heads
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_dropout_prob = hidden_dropout_prob
54
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.initializer_range = initializer_range
57
+ self.layer_norm_eps = layer_norm_eps
58
+ self.position_embedding_type = position_embedding_type
59
+ self.emb_layer_norm_before = emb_layer_norm_before
60
+
61
+ def to_dict(self):
62
+ """
63
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
64
+
65
+ Returns:
66
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
67
+ """
68
+ output = super().to_dict()
69
+ return output
70
+
71
+
72
+ def rotate_half(x):
73
+ x1, x2 = x.chunk(2, dim=-1)
74
+ return torch.cat((-x2, x1), dim=-1)
75
+
76
+
77
+ def apply_rotary_pos_emb(x, cos, sin):
78
+ cos = cos[:, :, : x.shape[-2], :]
79
+ sin = sin[:, :, : x.shape[-2], :]
80
+
81
+ return (x * cos) + (rotate_half(x) * sin)
82
+
83
+
84
+ def symmetrize(x):
85
+ "Make layer symmetric in final two dimensions, used for contact prediction."
86
+ return x + x.transpose(-1, -2)
87
+
88
+
89
+ def average_product_correct(x):
90
+ "Perform average product correct, used for contact prediction."
91
+ a1 = x.sum(-1, keepdims=True)
92
+ a2 = x.sum(-2, keepdims=True)
93
+ a12 = x.sum((-1, -2), keepdims=True)
94
+
95
+ avg = a1 * a2
96
+ avg.div_(a12) # in-place to reduce memory
97
+ normalized = x - avg
98
+ return normalized
99
+
100
+
101
+ class EsmContactPredictionHead(nn.Module):
102
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
103
+
104
+ def __init__(
105
+ self,
106
+ in_features: int,
107
+ bias=True,
108
+ eos_idx: int = 2,
109
+ ):
110
+ super().__init__()
111
+ self.in_features = in_features
112
+ self.eos_idx = eos_idx
113
+ self.regression = nn.Linear(in_features, 1, bias)
114
+ self.activation = nn.Sigmoid()
115
+
116
+ def forward(self, tokens, attentions):
117
+ # remove eos token attentions
118
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
119
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
120
+ attentions = attentions * eos_mask[:, None, None, :, :]
121
+ attentions = attentions[..., :-1, :-1]
122
+ # remove cls token attentions
123
+ attentions = attentions[..., 1:, 1:]
124
+ batch_size, layers, heads, seqlen, _ = attentions.size()
125
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
126
+
127
+ # features: batch x channels x tokens x tokens (symmetric)
128
+ attentions = attentions.to(
129
+ self.regression.weight.device
130
+ ) # attentions always float32, may need to convert to float16
131
+ attentions = average_product_correct(symmetrize(attentions))
132
+ attentions = attentions.permute(0, 2, 3, 1)
133
+ return self.activation(self.regression(attentions).squeeze(3))
134
+
135
+
136
+ class RotaryEmbedding(torch.nn.Module):
137
+ """
138
+ Rotary position embeddings based on those in
139
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
140
+ matrices which depend on their relative positions.
141
+ """
142
+
143
+ def __init__(self, dim: int):
144
+ super().__init__()
145
+ # Generate and save the inverse frequency buffer (non trainable)
146
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
147
+ inv_freq = inv_freq
148
+ self.register_buffer("inv_freq", inv_freq)
149
+
150
+ self._seq_len_cached = None
151
+ self._cos_cached = None
152
+ self._sin_cached = None
153
+
154
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
155
+ seq_len = x.shape[seq_dimension]
156
+
157
+ # Reset the tables if the sequence length has changed,
158
+ # or if we're on a new device (possibly due to tracing for instance)
159
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
160
+ self._seq_len_cached = seq_len
161
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
162
+ freqs = torch.outer(t, self.inv_freq)
163
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
164
+
165
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
166
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
167
+
168
+ return self._cos_cached, self._sin_cached
169
+
170
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
171
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
172
+
173
+ return (
174
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
175
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
176
+ )
177
+
178
+
179
+ class EsmEmbeddings(nn.Module):
180
+ """
181
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
182
+ """
183
+
184
+ def __init__(self, config):
185
+ super().__init__()
186
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
187
+ if config.emb_layer_norm_before:
188
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
189
+ else:
190
+ self.layer_norm = None
191
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
192
+ self.register_buffer(
193
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
194
+ )
195
+
196
+ def forward(
197
+ self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
198
+ ):
199
+ if inputs_embeds is None:
200
+ inputs_embeds = self.word_embeddings(input_ids)
201
+
202
+ embeddings = inputs_embeds
203
+
204
+ if self.layer_norm is not None:
205
+ embeddings = self.layer_norm(embeddings)
206
+ if attention_mask is not None:
207
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
208
+ return embeddings
209
+
210
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
211
+ """
212
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
213
+
214
+ Args:
215
+ inputs_embeds: torch.Tensor
216
+
217
+ Returns: torch.Tensor
218
+ """
219
+ input_shape = inputs_embeds.size()[:-1]
220
+ sequence_length = input_shape[1]
221
+
222
+ position_ids = torch.arange(
223
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
224
+ )
225
+ return position_ids.unsqueeze(0).expand(input_shape)
226
+
227
+
228
+ class EsmSelfAttention(nn.Module):
229
+ def __init__(self, config, position_embedding_type=None):
230
+ super().__init__()
231
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
232
+ raise ValueError(
233
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
234
+ f"heads ({config.num_attention_heads})"
235
+ )
236
+
237
+ self.num_attention_heads = config.num_attention_heads
238
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
239
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
240
+
241
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
242
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
243
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
244
+ self.scale = self.attention_head_size**-0.5
245
+
246
+ self.dropout_prob = config.attention_probs_dropout_prob
247
+ self.position_embedding_type = position_embedding_type or getattr(
248
+ config, "position_embedding_type", "absolute"
249
+ )
250
+ self.rotary_embeddings = None
251
+ if self.position_embedding_type == "rotary":
252
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
253
+
254
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
255
+ return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
256
+
257
+ def forward(
258
+ self,
259
+ hidden_states: torch.Tensor,
260
+ attention_mask: Optional[torch.FloatTensor] = None,
261
+ output_attentions: bool = False,
262
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
263
+ """Forward pass for self attention.
264
+
265
+ Args:
266
+ hidden_states: Input tensor
267
+ attention_mask: Optional attention mask
268
+ output_attentions: Whether to return attention weights
269
+
270
+ Returns:
271
+ Output tensor and optionally attention weights
272
+ """
273
+ query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
274
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
275
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
276
+
277
+ if self.position_embedding_type == "rotary":
278
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
279
+
280
+ if output_attentions:
281
+ # Manual attention computation to get attention weights
282
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
+ if attention_mask is not None:
284
+ attention_scores = attention_scores + attention_mask
285
+ attention_probs = F.softmax(attention_scores, dim=-1)
286
+ if self.dropout_prob > 0:
287
+ attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
288
+ context_layer = torch.matmul(attention_probs, value_layer)
289
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
290
+ return context_layer, attention_probs
291
+ else:
292
+ context_layer = F.scaled_dot_product_attention(
293
+ query_layer,
294
+ key_layer,
295
+ value_layer,
296
+ attn_mask=attention_mask,
297
+ dropout_p=self.dropout_prob,
298
+ scale=1.0
299
+ )
300
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
301
+ return context_layer
302
+
303
+
304
+ class EsmAttention(nn.Module):
305
+ def __init__(self, config):
306
+ super().__init__()
307
+ self.self = EsmSelfAttention(config)
308
+ self.output = EsmSelfOutput(config)
309
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
310
+
311
+ def forward(
312
+ self,
313
+ hidden_states: torch.Tensor,
314
+ attention_mask: Optional[torch.FloatTensor] = None,
315
+ output_attentions: bool = False,
316
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
317
+ """Forward pass for attention layer.
318
+
319
+ Args:
320
+ hidden_states: Input tensor
321
+ attention_mask: Optional attention mask
322
+ output_attentions: Whether to return attention weights
323
+
324
+ Returns:
325
+ Output tensor and optionally attention weights
326
+ """
327
+ hidden_states_ln = self.LayerNorm(hidden_states)
328
+ self_outputs = self.self(
329
+ hidden_states_ln,
330
+ attention_mask,
331
+ output_attentions,
332
+ )
333
+ if output_attentions:
334
+ attention_output, attention_weights = self_outputs
335
+ attention_output = self.output(attention_output, hidden_states)
336
+ return attention_output, attention_weights
337
+ else:
338
+ attention_output = self_outputs
339
+ return self.output(attention_output, hidden_states)
340
+
341
+
342
+ class EsmLayer(nn.Module):
343
+ def __init__(self, config):
344
+ super().__init__()
345
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
346
+ self.seq_len_dim = 1
347
+ self.attention = EsmAttention(config)
348
+ self.intermediate = EsmIntermediate(config)
349
+ self.output = EsmOutput(config)
350
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
351
+
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.Tensor,
355
+ attention_mask: Optional[torch.FloatTensor] = None,
356
+ output_attentions: bool = False,
357
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
358
+ """Forward pass for transformer layer.
359
+
360
+ Args:
361
+ hidden_states: Input tensor
362
+ attention_mask: Optional attention mask
363
+ output_attentions: Whether to return attention weights
364
+
365
+ Returns:
366
+ Output tensor and optionally attention weights
367
+ """
368
+ attention_outputs = self.attention(
369
+ hidden_states,
370
+ attention_mask,
371
+ output_attentions,
372
+ )
373
+ if output_attentions:
374
+ attention_output, attention_weights = attention_outputs
375
+ else:
376
+ attention_output = attention_outputs
377
+ attention_weights = None
378
+
379
+ layer_output = self.feed_forward_chunk(attention_output)
380
+
381
+ if output_attentions:
382
+ return layer_output, attention_weights
383
+ return layer_output
384
+
385
+ def feed_forward_chunk(self, attention_output):
386
+ attention_output_ln = self.LayerNorm(attention_output)
387
+ intermediate_output = self.intermediate(attention_output_ln)
388
+ layer_output = self.output(intermediate_output, attention_output)
389
+ return layer_output
390
+
391
+
392
+ class EsmEncoder(nn.Module):
393
+ def __init__(self, config):
394
+ super().__init__()
395
+ self.config = config
396
+ self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
397
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
398
+ self.gradient_checkpointing = False
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states: torch.Tensor,
403
+ attention_mask: Optional[torch.FloatTensor] = None,
404
+ output_hidden_states: bool = False,
405
+ output_attentions: bool = False,
406
+ ) -> BaseModelOutputWithPastAndCrossAttentions:
407
+ """Forward pass for transformer encoder.
408
+
409
+ Args:
410
+ hidden_states: Input tensor
411
+ attention_mask: Optional attention mask
412
+ output_hidden_states: Whether to return all hidden states
413
+ output_attentions: Whether to return attention weights
414
+
415
+ Returns:
416
+ BaseModelOutputWithPastAndCrossAttentions containing model outputs
417
+ """
418
+ all_hidden_states = () if output_hidden_states else None
419
+ all_attentions = () if output_attentions else None
420
+
421
+ for layer_module in self.layer:
422
+ if output_hidden_states:
423
+ all_hidden_states = all_hidden_states + (hidden_states,)
424
+
425
+ if self.gradient_checkpointing and self.training:
426
+ layer_outputs = self._gradient_checkpointing_func(
427
+ layer_module.__call__,
428
+ hidden_states,
429
+ attention_mask,
430
+ output_attentions,
431
+ )
432
+ else:
433
+ layer_outputs = layer_module(
434
+ hidden_states,
435
+ attention_mask,
436
+ output_attentions,
437
+ )
438
+
439
+ if output_attentions:
440
+ hidden_states, attention_weights = layer_outputs
441
+ all_attentions = all_attentions + (attention_weights,)
442
+ else:
443
+ hidden_states = layer_outputs
444
+
445
+ if self.emb_layer_norm_after:
446
+ hidden_states = self.emb_layer_norm_after(hidden_states)
447
+
448
+ if output_hidden_states:
449
+ all_hidden_states = all_hidden_states + (hidden_states,)
450
+
451
+ return BaseModelOutputWithPastAndCrossAttentions(
452
+ last_hidden_state=hidden_states,
453
+ hidden_states=all_hidden_states,
454
+ attentions=all_attentions,
455
+ )
456
+
457
+
458
+ ### Dataset for Embedding
459
+ class ProteinDataset(Dataset):
460
+ """Simple dataset for protein sequences."""
461
+ def __init__(self, sequences: list[str]):
462
+ self.sequences = sequences
463
+
464
+ def __len__(self) -> int:
465
+ return len(self.sequences)
466
+
467
+ def __getitem__(self, idx: int) -> str:
468
+ return self.sequences[idx]
469
+
470
+
471
+ class FastEsmPreTrainedModel(PreTrainedModel):
472
+ """
473
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
474
+ models.
475
+ """
476
+ config_class = FastEsmConfig
477
+ base_model_prefix = "fastesm"
478
+ supports_gradient_checkpointing = True
479
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
480
+ def _init_weights(self, module):
481
+ """Initialize the weights"""
482
+ if isinstance(module, nn.Linear):
483
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
484
+ if module.bias is not None:
485
+ module.bias.data.zero_()
486
+ elif isinstance(module, nn.Embedding):
487
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
488
+ if module.padding_idx is not None:
489
+ module.weight.data[module.padding_idx].zero_()
490
+ elif isinstance(module, nn.LayerNorm):
491
+ module.bias.data.zero_()
492
+ module.weight.data.fill_(1.0)
493
+
494
+ def get_input_embeddings(self) -> nn.Module:
495
+ try:
496
+ return self.embeddings.word_embeddings
497
+ except AttributeError:
498
+ return self.esm.embeddings.word_embeddings
499
+
500
+ @property
501
+ def device(self) -> torch.device:
502
+ """Get the device of the model."""
503
+ return next(self.parameters()).device
504
+
505
+ def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
506
+ """Apply mean pooling to sequence outputs."""
507
+ if attention_mask is None:
508
+ return x.mean(dim=1)
509
+ else:
510
+ attention_mask = attention_mask.unsqueeze(-1)
511
+ return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
512
+
513
+ def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
514
+ """Collate function for batching sequences."""
515
+ return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
516
+
517
+ def _read_sequences_from_db(self, db_path: str) -> set[str]:
518
+ """Read sequences from SQLite database."""
519
+ import sqlite3
520
+ sequences = []
521
+ with sqlite3.connect(db_path) as conn:
522
+ c = conn.cursor()
523
+ c.execute("SELECT sequence FROM embeddings")
524
+ while True:
525
+ row = c.fetchone()
526
+ if row is None:
527
+ break
528
+ sequences.append(row[0])
529
+ return set(sequences)
530
+
531
+ def embed_dataset(
532
+ self,
533
+ sequences: list[str],
534
+ batch_size: int = 2,
535
+ max_len: int = 512,
536
+ full_embeddings: bool = False,
537
+ full_precision: bool = False,
538
+ pooling_type: str = 'mean',
539
+ num_workers: int = 0,
540
+ sql: bool = False,
541
+ sql_db_path: str = 'embeddings.db',
542
+ ) -> Optional[dict[str, torch.Tensor]]:
543
+ """Embed a dataset of protein sequences.
544
+
545
+ Args:
546
+ sequences: List of protein sequences
547
+ batch_size: Batch size for processing
548
+ max_len: Maximum sequence length
549
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
550
+ full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
551
+ pooling_type: Type of pooling ('mean' or 'cls')
552
+ num_workers: Number of workers for data loading, 0 for the main process
553
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
554
+ sql_db_path: Path to SQLite database
555
+
556
+ Returns:
557
+ Dictionary mapping sequences to embeddings, or None if sql=True
558
+ """
559
+ sequences = list(set([seq[:max_len] for seq in sequences]))
560
+ sequences = sorted(sequences, key=len, reverse=True)
561
+ dataset = ProteinDataset(sequences)
562
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn)
563
+ device = self.device
564
+
565
+ def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
566
+ if full_embeddings:
567
+ return residue_embeddings
568
+ elif pooling_type == 'mean':
569
+ return self.mean_pooling(residue_embeddings, attention_mask)
570
+ else:
571
+ return residue_embeddings[:, 0, :]
572
+
573
+ if sql:
574
+ import sqlite3
575
+ conn = sqlite3.connect(sql_db_path)
576
+ c = conn.cursor()
577
+ c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
578
+ already_embedded = self._read_sequences_from_db(sql_db_path)
579
+ to_embed = [seq for seq in sequences if seq not in already_embedded]
580
+ print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
581
+ print(f"Embedding {len(to_embed)} new sequences")
582
+ if len(to_embed) > 0:
583
+ with torch.no_grad():
584
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
585
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
586
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
587
+ residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float() # required for sql
588
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
589
+
590
+ for seq, emb in zip(seqs, embeddings):
591
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
592
+ (seq, emb.cpu().numpy().tobytes()))
593
+
594
+ if (i + 1) % 100 == 0:
595
+ conn.commit()
596
+
597
+ conn.commit()
598
+ conn.close()
599
+ return None
600
+
601
+ embeddings_dict = {}
602
+ with torch.no_grad():
603
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
604
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
605
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
606
+ residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float()
607
+ if full_precision:
608
+ residue_embeddings = residue_embeddings.float()
609
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
610
+ for seq, emb in zip(seqs, embeddings):
611
+ embeddings_dict[seq] = emb
612
+
613
+ return embeddings_dict
614
+
615
+
616
+ class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
617
+ def __init__(self, config, add_pooling_layer=True):
618
+ super().__init__(config)
619
+ self.config = config
620
+ self.embeddings = EsmEmbeddings(config)
621
+ self.encoder = EsmEncoder(config)
622
+ # Initialize weights and apply final processing
623
+ self.post_init()
624
+
625
+ def get_input_embeddings(self):
626
+ return self.embeddings.word_embeddings
627
+
628
+ def set_input_embeddings(self, value):
629
+ self.embeddings.word_embeddings = value
630
+
631
+ def forward(
632
+ self,
633
+ input_ids: Optional[torch.LongTensor] = None,
634
+ attention_mask: Optional[torch.Tensor] = None,
635
+ position_ids: Optional[torch.LongTensor] = None,
636
+ inputs_embeds: Optional[torch.FloatTensor] = None,
637
+ output_attentions: Optional[bool] = None,
638
+ output_hidden_states: Optional[bool] = None,
639
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
640
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
641
+ """Forward pass for base model.
642
+
643
+ Args:
644
+ input_ids: Input token IDs
645
+ attention_mask: Optional attention mask
646
+ position_ids: Optional position IDs
647
+ inputs_embeds: Optional input embeddings
648
+ output_hidden_states: Whether to return all hidden states
649
+ output_attentions: Whether to return attention weights
650
+
651
+ Returns:
652
+ Model outputs including hidden states and optionally attention weights
653
+ """
654
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
655
+ output_hidden_states = (
656
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
657
+ )
658
+
659
+ if input_ids is not None and inputs_embeds is not None:
660
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
661
+ elif input_ids is not None:
662
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
663
+ input_shape = input_ids.size()
664
+ elif inputs_embeds is not None:
665
+ input_shape = inputs_embeds.size()[:-1]
666
+ else:
667
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
668
+
669
+ batch_size, seq_length = input_shape
670
+ embedding_output = self.embeddings(
671
+ input_ids=input_ids,
672
+ position_ids=position_ids,
673
+ attention_mask=attention_mask,
674
+ inputs_embeds=inputs_embeds,
675
+ )
676
+
677
+ if attention_mask is not None:
678
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
679
+ batch_size, 1, seq_length, seq_length
680
+ ).bool()
681
+ else:
682
+ extended_attention_mask = None
683
+
684
+ encoder_outputs = self.encoder(
685
+ embedding_output,
686
+ attention_mask=extended_attention_mask,
687
+ output_hidden_states=output_hidden_states,
688
+ output_attentions=output_attentions,
689
+ )
690
+ sequence_output = encoder_outputs.last_hidden_state
691
+
692
+ return BaseModelOutputWithPoolingAndCrossAttentions(
693
+ last_hidden_state=sequence_output,
694
+ hidden_states=encoder_outputs.hidden_states,
695
+ attentions=encoder_outputs.attentions,
696
+ )
697
+
698
+
699
+ class FastEsmModel(FastEsmPreTrainedModel):
700
+ def __init__(self, config, add_pooling_layer=True):
701
+ super().__init__(config)
702
+ self.config = config
703
+ self.esm = FAST_ESM_ENCODER(config)
704
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
705
+ # Initialize weights and apply final processing
706
+ self.post_init()
707
+
708
+ def get_input_embeddings(self):
709
+ return self.embeddings.word_embeddings
710
+
711
+ def set_input_embeddings(self, value):
712
+ self.embeddings.word_embeddings = value
713
+
714
+ def forward(
715
+ self,
716
+ input_ids: Optional[torch.LongTensor] = None,
717
+ attention_mask: Optional[torch.Tensor] = None,
718
+ position_ids: Optional[torch.LongTensor] = None,
719
+ inputs_embeds: Optional[torch.FloatTensor] = None,
720
+ output_attentions: Optional[bool] = None,
721
+ output_hidden_states: Optional[bool] = None,
722
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
723
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
724
+ """Forward pass for base model.
725
+
726
+ Args:
727
+ input_ids: Input token IDs
728
+ attention_mask: Optional attention mask
729
+ position_ids: Optional position IDs
730
+ inputs_embeds: Optional input embeddings
731
+ output_hidden_states: Whether to return all hidden states
732
+ output_attentions: Whether to return attention weights
733
+
734
+ Returns:
735
+ Model outputs including hidden states and optionally attention weights
736
+ """
737
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
738
+ output_hidden_states = (
739
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
740
+ )
741
+
742
+ if input_ids is not None and inputs_embeds is not None:
743
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
744
+ elif input_ids is not None:
745
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
746
+ input_shape = input_ids.size()
747
+ elif inputs_embeds is not None:
748
+ input_shape = inputs_embeds.size()[:-1]
749
+ else:
750
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
751
+
752
+ outputs = self.esm(
753
+ input_ids,
754
+ attention_mask=attention_mask,
755
+ position_ids=position_ids,
756
+ inputs_embeds=inputs_embeds,
757
+ output_hidden_states=output_hidden_states,
758
+ output_attentions=output_attentions,
759
+ )
760
+ sequence_output = outputs.last_hidden_state
761
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
762
+
763
+ return BaseModelOutputWithPoolingAndCrossAttentions(
764
+ last_hidden_state=sequence_output,
765
+ pooler_output=pooled_output,
766
+ hidden_states=outputs.hidden_states,
767
+ attentions=outputs.attentions,
768
+ )
769
+
770
+
771
+ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
772
+ _tied_weights_keys = ["lm_head.decoder.weight"]
773
+
774
+ def __init__(self, config):
775
+ super().__init__(config)
776
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
777
+ self.lm_head = EsmLMHead(config)
778
+ self.loss_fct = nn.CrossEntropyLoss()
779
+ self.init_weights()
780
+
781
+ def get_output_embeddings(self):
782
+ return self.lm_head.decoder
783
+
784
+ def set_output_embeddings(self, new_embeddings):
785
+ self.lm_head.decoder = new_embeddings
786
+
787
+ def forward(
788
+ self,
789
+ input_ids: Optional[torch.LongTensor] = None,
790
+ attention_mask: Optional[torch.Tensor] = None,
791
+ position_ids: Optional[torch.LongTensor] = None,
792
+ inputs_embeds: Optional[torch.FloatTensor] = None,
793
+ labels: Optional[torch.LongTensor] = None,
794
+ output_attentions: Optional[bool] = None,
795
+ output_hidden_states: Optional[bool] = None,
796
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
797
+ ) -> Union[Tuple, MaskedLMOutput]:
798
+ outputs = self.esm(
799
+ input_ids,
800
+ attention_mask=attention_mask,
801
+ position_ids=position_ids,
802
+ inputs_embeds=inputs_embeds,
803
+ output_hidden_states=output_hidden_states,
804
+ output_attentions=output_attentions,
805
+ )
806
+ sequence_output = outputs.last_hidden_state
807
+ prediction_scores = self.lm_head(sequence_output)
808
+
809
+ loss = None
810
+ if labels is not None:
811
+ labels = labels.to(prediction_scores.device)
812
+ loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
813
+
814
+ return MaskedLMOutput(
815
+ loss=loss,
816
+ logits=prediction_scores,
817
+ hidden_states=outputs.hidden_states,
818
+ attentions=outputs.attentions,
819
+ )
820
+
821
+ def predict_contacts(self, tokens, attention_mask):
822
+ raise NotImplementedError("predict_contacts is not supported by F.scaled_dot_product_attention")
823
+
824
+
825
+ class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
826
+ def __init__(self, config):
827
+ super().__init__(config)
828
+ self.num_labels = config.num_labels
829
+ self.config = config
830
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
831
+ self.classifier = EsmClassificationHead(config)
832
+ self.mse = nn.MSELoss()
833
+ self.ce = nn.CrossEntropyLoss()
834
+ self.bce = nn.BCEWithLogitsLoss()
835
+ self.init_weights()
836
+
837
+ def forward(
838
+ self,
839
+ input_ids: Optional[torch.LongTensor] = None,
840
+ attention_mask: Optional[torch.Tensor] = None,
841
+ position_ids: Optional[torch.LongTensor] = None,
842
+ inputs_embeds: Optional[torch.FloatTensor] = None,
843
+ labels: Optional[torch.LongTensor] = None,
844
+ output_attentions: Optional[bool] = None,
845
+ output_hidden_states: Optional[bool] = None,
846
+ ) -> Union[Tuple, SequenceClassifierOutput]:
847
+ outputs = self.esm(
848
+ input_ids,
849
+ attention_mask=attention_mask,
850
+ position_ids=position_ids,
851
+ inputs_embeds=inputs_embeds,
852
+ output_attentions=output_attentions,
853
+ output_hidden_states=output_hidden_states,
854
+ )
855
+ sequence_output = outputs.last_hidden_state
856
+ logits = self.classifier(sequence_output)
857
+
858
+ loss = None
859
+ if labels is not None:
860
+ labels = labels.to(logits.device)
861
+ if self.config.problem_type is None:
862
+ if self.num_labels == 1:
863
+ self.config.problem_type = "regression"
864
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
865
+ self.config.problem_type = "single_label_classification"
866
+ else:
867
+ self.config.problem_type = "multi_label_classification"
868
+
869
+ if self.config.problem_type == "regression":
870
+ if self.num_labels == 1:
871
+ loss = self.mse(logits.squeeze(), labels.squeeze())
872
+ else:
873
+ loss = self.mse(logits, labels)
874
+ elif self.config.problem_type == "single_label_classification":
875
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
876
+ elif self.config.problem_type == "multi_label_classification":
877
+ loss = self.bce(logits, labels)
878
+
879
+ return SequenceClassifierOutput(
880
+ loss=loss,
881
+ logits=logits,
882
+ hidden_states=outputs.hidden_states,
883
+ attentions=outputs.attentions,
884
+ )
885
+
886
+
887
+ class FastEsmForTokenClassification(FastEsmPreTrainedModel):
888
+ def __init__(self, config):
889
+ super().__init__(config)
890
+ self.num_labels = config.num_labels
891
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
892
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
893
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
894
+ self.loss_fct = nn.CrossEntropyLoss()
895
+ self.init_weights()
896
+
897
+ def forward(
898
+ self,
899
+ input_ids: Optional[torch.LongTensor] = None,
900
+ attention_mask: Optional[torch.Tensor] = None,
901
+ position_ids: Optional[torch.LongTensor] = None,
902
+ inputs_embeds: Optional[torch.FloatTensor] = None,
903
+ labels: Optional[torch.LongTensor] = None,
904
+ output_attentions: Optional[bool] = None,
905
+ output_hidden_states: Optional[bool] = None,
906
+ ) -> Union[Tuple, TokenClassifierOutput]:
907
+ outputs = self.esm(
908
+ input_ids,
909
+ attention_mask=attention_mask,
910
+ position_ids=position_ids,
911
+ inputs_embeds=inputs_embeds,
912
+ output_attentions=output_attentions,
913
+ output_hidden_states=output_hidden_states,
914
+ )
915
+ sequence_output = outputs.last_hidden_state
916
+ sequence_output = self.dropout(sequence_output)
917
+ logits = self.classifier(sequence_output)
918
+
919
+ loss = None
920
+ if labels is not None:
921
+ labels = labels.to(logits.device)
922
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
923
+
924
+ return TokenClassifierOutput(
925
+ loss=loss,
926
+ logits=logits,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ )
930
+
931
+
932
+ if __name__ == "__main__":
933
+ """
934
+ Test the hidden state differences between the FastEsmModel and the HF EsmModel.
935
+ In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
936
+ In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
937
+ """
938
+ import random
939
+ from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
940
+
941
+ model_paths = [
942
+ "facebook/esm2_t6_8M_UR50D",
943
+ "facebook/esm2_t12_35M_UR50D",
944
+ #"facebook/esm2_t30_150M_UR50D",
945
+ #"facebook/esm2_t33_650M_UR50D",
946
+ ]
947
+ canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
948
+ length = 64
949
+ seq_count = 100
950
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
951
+ tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
952
+
953
+ def generate_random_sequence(length: int) -> str:
954
+ return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
955
+
956
+ print("Percentage of hidden states that are within the tolerance:")
957
+ for model_path in model_paths:
958
+ print(f"Testing {model_path}...")
959
+ tokenizer = EsmTokenizer.from_pretrained(model_path)
960
+ config = FastEsmConfig.from_pretrained(model_path)
961
+ fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
962
+ model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
963
+
964
+ counts = [0] * len(tolerances)
965
+ for _ in range(seq_count):
966
+ example_seq = generate_random_sequence(length)
967
+ fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
968
+ fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
969
+
970
+ model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
971
+ model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
972
+
973
+ for i, atol in enumerate(tolerances):
974
+ if torch.allclose(fast_output, model_output, atol=atol):
975
+ counts[i] += 1
976
+
977
+ print(f"{model_path}:")
978
+ for i, atol in enumerate(tolerances):
979
+ print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
980
+
981
+ model.cpu()
982
+ fast_model.cpu()
983
+ del model
984
+ del fast_model
985
+ torch.cuda.empty_cache()