|
import torch
|
|
import torch.distributed as dist
|
|
import torchmetrics
|
|
import json
|
|
import math
|
|
import numpy as np
|
|
import os
|
|
import copy
|
|
import faiss
|
|
import time
|
|
import pandas as pd
|
|
import random
|
|
|
|
from tqdm import tqdm
|
|
from .protein_encoder import ProteinEncoder
|
|
from .structure_encoder import StructureEncoder
|
|
from .text_encoder import TextEncoder
|
|
from ..abstract_model import AbstractModel
|
|
from ..model_interface import register_model
|
|
from utils.mpr import MultipleProcessRunnerSimplifier
|
|
from torch.nn.functional import normalize, cross_entropy
|
|
from utils.constants import residue_level, sequence_level
|
|
from sklearn.metrics import roc_auc_score
|
|
|
|
|
|
def multilabel_cross_entropy(logits, labels):
|
|
"""
|
|
Compute cross entropy loss for multilabel classification。 See "https://arxiv.org/pdf/2208.02955.pdf"
|
|
Args:
|
|
logits: [num_samples, num_classes]
|
|
labels: [num_samples, num_classes]
|
|
"""
|
|
|
|
loss = 0
|
|
for pred, label in zip(logits, labels):
|
|
pos_logits = pred[label == 1]
|
|
neg_logits = pred[label == 0]
|
|
|
|
diff = neg_logits.unsqueeze(-1) - pos_logits
|
|
loss += torch.log(1 + torch.exp(diff).sum())
|
|
|
|
return loss / len(logits)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
class ProTrekTrimodalModel(AbstractModel):
|
|
def __init__(self,
|
|
protein_config: str,
|
|
text_config: str,
|
|
structure_config: str = None,
|
|
repr_dim: int = 1024,
|
|
temperature: float = 0.07,
|
|
load_protein_pretrained: bool = True,
|
|
load_text_pretrained: bool = True,
|
|
use_mlm_loss: bool = False,
|
|
use_zlpr_loss: bool = False,
|
|
use_saprot: bool = False,
|
|
gradient_checkpointing: bool = False,
|
|
**kwargs):
|
|
"""
|
|
Args:
|
|
protein_config: Path to the config file for protein sequence encoder
|
|
|
|
text_config: Path to the config file for text encoder
|
|
|
|
structure_config: Path to the config file for structure encoder
|
|
|
|
repr_dim: Output dimension of the protein and text representation
|
|
|
|
temperature: Temperature for softmax
|
|
|
|
load_protein_pretrained: Whether to load pretrained weights for protein encoder
|
|
|
|
load_text_pretrained: Whether to load pretrained weights for text encoder
|
|
|
|
use_mlm_loss: Whether to use masked language modeling loss
|
|
|
|
use_zlpr_loss: Whether to use zlpr loss. See "https://arxiv.org/pdf/2208.02955.pdf"
|
|
|
|
use_saprot: Whether to use SaProt as protein encoder
|
|
|
|
gradient_checkpointing: Whether to use gradient checkpointing for protein encoder
|
|
"""
|
|
self.protein_config = protein_config
|
|
self.structure_config = structure_config
|
|
self.text_config = text_config
|
|
self.repr_dim = repr_dim
|
|
self.temperature = temperature
|
|
self.load_protein_pretrained = load_protein_pretrained
|
|
self.load_text_pretrained = load_text_pretrained
|
|
self.use_mlm_loss = use_mlm_loss
|
|
self.use_zlpr_loss = use_zlpr_loss
|
|
self.use_saprot = use_saprot
|
|
self.gradient_checkpointing = gradient_checkpointing
|
|
super().__init__(**kwargs)
|
|
|
|
def initialize_metrics(self, stage: str) -> dict:
|
|
return_dict = {
|
|
f"{stage}_protein_text_acc": torchmetrics.Accuracy(),
|
|
f"{stage}_text_protein_acc": torchmetrics.Accuracy(),
|
|
}
|
|
|
|
if self.use_mlm_loss:
|
|
return_dict[f"{stage}_protein_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
|
|
if self.structure_config is not None:
|
|
return_dict[f"{stage}_structure_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
|
|
|
|
if self.structure_config is not None:
|
|
return_dict[f"{stage}_structure_protein_acc"] = torchmetrics.Accuracy()
|
|
return_dict[f"{stage}_structure_text_acc"] = torchmetrics.Accuracy()
|
|
return_dict[f"{stage}_text_structure_acc"] = torchmetrics.Accuracy()
|
|
return_dict[f"{stage}_protein_structure_acc"] = torchmetrics.Accuracy()
|
|
|
|
return return_dict
|
|
|
|
def initialize_model(self):
|
|
|
|
self.protein_encoder = ProteinEncoder(self.protein_config,
|
|
self.repr_dim,
|
|
self.load_protein_pretrained,
|
|
self.gradient_checkpointing)
|
|
|
|
self.text_encoder = TextEncoder(self.text_config,
|
|
self.repr_dim,
|
|
self.load_text_pretrained,
|
|
self.gradient_checkpointing)
|
|
|
|
|
|
self.temperature = torch.nn.Parameter(torch.tensor(self.temperature))
|
|
|
|
|
|
self.model = torch.nn.ParameterList([self.temperature,
|
|
self.protein_encoder,
|
|
self.text_encoder])
|
|
|
|
|
|
if self.structure_config is not None:
|
|
self.structure_encoder = StructureEncoder(self.structure_config, self.repr_dim)
|
|
self.model.append(self.structure_encoder)
|
|
|
|
def get_text_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
|
return self.text_encoder.get_repr(texts, batch_size, verbose)
|
|
|
|
def get_structure_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
|
return self.structure_encoder.get_repr(proteins, batch_size, verbose)
|
|
|
|
def get_protein_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
|
return self.protein_encoder.get_repr(proteins, batch_size, verbose)
|
|
|
|
def forward(self, protein_inputs: dict, text_inputs: dict, structure_inputs: dict = None):
|
|
"""
|
|
Args:
|
|
protein_inputs: A dictionary for protein encoder
|
|
structure_inputs: A dictionary for structure encoder
|
|
text_inputs : A dictionary for text encoder
|
|
"""
|
|
protein_repr, protein_mask_logits = self.protein_encoder(protein_inputs, self.use_mlm_loss)
|
|
text_repr = self.text_encoder(text_inputs)
|
|
|
|
outputs = [text_repr, protein_repr, protein_mask_logits]
|
|
|
|
if self.structure_config is not None:
|
|
structure_repr, structure_mask_logits = self.structure_encoder(structure_inputs, self.use_mlm_loss)
|
|
outputs += [structure_repr, structure_mask_logits]
|
|
|
|
return outputs
|
|
|
|
def loss_func(self, stage: str, outputs, labels):
|
|
if self.structure_config is not None:
|
|
text_repr, protein_repr, protein_mask_logits, structure_repr, structure_mask_logits = outputs
|
|
else:
|
|
text_repr, protein_repr, protein_mask_logits = outputs
|
|
|
|
device = text_repr.device
|
|
|
|
text_repr = normalize(text_repr, dim=-1)
|
|
protein_repr = normalize(protein_repr, dim=-1)
|
|
|
|
|
|
all_protein_repr = self.all_gather(protein_repr).view(-1, protein_repr.shape[-1]).detach()
|
|
all_text_repr = self.all_gather(text_repr).view(-1, text_repr.shape[-1]).detach()
|
|
|
|
if self.structure_config is not None:
|
|
structure_repr = normalize(structure_repr, dim=-1)
|
|
all_structure_repr = self.all_gather(structure_repr).view(-1, structure_repr.shape[-1]).detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rank = dist.get_rank()
|
|
bs = text_repr.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bs_labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(device)
|
|
|
|
if self.structure_config is not None:
|
|
pairs = {
|
|
"protein": ["structure", "text"],
|
|
"structure": ["protein", "text"],
|
|
"text": ["protein", "structure"]
|
|
}
|
|
else:
|
|
pairs = {
|
|
"protein": ["text"],
|
|
"text": ["protein"]
|
|
}
|
|
|
|
loss_list = []
|
|
for k, values in pairs.items():
|
|
for v in values:
|
|
|
|
sim = torch.matmul(eval(f"{k}_repr"), eval(f"all_{v}_repr").T).div(self.temperature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = cross_entropy(sim, bs_labels)
|
|
self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
|
|
loss_list.append(loss)
|
|
|
|
|
|
if self.use_mlm_loss:
|
|
k_label = [("protein", labels["seq_labels"])]
|
|
if self.structure_config is not None:
|
|
k_label.append(("structure", labels["struc_labels"]))
|
|
|
|
for k, label in k_label:
|
|
logits = eval(f"{k}_mask_logits")
|
|
|
|
logits = logits.view(-1, logits.shape[-1])
|
|
label = label.flatten().to(device)
|
|
mlm_loss = cross_entropy(logits, label, ignore_index=-1)
|
|
loss_list.append(mlm_loss)
|
|
self.metrics[stage][f"{stage}_{k}_mask_acc"].update(logits.detach(), label)
|
|
|
|
loss = sum(loss_list) / len(loss_list)
|
|
|
|
if stage == "train":
|
|
log_dict = self.get_log_dict("train")
|
|
log_dict["train_loss"] = loss
|
|
self.log_info(log_dict)
|
|
|
|
|
|
self.reset_metrics("train")
|
|
|
|
return loss
|
|
|
|
def padded_gather(self, tensor: torch.Tensor):
|
|
"""
|
|
Gather tensors from all GPUs, allowing different shapes at the batch dimension.
|
|
"""
|
|
|
|
|
|
size = tensor.shape[0]
|
|
all_sizes = self.all_gather(torch.tensor(size, device=tensor.device))
|
|
max_size = max(all_sizes)
|
|
|
|
|
|
if size != max_size:
|
|
tmp = torch.zeros(max_size, tensor.shape[-1], dtype=tensor.dtype, device=tensor.device)
|
|
tmp[:size] = tensor
|
|
tensor = tmp
|
|
|
|
padded_tensor = self.all_gather(tensor).view(-1, tensor.shape[-1])
|
|
tensor = padded_tensor[:sum(all_sizes)]
|
|
|
|
return tensor
|
|
|
|
def _get_protein_indices(self):
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
|
|
if self.use_saprot:
|
|
proteins = []
|
|
for sub_dict in self.uniprot2label.values():
|
|
aa_seq = sub_dict["seq"]
|
|
foldseek_seq = sub_dict["foldseek"]
|
|
assert len(aa_seq) == len(foldseek_seq)
|
|
seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
|
|
proteins.append(seq)
|
|
|
|
else:
|
|
proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
|
|
|
|
span = math.ceil(len(proteins) / world_size)
|
|
sub_proteins = proteins[rank * span: (rank + 1) * span]
|
|
|
|
|
|
verbose = self.trainer.local_rank == 0
|
|
|
|
sub_protein_repr = self.protein_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
|
|
protein_repr = self.padded_gather(sub_protein_repr)
|
|
|
|
|
|
d = protein_repr.shape[-1]
|
|
protein_indices = faiss.IndexFlatIP(d)
|
|
protein_indices.add(protein_repr.cpu().numpy())
|
|
return protein_indices
|
|
|
|
def _get_structure_indices(self):
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
|
|
proteins = [sub_dict["foldseek"] for sub_dict in self.uniprot2label.values()]
|
|
span = math.ceil(len(proteins) / world_size)
|
|
sub_proteins = proteins[rank * span: (rank + 1) * span]
|
|
|
|
|
|
verbose = self.trainer.local_rank == 0
|
|
|
|
sub_protein_repr = self.structure_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
|
|
protein_repr = self.padded_gather(sub_protein_repr)
|
|
|
|
|
|
d = protein_repr.shape[-1]
|
|
structure_indices = faiss.IndexFlatIP(d)
|
|
structure_indices.add(protein_repr.cpu().numpy())
|
|
return structure_indices
|
|
|
|
def _get_text_indices(self):
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
|
|
|
|
verbose = self.trainer.local_rank == 0
|
|
if verbose:
|
|
iterator = tqdm(self.label2text.keys(), desc="Get text representations")
|
|
else:
|
|
iterator = self.label2text.keys()
|
|
|
|
text_embeddings = {}
|
|
for subsection in iterator:
|
|
if subsection == "Total":
|
|
continue
|
|
|
|
texts = []
|
|
for text_list in self.label2text[subsection].values():
|
|
|
|
texts.append(text_list[0:1])
|
|
|
|
span = math.ceil(len(texts) / world_size)
|
|
texts = texts[rank * span: (rank + 1) * span]
|
|
embeddings = []
|
|
for text_list in texts:
|
|
text_repr = self.text_encoder.get_repr(text_list)
|
|
mean_repr = text_repr.mean(dim=0, keepdim=True)
|
|
norm_repr = torch.nn.functional.normalize(mean_repr, dim=-1)
|
|
embeddings.append(norm_repr)
|
|
|
|
if len(embeddings) > 0:
|
|
embeddings = torch.cat(embeddings, dim=0)
|
|
else:
|
|
embeddings = torch.zeros(0, self.repr_dim, dtype=self.dtype, device=self.device)
|
|
|
|
text_repr = self.padded_gather(embeddings)
|
|
text_embeddings[subsection] = text_repr
|
|
|
|
|
|
total_embeddings = []
|
|
for idx in self.label2text["Total"].values():
|
|
subsection, i = idx.split("|")
|
|
total_embeddings.append(text_embeddings[subsection][int(i)])
|
|
|
|
text_embeddings["Total"] = torch.stack(total_embeddings)
|
|
|
|
|
|
text_indices = {}
|
|
for subsection, text_repr in text_embeddings.items():
|
|
d = text_repr.shape[-1]
|
|
text_indices[subsection] = faiss.IndexFlatIP(d)
|
|
text_indices[subsection].add(text_repr.cpu().numpy())
|
|
|
|
return text_indices
|
|
|
|
def _protein2text(self, modality: str, protein_indices, text_indices: dict):
|
|
def do(process_id, idx, row, writer):
|
|
subsection, uniprot_id, prob_idx, label = row
|
|
|
|
|
|
p_embedding = protein_indices.reconstruct(prob_idx).reshape(1, -1)
|
|
text_inds = text_indices[subsection]
|
|
sim_scores, rank_inds = text_inds.search(p_embedding, text_inds.ntotal)
|
|
sim_scores, rank_inds = sim_scores[0], rank_inds[0]
|
|
|
|
|
|
ranks = []
|
|
label = set(label)
|
|
for i, rk in enumerate(rank_inds):
|
|
|
|
if rk in label:
|
|
ranks.append(i + 1)
|
|
|
|
ranks = np.array(ranks)
|
|
ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
|
|
|
|
|
|
best_rank = ranks[0]
|
|
mrr = 1 / best_rank
|
|
|
|
|
|
true_labels = np.zeros_like(sim_scores)
|
|
true_labels[ranks - 1] = 1
|
|
if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
|
|
auc = 0
|
|
else:
|
|
auc = roc_auc_score(true_labels, sim_scores)
|
|
|
|
output = json.dumps([ap, mrr, auc])
|
|
writer.write(output + "\n")
|
|
|
|
inputs = []
|
|
swissprot_subsections = set()
|
|
for subsection in text_indices.keys():
|
|
for i, (uniprot_id, labels) in enumerate(self.uniprot2label.items()):
|
|
if uniprot_id in self.swissprot_ids:
|
|
if subsection in labels:
|
|
swissprot_subsections.add(subsection)
|
|
label = labels[subsection]
|
|
inputs.append((subsection, uniprot_id, i, label))
|
|
|
|
|
|
random.seed(20000812)
|
|
random.shuffle(inputs)
|
|
|
|
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
|
|
span = math.ceil(len(inputs) / world_size)
|
|
sub_inputs = inputs[rank * span: (rank + 1) * span]
|
|
|
|
|
|
verbose = self.trainer.local_rank == 0
|
|
if verbose:
|
|
print("Evaluating on each subsection...")
|
|
tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
|
|
mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
|
|
return_results=True)
|
|
outputs = mpr.run()
|
|
os.remove(tmp_path)
|
|
|
|
|
|
tensor_outputs = []
|
|
for output in outputs:
|
|
ap, mrr, auc = json.loads(output)
|
|
tensor_outputs.append([float(ap), float(mrr), float(auc)])
|
|
|
|
tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
|
|
tensor_outputs = self.padded_gather(tensor_outputs)
|
|
|
|
|
|
avg_results = {}
|
|
for subsection in swissprot_subsections:
|
|
avg_results[subsection] = {"map": [],
|
|
"mrr": [],
|
|
"auc": []}
|
|
|
|
for input, output in zip(inputs, tensor_outputs):
|
|
ap, mrr, auc = output
|
|
subsection, _, _, _ = input
|
|
|
|
avg_results[subsection]["map"].append(ap.cpu().item())
|
|
avg_results[subsection]["mrr"].append(mrr.cpu().item())
|
|
avg_results[subsection]["auc"].append(auc.cpu().item())
|
|
|
|
results = {
|
|
f"{modality}2Text_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
|
|
f"{modality}2Text_Total_map": np.mean(avg_results["Total"]["map"]),
|
|
f"{modality}2Text_Total_auc": np.mean(avg_results["Total"]["auc"]),
|
|
}
|
|
|
|
|
|
for level, labels in [("residue-level", residue_level),
|
|
("sequence-level", sequence_level),
|
|
("all", residue_level | sequence_level)]:
|
|
|
|
mrrs = []
|
|
maps = []
|
|
aucs = []
|
|
for subsection in labels:
|
|
if subsection in avg_results:
|
|
mrrs.append(np.mean(avg_results[subsection]["mrr"]))
|
|
maps.append(np.mean(avg_results[subsection]["map"]))
|
|
aucs.append(np.mean(avg_results[subsection]["auc"]))
|
|
|
|
results[f"{modality}2Text_{level}_mrr"] = np.mean(mrrs)
|
|
results[f"{modality}2Text_{level}_map"] = np.mean(maps)
|
|
results[f"{modality}2Text_{level}_auc"] = np.mean(aucs)
|
|
|
|
return results
|
|
|
|
def _text2protein(self, modality: str, protein_indices, text_indices: dict):
|
|
def do(process_id, idx, row, writer):
|
|
subsection, text_id, label = row
|
|
|
|
|
|
t_embedding = text_indices[subsection].reconstruct(text_id).reshape(1, -1)
|
|
sim_scores, rank_inds = protein_indices.search(t_embedding, protein_indices.ntotal)
|
|
sim_scores, rank_inds = sim_scores[0], rank_inds[0]
|
|
|
|
|
|
ranks = []
|
|
label = set(label)
|
|
for i, rk in enumerate(rank_inds):
|
|
|
|
if rk in label:
|
|
ranks.append(i + 1)
|
|
|
|
ranks = np.array(ranks)
|
|
ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
|
|
|
|
|
|
best_rank = ranks[0]
|
|
mrr = 1 / best_rank
|
|
|
|
|
|
true_labels = np.zeros_like(sim_scores)
|
|
true_labels[ranks - 1] = 1
|
|
if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
|
|
auc = 0
|
|
else:
|
|
auc = roc_auc_score(true_labels, sim_scores)
|
|
|
|
output = json.dumps([ap, mrr, auc])
|
|
writer.write(output + "\n")
|
|
|
|
text2label = {}
|
|
swissprot_subsections = set()
|
|
for i, (uniprot_id, subsections) in enumerate(self.uniprot2label.items()):
|
|
|
|
if uniprot_id not in self.swissprot_ids:
|
|
continue
|
|
|
|
for subsection, text_ids in subsections.items():
|
|
if subsection == "seq" or subsection == "foldseek":
|
|
continue
|
|
|
|
swissprot_subsections.add(subsection)
|
|
if subsection not in text2label:
|
|
text2label[subsection] = {}
|
|
|
|
for text_id in text_ids:
|
|
text2label[subsection][text_id] = text2label[subsection].get(text_id, []) + [i]
|
|
|
|
inputs = []
|
|
for subsection in swissprot_subsections:
|
|
for i, (text_id, label) in enumerate(text2label[subsection].items()):
|
|
inputs.append((subsection, text_id, label))
|
|
|
|
|
|
random.seed(20000812)
|
|
random.shuffle(inputs)
|
|
|
|
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
|
|
span = math.ceil(len(inputs) / world_size)
|
|
sub_inputs = inputs[rank * span: (rank + 1) * span]
|
|
|
|
|
|
verbose = self.trainer.local_rank == 0
|
|
if verbose:
|
|
print("Evaluating on each text...")
|
|
|
|
|
|
tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
|
|
mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
|
|
return_results=True)
|
|
outputs = mpr.run()
|
|
os.remove(tmp_path)
|
|
|
|
|
|
tensor_outputs = []
|
|
for output in outputs:
|
|
ap, mrr, auc = json.loads(output)
|
|
tensor_outputs.append([float(ap), float(mrr), float(auc)])
|
|
|
|
tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
|
|
tensor_outputs = self.padded_gather(tensor_outputs)
|
|
|
|
|
|
avg_results = {}
|
|
for subsection in swissprot_subsections:
|
|
avg_results[subsection] = {"map": [],
|
|
"mrr": [],
|
|
"auc": []}
|
|
|
|
for input, output in zip(inputs, tensor_outputs):
|
|
ap, mrr, auc = output
|
|
subsection, _, _ = input
|
|
|
|
avg_results[subsection]["map"].append(ap.cpu().item())
|
|
avg_results[subsection]["mrr"].append(mrr.cpu().item())
|
|
avg_results[subsection]["auc"].append(auc.cpu().item())
|
|
|
|
results = {
|
|
f"Text2{modality}_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
|
|
f"Text2{modality}_Total_map": np.mean(avg_results["Total"]["map"]),
|
|
f"Text2{modality}_Total_auc": np.mean(avg_results["Total"]["auc"]),
|
|
}
|
|
|
|
|
|
for level, labels in [("residue-level", residue_level),
|
|
("sequence-level", sequence_level),
|
|
("all", residue_level | sequence_level)]:
|
|
|
|
mrrs = []
|
|
maps = []
|
|
aucs = []
|
|
for subsection in labels:
|
|
if subsection in avg_results:
|
|
mrrs.append(np.mean(avg_results[subsection]["mrr"]))
|
|
maps.append(np.mean(avg_results[subsection]["map"]))
|
|
aucs.append(np.mean(avg_results[subsection]["auc"]))
|
|
|
|
results[f"Text2{modality}_{level}_mrr"] = np.mean(mrrs)
|
|
results[f"Text2{modality}_{level}_map"] = np.mean(maps)
|
|
results[f"Text2{modality}_{level}_auc"] = np.mean(aucs)
|
|
|
|
return results
|
|
|
|
def retrieval_eval(self) -> dict:
|
|
|
|
protein_indices = self._get_protein_indices()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_indices = self._get_text_indices()
|
|
|
|
|
|
results = {}
|
|
results.update(self._protein2text("Sequence", protein_indices, text_indices))
|
|
|
|
|
|
|
|
|
|
|
|
results.update(self._text2protein("Sequence", protein_indices, text_indices))
|
|
|
|
return results
|
|
|
|
def _apply_bert_mask(self, tokens, tokenizer, mask_ratio):
|
|
while True:
|
|
masked_tokens = copy.copy(tokens)
|
|
labels = torch.full((len(tokens) + 2,), -1, dtype=torch.long)
|
|
vocab = [k for k in tokenizer.get_vocab().keys()]
|
|
|
|
for i in range(len(tokens)):
|
|
token = tokens[i]
|
|
|
|
prob = random.random()
|
|
if prob < mask_ratio:
|
|
prob /= mask_ratio
|
|
labels[i + 1] = tokenizer.convert_tokens_to_ids(token)
|
|
|
|
if prob < 0.8:
|
|
|
|
if self.use_saprot:
|
|
token = "#" + token[-1]
|
|
else:
|
|
token = tokenizer.mask_token
|
|
elif prob < 0.9:
|
|
|
|
token = random.choice(vocab)
|
|
else:
|
|
|
|
pass
|
|
|
|
masked_tokens[i] = token
|
|
|
|
|
|
if (labels != -1).any():
|
|
return masked_tokens, labels
|
|
|
|
def mlm_eval(self) -> float:
|
|
world_size = dist.get_world_size()
|
|
rank = dist.get_rank()
|
|
|
|
if self.use_saprot:
|
|
proteins = []
|
|
for sub_dict in self.uniprot2label.values():
|
|
aa_seq = sub_dict["seq"]
|
|
foldseek_seq = sub_dict["foldseek"]
|
|
assert len(aa_seq) == len(foldseek_seq)
|
|
seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
|
|
proteins.append(seq)
|
|
|
|
else:
|
|
proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
|
|
|
|
span = math.ceil(len(proteins) / world_size)
|
|
sub_proteins = proteins[rank * span: (rank + 1) * span]
|
|
|
|
|
|
if self.trainer.local_rank == 0:
|
|
iterator = tqdm(sub_proteins, desc="Computing mlm...")
|
|
else:
|
|
iterator = sub_proteins
|
|
|
|
total = torch.tensor([0], dtype=torch.long, device=self.device)
|
|
correct = torch.tensor([0], dtype=torch.long, device=self.device)
|
|
for seq in iterator:
|
|
tokens = self.protein_encoder.tokenizer.tokenize(seq)
|
|
masked_tokens, labels = self._apply_bert_mask(tokens, self.protein_encoder.tokenizer, 0.15)
|
|
seq = " ".join(masked_tokens)
|
|
|
|
inputs = self.protein_encoder.tokenizer(seq, return_tensors="pt")
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
_, logits = self.protein_encoder(inputs, get_mask_logits=True)
|
|
|
|
logits = logits.squeeze(0)
|
|
labels = labels.to(self.device)
|
|
|
|
selecor = labels != -1
|
|
preds = logits.argmax(dim=-1)[selecor]
|
|
labels = labels[selecor]
|
|
|
|
total += len(preds)
|
|
correct += (preds == labels).sum()
|
|
|
|
|
|
total = self.padded_gather(total).sum()
|
|
correct = self.padded_gather(correct).sum()
|
|
|
|
acc = correct / total
|
|
return acc.cpu().item()
|
|
|
|
def _load_eval_data(self, stage):
|
|
|
|
lmdb_dir = eval(f"self.trainer.datamodule.{stage}_lmdb")
|
|
uniprot2label_path = os.path.join(lmdb_dir, "uniprot2label.json")
|
|
label2text_path = os.path.join(lmdb_dir, "label2text.json")
|
|
swissprot_id_path = os.path.join(lmdb_dir, "swissprot_ids.tsv")
|
|
|
|
self.uniprot2label = json.load(open(uniprot2label_path, "r"))
|
|
self.label2text = json.load(open(label2text_path, "r"))
|
|
self.swissprot_ids = set(pd.read_csv(swissprot_id_path, sep="\t", header=None).values.flatten().tolist())
|
|
self.k = 3
|
|
|
|
def on_test_start(self):
|
|
self._load_eval_data("test")
|
|
|
|
log_dict = self.retrieval_eval()
|
|
log_dict = {"test_" + k: v for k, v in log_dict.items()}
|
|
if self.use_mlm_loss:
|
|
log_dict["test_mask_acc"] = self.mlm_eval()
|
|
self.log_info(log_dict)
|
|
print(log_dict)
|
|
|
|
def on_validation_start(self):
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
self._load_eval_data("valid")
|
|
|
|
log_dict = self.retrieval_eval()
|
|
log_dict = {"valid_" + k: v for k, v in log_dict.items()}
|
|
if self.use_mlm_loss:
|
|
log_dict["valid_mask_acc"] = self.mlm_eval()
|
|
self.log_info(log_dict)
|
|
|
|
self.check_save_condition(self.step, mode="max")
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
return
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
return
|
|
|
|
def on_train_epoch_end(self):
|
|
super().on_train_epoch_end()
|
|
|
|
if self.trainer.datamodule.train_dataset.fixed_dataset_num is not None:
|
|
self.trainer.datamodule.train_dataset.sample_subset()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|