Spaces:
Running
Running
| import torch | |
| from tqdm import tqdm | |
| from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer | |
| from torch.nn.functional import normalize | |
| class StructureEncoder(torch.nn.Module): | |
| def __init__(self, config_path: str, out_dim: int, gradient_checkpointing: bool = False): | |
| """ | |
| Args: | |
| config_path: Path to the config file | |
| out_dim: Output dimension of the structure representation | |
| gradient_checkpointing: Whether to use gradient checkpointing | |
| """ | |
| super().__init__() | |
| config = EsmConfig.from_pretrained(config_path) | |
| self.model = EsmForMaskedLM(config) | |
| self.out = torch.nn.Linear(config.hidden_size, out_dim) | |
| # Set gradient checkpointing | |
| self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing | |
| # Remove contact head | |
| self.model.esm.contact_head = None | |
| # Remove position embedding if the embedding type is ``rotary`` | |
| if config.position_embedding_type == "rotary": | |
| self.model.esm.embeddings.position_embeddings = None | |
| self.tokenizer = EsmTokenizer.from_pretrained(config_path) | |
| def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor: | |
| """ | |
| Compute protein structure representation for the given proteins | |
| Args: | |
| protein: A list of protein structural sequences | |
| batch_size: Batch size for inference | |
| verbose: Whether to print progress | |
| """ | |
| device = next(self.parameters()).device | |
| protein_repr = [] | |
| if verbose: | |
| iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings") | |
| else: | |
| iterator = range(0, len(proteins), batch_size) | |
| for i in iterator: | |
| protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size], | |
| return_tensors="pt", | |
| padding=True) | |
| protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()} | |
| output, _ = self.forward(protein_inputs) | |
| protein_repr.append(output) | |
| protein_repr = torch.cat(protein_repr, dim=0) | |
| return normalize(protein_repr, dim=-1) | |
| def forward(self, inputs: dict, get_mask_logits: bool = False): | |
| """ | |
| Encode protein structure into protein representation | |
| Args: | |
| inputs: A dictionary containing the following keys: | |
| - input_ids: [batch, seq_len] | |
| - attention_mask: [batch, seq_len] | |
| get_mask_logits: Whether to return the logits for masked tokens | |
| Returns: | |
| protein_repr: [batch, protein_repr_dim] | |
| mask_logits : [batch, seq_len, vocab_size] | |
| """ | |
| last_hidden_state = self.model.esm(**inputs).last_hidden_state | |
| reprs = last_hidden_state[:, 0, :] | |
| reprs = self.out(reprs) | |
| # Get logits for masked tokens | |
| if get_mask_logits: | |
| mask_logits = self.model.lm_head(last_hidden_state) | |
| else: | |
| mask_logits = None | |
| return reprs, mask_logits |