File size: 1,774 Bytes
197ad47
 
 
 
 
45e9f70
197ad47
45e9f70
 
 
 
 
197ad47
45e9f70
 
197ad47
45e9f70
197ad47
45e9f70
 
 
 
 
 
197ad47
45e9f70
197ad47
45e9f70
 
197ad47
45e9f70
 
 
 
197ad47
 
 
45e9f70
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
---
library_name: transformers
tags: []
---

ESM++ is a faithful implementation of [ESMC](https://www.evolutionaryscale.ai/blog/esm-cambrian) that allows for batching and standard Huggingface compatibility without requiring the ESM package.

Use with transformers
```python
from transformers import AutoModelForMaskedLM #AutoModel also works
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
tokenizer = model.tokenizer

sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, return_tensors='pt')

# tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training

output = model(**tokenized) # get all hidden states with output_hidden_states=True
print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
print(output.last_hidden_state) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)
print(output.loss) # language modeling loss if you passed labels
#print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
```

ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.

```python
from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification

model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
logits = model(**tokenized)
print(logits.shape) # (batch_size, num_labels), (2, 2)
```



Measured difference between this implementation and version loaded with ESM package (1000 random sequences)
Average MSE: 7.742734737803403e-10