import torch import torch.distributed as dist import torchmetrics import json import math import numpy as np import os import copy import faiss import time import pandas as pd import random from tqdm import tqdm from .protein_encoder import ProteinEncoder from .structure_encoder import StructureEncoder from .text_encoder import TextEncoder from ..abstract_model import AbstractModel from ..model_interface import register_model from utils.mpr import MultipleProcessRunnerSimplifier from torch.nn.functional import normalize, cross_entropy from utils.constants import residue_level, sequence_level from sklearn.metrics import roc_auc_score def multilabel_cross_entropy(logits, labels): """ Compute cross entropy loss for multilabel classification。 See "https://arxiv.org/pdf/2208.02955.pdf" Args: logits: [num_samples, num_classes] labels: [num_samples, num_classes] """ loss = 0 for pred, label in zip(logits, labels): pos_logits = pred[label == 1] neg_logits = pred[label == 0] diff = neg_logits.unsqueeze(-1) - pos_logits loss += torch.log(1 + torch.exp(diff).sum()) return loss / len(logits) # pred = (1 - 2 * labels) * logits # pred_neg = pred - labels * 1e12 # pred_pos = pred - (1 - labels) * 1e12 # # zeros = torch.zeros_like(logits[..., :1], dtype=logits.dtype) # pred_neg = torch.cat([pred_neg, zeros], dim=-1) # pred_pos = torch.cat([pred_pos, zeros], dim=-1) # # neg_loss = torch.logsumexp(pred_neg, dim=-1) # pos_loss = torch.logsumexp(pred_pos, dim=-1) # # return (neg_loss + pos_loss).mean() @register_model class ProTrekTrimodalModel(AbstractModel): def __init__(self, protein_config: str, text_config: str, structure_config: str = None, repr_dim: int = 1024, temperature: float = 0.07, load_protein_pretrained: bool = True, load_text_pretrained: bool = True, use_mlm_loss: bool = False, use_zlpr_loss: bool = False, use_saprot: bool = False, gradient_checkpointing: bool = False, **kwargs): """ Args: protein_config: Path to the config file for protein sequence encoder text_config: Path to the config file for text encoder structure_config: Path to the config file for structure encoder repr_dim: Output dimension of the protein and text representation temperature: Temperature for softmax load_protein_pretrained: Whether to load pretrained weights for protein encoder load_text_pretrained: Whether to load pretrained weights for text encoder use_mlm_loss: Whether to use masked language modeling loss use_zlpr_loss: Whether to use zlpr loss. See "https://arxiv.org/pdf/2208.02955.pdf" use_saprot: Whether to use SaProt as protein encoder gradient_checkpointing: Whether to use gradient checkpointing for protein encoder """ self.protein_config = protein_config self.structure_config = structure_config self.text_config = text_config self.repr_dim = repr_dim self.temperature = temperature self.load_protein_pretrained = load_protein_pretrained self.load_text_pretrained = load_text_pretrained self.use_mlm_loss = use_mlm_loss self.use_zlpr_loss = use_zlpr_loss self.use_saprot = use_saprot self.gradient_checkpointing = gradient_checkpointing super().__init__(**kwargs) def initialize_metrics(self, stage: str) -> dict: return_dict = { f"{stage}_protein_text_acc": torchmetrics.Accuracy(), f"{stage}_text_protein_acc": torchmetrics.Accuracy(), } if self.use_mlm_loss: return_dict[f"{stage}_protein_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1) if self.structure_config is not None: return_dict[f"{stage}_structure_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1) if self.structure_config is not None: return_dict[f"{stage}_structure_protein_acc"] = torchmetrics.Accuracy() return_dict[f"{stage}_structure_text_acc"] = torchmetrics.Accuracy() return_dict[f"{stage}_text_structure_acc"] = torchmetrics.Accuracy() return_dict[f"{stage}_protein_structure_acc"] = torchmetrics.Accuracy() return return_dict def initialize_model(self): # Initialize encoders self.protein_encoder = ProteinEncoder(self.protein_config, self.repr_dim, self.load_protein_pretrained, self.gradient_checkpointing) self.text_encoder = TextEncoder(self.text_config, self.repr_dim, self.load_text_pretrained, self.gradient_checkpointing) # Learnable temperature self.temperature = torch.nn.Parameter(torch.tensor(self.temperature)) # self.model is used for saving and loading self.model = torch.nn.ParameterList([self.temperature, self.protein_encoder, self.text_encoder]) # If the structure encoder is specified if self.structure_config is not None: self.structure_encoder = StructureEncoder(self.structure_config, self.repr_dim) self.model.append(self.structure_encoder) def get_text_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor: return self.text_encoder.get_repr(texts, batch_size, verbose) def get_structure_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor: return self.structure_encoder.get_repr(proteins, batch_size, verbose) def get_protein_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor: return self.protein_encoder.get_repr(proteins, batch_size, verbose) def forward(self, protein_inputs: dict, text_inputs: dict, structure_inputs: dict = None): """ Args: protein_inputs: A dictionary for protein encoder structure_inputs: A dictionary for structure encoder text_inputs : A dictionary for text encoder """ protein_repr, protein_mask_logits = self.protein_encoder(protein_inputs, self.use_mlm_loss) text_repr = self.text_encoder(text_inputs) outputs = [text_repr, protein_repr, protein_mask_logits] if self.structure_config is not None: structure_repr, structure_mask_logits = self.structure_encoder(structure_inputs, self.use_mlm_loss) outputs += [structure_repr, structure_mask_logits] return outputs def loss_func(self, stage: str, outputs, labels): if self.structure_config is not None: text_repr, protein_repr, protein_mask_logits, structure_repr, structure_mask_logits = outputs else: text_repr, protein_repr, protein_mask_logits = outputs device = text_repr.device text_repr = normalize(text_repr, dim=-1) protein_repr = normalize(protein_repr, dim=-1) # Gather representations from all GPUs all_protein_repr = self.all_gather(protein_repr).view(-1, protein_repr.shape[-1]).detach() all_text_repr = self.all_gather(text_repr).view(-1, text_repr.shape[-1]).detach() if self.structure_config is not None: structure_repr = normalize(structure_repr, dim=-1) all_structure_repr = self.all_gather(structure_repr).view(-1, structure_repr.shape[-1]).detach() # text_idx = labels["text_idx"] # text_candidates = labels["text_candidates"] # # # Gather all text ids # text_inds = self.all_gather(text_idx).flatten() # # Create text classification labels # text_labels = torch.zeros(len(text_candidates), len(text_inds), dtype=int).to(device) # for i, candidate in enumerate(text_candidates): # for j, idx in enumerate(text_inds): # if idx.item() in candidate: # text_labels[i, j] = 1 # # # Gather text labels from all GPUs # text_labels = self.all_gather(text_labels).view(-1, text_labels.shape[-1]) # # # Protein classification labels are the transpose of text labels # protein_labels = text_labels.T # Batch size rank = dist.get_rank() bs = text_repr.shape[0] # Get current labels # protein_labels = protein_labels[rank * bs: rank * bs + bs] # text_labels = text_labels[rank * bs: rank * bs + bs] # Create classification labels between structure and sequence bs_labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(device) if self.structure_config is not None: pairs = { "protein": ["structure", "text"], "structure": ["protein", "text"], "text": ["protein", "structure"] } else: pairs = { "protein": ["text"], "text": ["protein"] } loss_list = [] for k, values in pairs.items(): for v in values: # Only calculate the similarity for the current batch sim = torch.matmul(eval(f"{k}_repr"), eval(f"all_{v}_repr").T).div(self.temperature) # if k == "text": # if self.use_zlpr_loss: # loss = multilabel_cross_entropy(sim, protein_labels) # else: # loss = cross_entropy(sim, bs_labels) # # pred = [] # for s, l in zip(sim, protein_labels): # n_label = l.sum() # topk = torch.topk(s, k=n_label).indices # if l[topk].sum() == n_label: # pred.append(1) # else: # pred.append(0) # # pred = torch.tensor(pred).to(device) # label = torch.ones_like(pred) # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label) # # if v == "protein": # # acc = self.metrics[stage][f"{stage}_{k}_{v}_acc"].compute() # # print(f"{stage}_{k}_{v}_acc: {acc:.4f}") # # elif v == "text": # if self.use_zlpr_loss: # loss = multilabel_cross_entropy(sim, text_labels) # else: # loss = cross_entropy(sim, bs_labels) # # pred = [] # for s, l in zip(sim, text_labels): # n_label = l.sum() # topk = torch.topk(s, k=n_label).indices # if l[topk].sum() == n_label: # pred.append(1) # else: # pred.append(0) # # pred = torch.tensor(pred).to(device) # label = torch.ones_like(pred) # # if k == "protein": # # acc = pred.sum() / len(pred) # # print(f"{stage}_{k}_{v}_acc: {acc:.4f}") # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label) # # else: # loss = cross_entropy(sim, bs_labels) # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels) loss = cross_entropy(sim, bs_labels) self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels) loss_list.append(loss) # Masked language modeling loss if self.use_mlm_loss: k_label = [("protein", labels["seq_labels"])] if self.structure_config is not None: k_label.append(("structure", labels["struc_labels"])) for k, label in k_label: logits = eval(f"{k}_mask_logits") # merge the first and second dimension of logits logits = logits.view(-1, logits.shape[-1]) label = label.flatten().to(device) mlm_loss = cross_entropy(logits, label, ignore_index=-1) loss_list.append(mlm_loss) self.metrics[stage][f"{stage}_{k}_mask_acc"].update(logits.detach(), label) loss = sum(loss_list) / len(loss_list) if stage == "train": log_dict = self.get_log_dict("train") log_dict["train_loss"] = loss self.log_info(log_dict) # Reset train metrics self.reset_metrics("train") return loss def padded_gather(self, tensor: torch.Tensor): """ Gather tensors from all GPUs, allowing different shapes at the batch dimension. """ # Get the size of the tensor size = tensor.shape[0] all_sizes = self.all_gather(torch.tensor(size, device=tensor.device)) max_size = max(all_sizes) # Pad the tensor if size != max_size: tmp = torch.zeros(max_size, tensor.shape[-1], dtype=tensor.dtype, device=tensor.device) tmp[:size] = tensor tensor = tmp padded_tensor = self.all_gather(tensor).view(-1, tensor.shape[-1]) tensor = padded_tensor[:sum(all_sizes)] return tensor def _get_protein_indices(self): world_size = dist.get_world_size() rank = dist.get_rank() if self.use_saprot: proteins = [] for sub_dict in self.uniprot2label.values(): aa_seq = sub_dict["seq"] foldseek_seq = sub_dict["foldseek"] assert len(aa_seq) == len(foldseek_seq) seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)]) proteins.append(seq) else: proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()] span = math.ceil(len(proteins) / world_size) sub_proteins = proteins[rank * span: (rank + 1) * span] # Display the progress bar on the rank 0 process verbose = self.trainer.local_rank == 0 # Get protein representations sub_protein_repr = self.protein_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose) protein_repr = self.padded_gather(sub_protein_repr) # Construct faiss index d = protein_repr.shape[-1] protein_indices = faiss.IndexFlatIP(d) protein_indices.add(protein_repr.cpu().numpy()) return protein_indices def _get_structure_indices(self): world_size = dist.get_world_size() rank = dist.get_rank() proteins = [sub_dict["foldseek"] for sub_dict in self.uniprot2label.values()] span = math.ceil(len(proteins) / world_size) sub_proteins = proteins[rank * span: (rank + 1) * span] # Display the progress bar on the rank 0 process verbose = self.trainer.local_rank == 0 # Get protein representations sub_protein_repr = self.structure_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose) protein_repr = self.padded_gather(sub_protein_repr) # Construct faiss index d = protein_repr.shape[-1] structure_indices = faiss.IndexFlatIP(d) structure_indices.add(protein_repr.cpu().numpy()) return structure_indices def _get_text_indices(self): world_size = dist.get_world_size() rank = dist.get_rank() # Display the progress bar on the rank 0 process verbose = self.trainer.local_rank == 0 if verbose: iterator = tqdm(self.label2text.keys(), desc="Get text representations") else: iterator = self.label2text.keys() text_embeddings = {} for subsection in iterator: if subsection == "Total": continue texts = [] for text_list in self.label2text[subsection].values(): # Only use the first text for efficiency texts.append(text_list[0:1]) span = math.ceil(len(texts) / world_size) texts = texts[rank * span: (rank + 1) * span] embeddings = [] for text_list in texts: text_repr = self.text_encoder.get_repr(text_list) mean_repr = text_repr.mean(dim=0, keepdim=True) norm_repr = torch.nn.functional.normalize(mean_repr, dim=-1) embeddings.append(norm_repr) if len(embeddings) > 0: embeddings = torch.cat(embeddings, dim=0) else: embeddings = torch.zeros(0, self.repr_dim, dtype=self.dtype, device=self.device) text_repr = self.padded_gather(embeddings) text_embeddings[subsection] = text_repr # Aggregate text embeddings for global retrieval total_embeddings = [] for idx in self.label2text["Total"].values(): subsection, i = idx.split("|") total_embeddings.append(text_embeddings[subsection][int(i)]) text_embeddings["Total"] = torch.stack(total_embeddings) # Construct faiss index text_indices = {} for subsection, text_repr in text_embeddings.items(): d = text_repr.shape[-1] text_indices[subsection] = faiss.IndexFlatIP(d) text_indices[subsection].add(text_repr.cpu().numpy()) return text_indices def _protein2text(self, modality: str, protein_indices, text_indices: dict): def do(process_id, idx, row, writer): subsection, uniprot_id, prob_idx, label = row # Retrieve ranking results p_embedding = protein_indices.reconstruct(prob_idx).reshape(1, -1) text_inds = text_indices[subsection] sim_scores, rank_inds = text_inds.search(p_embedding, text_inds.ntotal) sim_scores, rank_inds = sim_scores[0], rank_inds[0] # Calculate Average Precision(AP) ranks = [] label = set(label) for i, rk in enumerate(rank_inds): # Find the rank of this label in all labels if rk in label: ranks.append(i + 1) ranks = np.array(ranks) ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)]) # Calculate Mean Reciprocal Rank(MRR) best_rank = ranks[0] mrr = 1 / best_rank # Calculate the AUC true_labels = np.zeros_like(sim_scores) true_labels[ranks - 1] = 1 if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]: auc = 0 else: auc = roc_auc_score(true_labels, sim_scores) output = json.dumps([ap, mrr, auc]) writer.write(output + "\n") inputs = [] swissprot_subsections = set() for subsection in text_indices.keys(): for i, (uniprot_id, labels) in enumerate(self.uniprot2label.items()): if uniprot_id in self.swissprot_ids: if subsection in labels: swissprot_subsections.add(subsection) label = labels[subsection] inputs.append((subsection, uniprot_id, i, label)) # Randomly shuffle the inputs random.seed(20000812) random.shuffle(inputs) # Split inputs into chunks for parallel processing world_size = dist.get_world_size() rank = dist.get_rank() span = math.ceil(len(inputs) / world_size) sub_inputs = inputs[rank * span: (rank + 1) * span] # Display the progress bar on the rank 0 process verbose = self.trainer.local_rank == 0 if verbose: print("Evaluating on each subsection...") tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv" mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose, return_results=True) outputs = mpr.run() os.remove(tmp_path) # Aggregate results tensor_outputs = [] for output in outputs: ap, mrr, auc = json.loads(output) tensor_outputs.append([float(ap), float(mrr), float(auc)]) tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device) tensor_outputs = self.padded_gather(tensor_outputs) # Record results avg_results = {} for subsection in swissprot_subsections: avg_results[subsection] = {"map": [], "mrr": [], "auc": []} for input, output in zip(inputs, tensor_outputs): ap, mrr, auc = output subsection, _, _, _ = input avg_results[subsection]["map"].append(ap.cpu().item()) avg_results[subsection]["mrr"].append(mrr.cpu().item()) avg_results[subsection]["auc"].append(auc.cpu().item()) results = { f"{modality}2Text_Total_mrr": np.mean(avg_results["Total"]["mrr"]), f"{modality}2Text_Total_map": np.mean(avg_results["Total"]["map"]), f"{modality}2Text_Total_auc": np.mean(avg_results["Total"]["auc"]), } # Average the precision and recall for each level for level, labels in [("residue-level", residue_level), ("sequence-level", sequence_level), ("all", residue_level | sequence_level)]: mrrs = [] maps = [] aucs = [] for subsection in labels: if subsection in avg_results: mrrs.append(np.mean(avg_results[subsection]["mrr"])) maps.append(np.mean(avg_results[subsection]["map"])) aucs.append(np.mean(avg_results[subsection]["auc"])) results[f"{modality}2Text_{level}_mrr"] = np.mean(mrrs) results[f"{modality}2Text_{level}_map"] = np.mean(maps) results[f"{modality}2Text_{level}_auc"] = np.mean(aucs) return results def _text2protein(self, modality: str, protein_indices, text_indices: dict): def do(process_id, idx, row, writer): subsection, text_id, label = row # Retrieve ranking results t_embedding = text_indices[subsection].reconstruct(text_id).reshape(1, -1) sim_scores, rank_inds = protein_indices.search(t_embedding, protein_indices.ntotal) sim_scores, rank_inds = sim_scores[0], rank_inds[0] # Calculate Average Precision(AP) ranks = [] label = set(label) for i, rk in enumerate(rank_inds): # Find the rank of this label in all labels if rk in label: ranks.append(i + 1) ranks = np.array(ranks) ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)]) # Calculate Mean Reciprocal Rank(MRR) best_rank = ranks[0] mrr = 1 / best_rank # Calculate the AUC true_labels = np.zeros_like(sim_scores) true_labels[ranks - 1] = 1 if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]: auc = 0 else: auc = roc_auc_score(true_labels, sim_scores) output = json.dumps([ap, mrr, auc]) writer.write(output + "\n") text2label = {} swissprot_subsections = set() for i, (uniprot_id, subsections) in enumerate(self.uniprot2label.items()): # Only evaluate the texts in Swiss-Prot if uniprot_id not in self.swissprot_ids: continue for subsection, text_ids in subsections.items(): if subsection == "seq" or subsection == "foldseek": continue swissprot_subsections.add(subsection) if subsection not in text2label: text2label[subsection] = {} for text_id in text_ids: text2label[subsection][text_id] = text2label[subsection].get(text_id, []) + [i] inputs = [] for subsection in swissprot_subsections: for i, (text_id, label) in enumerate(text2label[subsection].items()): inputs.append((subsection, text_id, label)) # Randomly shuffle the inputs random.seed(20000812) random.shuffle(inputs) # Split inputs into chunks for parallel processing world_size = dist.get_world_size() rank = dist.get_rank() span = math.ceil(len(inputs) / world_size) sub_inputs = inputs[rank * span: (rank + 1) * span] # Display the progress bar on the rank 0 process verbose = self.trainer.local_rank == 0 if verbose: print("Evaluating on each text...") # Add time stamp to the temporary file name to avoid conflicts tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv" mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose, return_results=True) outputs = mpr.run() os.remove(tmp_path) # Aggregate results tensor_outputs = [] for output in outputs: ap, mrr, auc = json.loads(output) tensor_outputs.append([float(ap), float(mrr), float(auc)]) tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device) tensor_outputs = self.padded_gather(tensor_outputs) # Record results avg_results = {} for subsection in swissprot_subsections: avg_results[subsection] = {"map": [], "mrr": [], "auc": []} for input, output in zip(inputs, tensor_outputs): ap, mrr, auc = output subsection, _, _ = input avg_results[subsection]["map"].append(ap.cpu().item()) avg_results[subsection]["mrr"].append(mrr.cpu().item()) avg_results[subsection]["auc"].append(auc.cpu().item()) results = { f"Text2{modality}_Total_mrr": np.mean(avg_results["Total"]["mrr"]), f"Text2{modality}_Total_map": np.mean(avg_results["Total"]["map"]), f"Text2{modality}_Total_auc": np.mean(avg_results["Total"]["auc"]), } # Average the precision and recall for each level for level, labels in [("residue-level", residue_level), ("sequence-level", sequence_level), ("all", residue_level | sequence_level)]: mrrs = [] maps = [] aucs = [] for subsection in labels: if subsection in avg_results: mrrs.append(np.mean(avg_results[subsection]["mrr"])) maps.append(np.mean(avg_results[subsection]["map"])) aucs.append(np.mean(avg_results[subsection]["auc"])) results[f"Text2{modality}_{level}_mrr"] = np.mean(mrrs) results[f"Text2{modality}_{level}_map"] = np.mean(maps) results[f"Text2{modality}_{level}_auc"] = np.mean(aucs) return results def retrieval_eval(self) -> dict: # Get protein representations protein_indices = self._get_protein_indices() # Get structure representations # if self.structure_config is not None: # structure_embeddings = self._get_structure_embeddings() # Get text representations text_indices = self._get_text_indices() # Retrieve texts for each protein results = {} results.update(self._protein2text("Sequence", protein_indices, text_indices)) # if self.structure_config is not None: # results.update(self._protein2text("Structure", structure_embeddings, text_embeddings)) # results.update(self._text2protein("Structure", structure_embeddings, text_embeddings)) # Retrieve proteins for each text results.update(self._text2protein("Sequence", protein_indices, text_indices)) return results def _apply_bert_mask(self, tokens, tokenizer, mask_ratio): while True: masked_tokens = copy.copy(tokens) labels = torch.full((len(tokens) + 2,), -1, dtype=torch.long) vocab = [k for k in tokenizer.get_vocab().keys()] for i in range(len(tokens)): token = tokens[i] prob = random.random() if prob < mask_ratio: prob /= mask_ratio labels[i + 1] = tokenizer.convert_tokens_to_ids(token) if prob < 0.8: # 80% random change to mask token if self.use_saprot: token = "#" + token[-1] else: token = tokenizer.mask_token elif prob < 0.9: # 10% chance to change to random token token = random.choice(vocab) else: # 10% chance to keep current token pass masked_tokens[i] = token # Check if there is at least one masked token if (labels != -1).any(): return masked_tokens, labels def mlm_eval(self) -> float: world_size = dist.get_world_size() rank = dist.get_rank() if self.use_saprot: proteins = [] for sub_dict in self.uniprot2label.values(): aa_seq = sub_dict["seq"] foldseek_seq = sub_dict["foldseek"] assert len(aa_seq) == len(foldseek_seq) seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)]) proteins.append(seq) else: proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()] span = math.ceil(len(proteins) / world_size) sub_proteins = proteins[rank * span: (rank + 1) * span] # Display the progress bar on the rank 0 process if self.trainer.local_rank == 0: iterator = tqdm(sub_proteins, desc="Computing mlm...") else: iterator = sub_proteins total = torch.tensor([0], dtype=torch.long, device=self.device) correct = torch.tensor([0], dtype=torch.long, device=self.device) for seq in iterator: tokens = self.protein_encoder.tokenizer.tokenize(seq) masked_tokens, labels = self._apply_bert_mask(tokens, self.protein_encoder.tokenizer, 0.15) seq = " ".join(masked_tokens) inputs = self.protein_encoder.tokenizer(seq, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} _, logits = self.protein_encoder(inputs, get_mask_logits=True) logits = logits.squeeze(0) labels = labels.to(self.device) selecor = labels != -1 preds = logits.argmax(dim=-1)[selecor] labels = labels[selecor] total += len(preds) correct += (preds == labels).sum() # Gather all results total = self.padded_gather(total).sum() correct = self.padded_gather(correct).sum() acc = correct / total return acc.cpu().item() def _load_eval_data(self, stage): # Load the data lmdb_dir = eval(f"self.trainer.datamodule.{stage}_lmdb") uniprot2label_path = os.path.join(lmdb_dir, "uniprot2label.json") label2text_path = os.path.join(lmdb_dir, "label2text.json") swissprot_id_path = os.path.join(lmdb_dir, "swissprot_ids.tsv") self.uniprot2label = json.load(open(uniprot2label_path, "r")) self.label2text = json.load(open(label2text_path, "r")) self.swissprot_ids = set(pd.read_csv(swissprot_id_path, sep="\t", header=None).values.flatten().tolist()) self.k = 3 def on_test_start(self): self._load_eval_data("test") log_dict = self.retrieval_eval() log_dict = {"test_" + k: v for k, v in log_dict.items()} if self.use_mlm_loss: log_dict["test_mask_acc"] = self.mlm_eval() self.log_info(log_dict) print(log_dict) def on_validation_start(self): # Clear the cache torch.cuda.empty_cache() self._load_eval_data("valid") log_dict = self.retrieval_eval() log_dict = {"valid_" + k: v for k, v in log_dict.items()} if self.use_mlm_loss: log_dict["valid_mask_acc"] = self.mlm_eval() self.log_info(log_dict) self.check_save_condition(self.step, mode="max") def test_step(self, batch, batch_idx): return def validation_step(self, batch, batch_idx): return def on_train_epoch_end(self): super().on_train_epoch_end() # Re-sample the subset of the training data if self.trainer.datamodule.train_dataset.fixed_dataset_num is not None: self.trainer.datamodule.train_dataset.sample_subset() # def test_epoch_end(self, outputs): # log_dict = self.get_log_dict("test") # log_dict["test_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean() # # print(log_dict) # self.log_info(log_dict) # # self.reset_metrics("test") # # def validation_epoch_end(self, outputs): # log_dict = self.get_log_dict("valid") # log_dict["valid_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean() # # self.log_info(log_dict) # self.reset_metrics("valid") # self.check_save_condition(log_dict["valid_loss"], mode="min")