lhallee commited on
Commit
35a553e
·
verified ·
1 Parent(s): 167fb7f

Upload modeling_esm_plusplus.py

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +634 -0
modeling_esm_plusplus.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Modified from https://github.com/evolutionaryscale/esm
2
+ ### License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from dataclasses import dataclass
8
+ from transformers import PreTrainedModel, PretrainedConfig
9
+ from einops import rearrange, repeat
10
+ from functools import partial
11
+ from typing import Optional, Tuple
12
+ from transformers.modeling_outputs import ModelOutput
13
+
14
+
15
+ class ESMplusplusConfig(PretrainedConfig):
16
+ model_type = "ESMplusplus"
17
+ def __init__(
18
+ self,
19
+ vocab_size: int = 64,
20
+ hidden_size: int = 960,
21
+ num_attention_heads: int = 15,
22
+ num_hidden_layers: int = 30,
23
+ num_labels: int = 2,
24
+ problem_type: str | None = None,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.vocab_size = vocab_size
29
+ self.hidden_size = hidden_size
30
+ self.num_attention_heads = num_attention_heads
31
+ self.num_hidden_layers = num_hidden_layers
32
+ self.num_labels = num_labels
33
+ self.problem_type = problem_type
34
+
35
+
36
+ ### Rotary
37
+ # https://github.com/evolutionaryscale/esm/blob/main/esm/layers/rotary.py
38
+ # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/08639a72e17836184096ae6a7e2766f2a34c3e36/modeling_flash_llama.py#L114
39
+ # Flash attention rotary implementation can be installed like so: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`
40
+ def rotate_half(x, interleaved=False):
41
+ if not interleaved:
42
+ x1, x2 = x.chunk(2, dim=-1)
43
+ return torch.cat((-x2, x1), dim=-1)
44
+ else:
45
+ x1, x2 = x[..., ::2], x[..., 1::2]
46
+ return rearrange(
47
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
48
+ )
49
+
50
+
51
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False):
52
+ """
53
+ x: (batch_size, seqlen, nheads, headdim)
54
+ cos, sin: (seqlen, rotary_dim / 2)
55
+ """
56
+ ro_dim = cos.shape[-1] * 2
57
+ assert ro_dim <= x.shape[-1]
58
+ seqlen = x.size(1)
59
+ cos = cos[:seqlen]
60
+ sin = sin[:seqlen]
61
+ cos = repeat(cos, "s d -> s 1 (2 d)")
62
+ sin = repeat(sin, "s d -> s 1 (2 d)")
63
+ return torch.cat(
64
+ [
65
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
66
+ x[..., ro_dim:],
67
+ ],
68
+ dim=-1,
69
+ )
70
+
71
+
72
+ class RotaryEmbedding(torch.nn.Module):
73
+ def __init__(
74
+ self,
75
+ dim: int,
76
+ base=10000.0,
77
+ interleaved=False,
78
+ scale_base=None,
79
+ scaling_factor=1.0,
80
+ pos_idx_in_fp32=True,
81
+ device=None,
82
+ ):
83
+ super().__init__()
84
+ self.dim = dim
85
+ self.base = float(base)
86
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
87
+ # Generate and save the inverse frequency buffer (non trainable)
88
+ self.interleaved = interleaved
89
+ self.scale_base = scale_base
90
+ self.scaling_factor = scaling_factor
91
+ self.device = device
92
+
93
+ self._seq_len_cached = 0
94
+ self._cos_cached = None
95
+ self._sin_cached = None
96
+ self._cos_k_cached = None
97
+ self._sin_k_cached = None
98
+ self.reset_parameters()
99
+
100
+ def reset_parameters(self):
101
+ inv_freq = self._compute_inv_freq(self.device)
102
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
103
+ arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
104
+ scale = (
105
+ (arange + 0.4 * self.dim) / (1.4 * self.dim)
106
+ if self.scale_base is not None
107
+ else None
108
+ )
109
+ self.register_buffer("scale", scale)
110
+
111
+ def _compute_inv_freq(self, device=None):
112
+ return 1 / (
113
+ self.base
114
+ ** (
115
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
116
+ / self.dim
117
+ )
118
+ )
119
+
120
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
121
+ if (
122
+ seqlen > self._seq_len_cached
123
+ or self._cos_cached is None
124
+ or self._cos_cached.device != device
125
+ or self._cos_cached.dtype != dtype
126
+ or (self.training and self._cos_cached.is_inference())
127
+ ):
128
+ self._seq_len_cached = seqlen
129
+ if self.pos_idx_in_fp32:
130
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
131
+ t /= self.scaling_factor
132
+ if self.inv_freq.dtype != torch.float32:
133
+ inv_freq = self.inv_freq.to(torch.float32)
134
+ else:
135
+ inv_freq = self.inv_freq
136
+ else:
137
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
138
+ t /= self.scaling_factor
139
+ inv_freq = self.inv_freq
140
+ freqs = torch.outer(t, inv_freq)
141
+
142
+ if self.scale is None:
143
+ self._cos_cached = torch.cos(freqs).to(dtype)
144
+ self._sin_cached = torch.sin(freqs).to(dtype)
145
+ else:
146
+ power = (
147
+ torch.arange(
148
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
149
+ )
150
+ - seqlen // 2
151
+ ) / self.scale_base
152
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
153
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
154
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
155
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
156
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
157
+
158
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
159
+ """
160
+ q: (batch, seqlen, nheads, headdim)
161
+ k: (batch, seqlen, nheads, headdim)
162
+ """
163
+ self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
164
+ assert self._cos_cached is not None
165
+ assert self._sin_cached is not None
166
+ if self.scale is None:
167
+ return (
168
+ apply_rotary_emb_torch(
169
+ q,
170
+ self._cos_cached,
171
+ self._sin_cached,
172
+ self.interleaved,
173
+ True, # inplace=True
174
+ ),
175
+ apply_rotary_emb_torch(
176
+ k,
177
+ self._cos_cached,
178
+ self._sin_cached,
179
+ self.interleaved,
180
+ True, # inplace=True
181
+ ),
182
+ ) # type: ignore
183
+ else:
184
+ assert False
185
+
186
+
187
+ ### Feedforward
188
+ def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
189
+ return int(((expansion_ratio * d_model) + 255) // 256 * 256)
190
+
191
+
192
+ class SwiGLU(nn.Module):
193
+ def __init__(self):
194
+ super(SwiGLU, self).__init__()
195
+
196
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
197
+ x1, x2 = x.chunk(2, dim=-1)
198
+ return F.silu(x1) * x2
199
+
200
+
201
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float):
202
+ return nn.Sequential(
203
+ nn.LayerNorm(d_model),
204
+ nn.Linear(
205
+ d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
206
+ ),
207
+ SwiGLU(),
208
+ nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
209
+ )
210
+
211
+
212
+ ### Attention
213
+ class MultiHeadAttention(nn.Module):
214
+ def __init__(self, d_model: int, n_heads: int):
215
+ super().__init__()
216
+ self.d_model = d_model
217
+ self.n_heads = n_heads
218
+ self.d_head = self.d_model // self.n_heads
219
+ self.layernorm_qkv = nn.Sequential(
220
+ nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
221
+ )
222
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
223
+ self.q_ln = nn.LayerNorm(d_model, bias=False)
224
+ self.k_ln = nn.LayerNorm(d_model, bias=False)
225
+ self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
226
+ self.rotary = RotaryEmbedding(d_model // n_heads)
227
+
228
+ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor):
229
+ q = q.unflatten(-1, (self.n_heads, self.d_head))
230
+ k = k.unflatten(-1, (self.n_heads, self.d_head))
231
+ q, k = self.rotary(q, k)
232
+ q = q.flatten(-2, -1)
233
+ k = k.flatten(-2, -1)
234
+ return q, k
235
+
236
+ def forward(self, x, attention_mask=None):
237
+ qkv_BLD3 = self.layernorm_qkv(x)
238
+ query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
239
+ query_BLD, key_BLD = (
240
+ self.q_ln(query_BLD).to(query_BLD.dtype),
241
+ self.k_ln(key_BLD).to(query_BLD.dtype),
242
+ )
243
+ query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
244
+ query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
245
+ context_BHLD = F.scaled_dot_product_attention(
246
+ query_BHLD, key_BHLD, value_BHLD, attention_mask
247
+ )
248
+ context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
249
+ return self.out_proj(context_BLD)
250
+
251
+
252
+ ### LM Head
253
+ def RegressionHead(
254
+ d_model: int, output_dim: int, hidden_dim: int | None = None
255
+ ) -> nn.Module:
256
+ hidden_dim = hidden_dim if hidden_dim is not None else d_model
257
+ return nn.Sequential(
258
+ nn.Linear(d_model, hidden_dim),
259
+ nn.GELU(),
260
+ nn.LayerNorm(hidden_dim),
261
+ nn.Linear(hidden_dim, output_dim),
262
+ )
263
+
264
+
265
+ ### Transformer Block
266
+ class UnifiedTransformerBlock(nn.Module):
267
+ def __init__(
268
+ self,
269
+ d_model: int,
270
+ n_heads: int,
271
+ residue_scaling_factor: float = 1,
272
+ expansion_ratio: float = 8 / 3,
273
+ ):
274
+ super().__init__()
275
+ self.attn = MultiHeadAttention(d_model, n_heads)
276
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
277
+ self.scaling_factor = residue_scaling_factor
278
+
279
+ def forward(
280
+ self,
281
+ x: torch.Tensor,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ ) -> torch.Tensor:
284
+ r1 = self.attn(x, attention_mask)
285
+ x = x + r1 / self.scaling_factor
286
+ r3 = self.ffn(x) / self.scaling_factor
287
+ x = x + r3
288
+ return x
289
+
290
+
291
+ ### Outputs
292
+ @dataclass
293
+ class TransformerOutput(ModelOutput):
294
+ last_hidden_state: torch.Tensor | None = None
295
+ hidden_states: tuple[torch.Tensor] | None = None
296
+
297
+
298
+ @dataclass
299
+ class ESMplusplusOutput(ModelOutput):
300
+ loss: torch.Tensor | None = None
301
+ logits: torch.Tensor | None = None
302
+ last_hidden_state: torch.Tensor | None = None
303
+ hidden_states: tuple[torch.Tensor] | None = None
304
+
305
+
306
+ ### Transformer
307
+ class TransformerStack(nn.Module):
308
+ def __init__(
309
+ self,
310
+ d_model: int,
311
+ n_heads: int,
312
+ n_layers: int,
313
+ ):
314
+ super().__init__()
315
+ self.blocks = nn.ModuleList(
316
+ [
317
+ UnifiedTransformerBlock(
318
+ d_model,
319
+ n_heads,
320
+ residue_scaling_factor=math.sqrt(n_layers / 36),
321
+ )
322
+ for i in range(n_layers)
323
+ ]
324
+ )
325
+ self.norm = nn.LayerNorm(d_model, bias=False)
326
+
327
+ def forward(
328
+ self,
329
+ x: torch.Tensor,
330
+ attention_mask: Optional[torch.Tensor] = None,
331
+ output_hidden_states: bool = False,
332
+ ) -> TransformerOutput:
333
+ batch_size, seq_len, _ = x.shape
334
+ hidden_states = ()
335
+ if attention_mask is not None:
336
+ attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
337
+ for block in self.blocks:
338
+ x = block(x, attention_mask)
339
+ if output_hidden_states:
340
+ hidden_states += (x,)
341
+ return TransformerOutput(last_hidden_state=self.norm(x), hidden_states=hidden_states)
342
+
343
+
344
+ ### Full model
345
+ class ESMplusplusForMaskedLM(PreTrainedModel):
346
+ """
347
+ ESM++ for masked language modeling.
348
+ """
349
+ def __init__(self, config: ESMplusplusConfig):
350
+ super().__init__(config)
351
+ self.config = config
352
+ self.vocab_size = config.vocab_size
353
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
354
+ self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers)
355
+ self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
356
+ self.ce_loss = nn.CrossEntropyLoss()
357
+ self.tokenizer = EsmSequenceTokenizer()
358
+
359
+ @classmethod
360
+ def from_pretrained_esm(cls, model_name: str):
361
+ if '300' in model_name:
362
+ return ESMplusplus_300M()
363
+ elif '600' in model_name:
364
+ return ESMplusplus_600M()
365
+ else:
366
+ raise ValueError(f"Invalid model name: {model_name}")
367
+
368
+ @property
369
+ def device(self):
370
+ return next(self.parameters()).device
371
+
372
+ def forward(
373
+ self,
374
+ input_ids: torch.Tensor | None = None,
375
+ attention_mask: Optional[torch.Tensor] = None,
376
+ labels: Optional[torch.Tensor] = None,
377
+ output_hidden_states: bool = False,
378
+ ) -> ESMplusplusOutput:
379
+ x = self.embed(input_ids)
380
+ output = self.transformer(x, attention_mask, output_hidden_states)
381
+ x = output.last_hidden_state
382
+ logits = self.sequence_head(x)
383
+ loss = None
384
+ if labels is not None:
385
+ loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
386
+ return ESMplusplusOutput(
387
+ loss=loss,
388
+ logits=logits,
389
+ last_hidden_state=x,
390
+ hidden_states=output.hidden_states,
391
+ )
392
+
393
+
394
+ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
395
+ """
396
+ ESM++ for sequence classification.
397
+ """
398
+ def __init__(self, config: ESMplusplusConfig):
399
+ super().__init__(config)
400
+ self.config = config
401
+ self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
402
+ # we find that large intermediate projections help with sequence classification tasks (*4)
403
+ self.mse = nn.MSELoss()
404
+ self.ce = nn.CrossEntropyLoss()
405
+ self.bce = nn.BCEWithLogitsLoss()
406
+
407
+ def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
408
+ # x: (batch_size, seq_len, hidden_size)
409
+ # attention_mask: (batch_size, seq_len)
410
+ if attention_mask is None:
411
+ return x.mean(dim=1)
412
+ else:
413
+ return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
414
+
415
+ def forward(
416
+ self,
417
+ input_ids: torch.Tensor | None = None,
418
+ attention_mask: Optional[torch.Tensor] = None,
419
+ labels: Optional[torch.Tensor] = None,
420
+ output_hidden_states: bool = False,
421
+ ) -> ESMplusplusOutput:
422
+ output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
423
+ x = output.last_hidden_state
424
+ cls_features = x[:, 0, :]
425
+ mean_features = self.mean_pooling(x, attention_mask)
426
+ # we include mean pooling features to help with early convergence, the cost of this is basically zero
427
+ features = torch.cat([cls_features, mean_features], dim=-1)
428
+ logits = self.classifier(features)
429
+ loss = None
430
+ if labels is not None:
431
+ labels = labels.to(logits.device)
432
+ if self.config.problem_type is None:
433
+ if self.num_labels == 1:
434
+ self.config.problem_type = "regression"
435
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
436
+ self.config.problem_type = "single_label_classification"
437
+ else:
438
+ self.config.problem_type = "multi_label_classification"
439
+
440
+ if self.config.problem_type == "regression":
441
+ if self.num_labels == 1:
442
+ loss = self.mse(logits.squeeze(), labels.squeeze())
443
+ else:
444
+ loss = self.mse(logits, labels)
445
+ elif self.config.problem_type == "single_label_classification":
446
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
447
+ elif self.config.problem_type == "multi_label_classification":
448
+ loss = self.bce(logits, labels)
449
+ return ESMplusplusOutput(
450
+ loss=loss,
451
+ logits=logits,
452
+ last_hidden_state=x,
453
+ hidden_states=output.hidden_states,
454
+ )
455
+
456
+
457
+ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
458
+ """
459
+ ESM++ for token classification.
460
+ """
461
+ def __init__(self, config: ESMplusplusConfig):
462
+ super().__init__(config)
463
+ self.config = config
464
+ self.num_labels = config.num_labels
465
+ self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
466
+ # we find that large intermediate projections help with sequence classification tasks (*4)
467
+ self.loss_fct = nn.CrossEntropyLoss()
468
+
469
+ def forward(
470
+ self,
471
+ input_ids: torch.Tensor | None = None,
472
+ attention_mask: Optional[torch.Tensor] = None,
473
+ labels: Optional[torch.Tensor] = None,
474
+ output_hidden_states: bool = False,
475
+ ) -> ESMplusplusOutput:
476
+ output = super().forward(input_ids, attention_mask, labels, output_hidden_states)
477
+ x = output.last_hidden_state
478
+ logits = self.classifier(x)
479
+ loss = None
480
+ if labels is not None:
481
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
482
+ return ESMplusplusOutput(
483
+ loss=loss,
484
+ logits=logits,
485
+ last_hidden_state=x,
486
+ hidden_states=output.hidden_states,
487
+ )
488
+
489
+
490
+ ### Loading
491
+ import os
492
+ from functools import cache
493
+ from pathlib import Path
494
+ from huggingface_hub import snapshot_download
495
+
496
+
497
+ @staticmethod
498
+ @cache
499
+ def data_root(model: str):
500
+ if "INFRA_PROVIDER" in os.environ:
501
+ return Path("")
502
+ # Try to download from hugginface if it doesn't exist
503
+ if model.startswith("esmc-300"):
504
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
505
+ elif model.startswith("esmc-600"):
506
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
507
+ else:
508
+ raise ValueError(f"{model=} is an invalid model name.")
509
+ return path
510
+
511
+
512
+ def ESMplusplus_300M(device: torch.device | str = "cpu"):
513
+ with torch.device(device):
514
+ config = ESMplusplusConfig(
515
+ hidden_size=960,
516
+ num_attention_heads=15,
517
+ num_hidden_layers=30,
518
+ )
519
+ model = ESMplusplusForMaskedLM(config)
520
+ state_dict = torch.load(
521
+ data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
522
+ map_location=device,
523
+ )
524
+ model.load_state_dict(state_dict)
525
+ return model
526
+
527
+
528
+ def ESMplusplus_600M(device: torch.device | str = "cpu"):
529
+ with torch.device(device):
530
+ config = ESMplusplusConfig(
531
+ hidden_size=1152,
532
+ num_attention_heads=18,
533
+ num_hidden_layers=36,
534
+ )
535
+ model = ESMplusplusForMaskedLM(config)
536
+ state_dict = torch.load(
537
+ data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
538
+ map_location=device,
539
+ )
540
+ model.load_state_dict(state_dict)
541
+ return model
542
+
543
+
544
+ ### Tokenization
545
+ from tokenizers import Tokenizer
546
+ from tokenizers.models import BPE
547
+ from tokenizers.processors import TemplateProcessing
548
+ from transformers import PreTrainedTokenizerFast
549
+
550
+
551
+ SEQUENCE_VOCAB = [
552
+ "<cls>", "<pad>", "<eos>", "<unk>",
553
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
554
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
555
+ "O", ".", "-", "|",
556
+ "<mask>",
557
+ ]
558
+
559
+ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
560
+ model_input_names = ["input_ids", "attention_mask"]
561
+
562
+ def __init__(
563
+ self,
564
+ unk_token="<unk>",
565
+ cls_token="<cls>",
566
+ pad_token="<pad>",
567
+ mask_token="<mask>",
568
+ eos_token="<eos>",
569
+ chain_break_token="|",
570
+ **kwargs,
571
+ ):
572
+ all_tokens = SEQUENCE_VOCAB
573
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
574
+
575
+ # a character-level tokenizer is the same as BPE with no token merges
576
+ bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
577
+ tokenizer = Tokenizer(bpe)
578
+ special_tokens = [
579
+ cls_token,
580
+ pad_token,
581
+ mask_token,
582
+ eos_token,
583
+ chain_break_token,
584
+ ]
585
+ self.cb_token = chain_break_token
586
+ additional_special_tokens = [chain_break_token]
587
+
588
+ tokenizer.add_special_tokens(special_tokens)
589
+
590
+ # This is where we configure the automatic addition of special tokens when we call
591
+ # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
592
+ # sequences are merged if you want.
593
+ tokenizer.post_processor = TemplateProcessing( # type: ignore
594
+ single="<cls> $A <eos>",
595
+ special_tokens=[
596
+ ("<cls>", tokenizer.token_to_id("<cls>")),
597
+ ("<eos>", tokenizer.token_to_id("<eos>")),
598
+ ],
599
+ )
600
+ super().__init__(
601
+ tokenizer_object=tokenizer,
602
+ unk_token=unk_token,
603
+ cls_token=cls_token,
604
+ pad_token=pad_token,
605
+ mask_token=mask_token,
606
+ eos_token=eos_token,
607
+ additional_special_tokens=additional_special_tokens,
608
+ **kwargs,
609
+ )
610
+
611
+ # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
612
+ @property
613
+ def bos_token(self):
614
+ return self.cls_token
615
+
616
+ @property
617
+ def bos_token_id(self):
618
+ return self.cls_token_id
619
+
620
+ @property
621
+ def chain_break_token(self):
622
+ return self.cb_token
623
+
624
+ @property
625
+ def chain_break_token_id(self):
626
+ return self.convert_tokens_to_ids(self.chain_break_token)
627
+
628
+ @property
629
+ def all_token_ids(self):
630
+ return list(range(self.vocab_size))
631
+
632
+ @property
633
+ def special_token_ids(self):
634
+ return self.all_special_ids