LTEnjoy commited on
Commit
9606143
·
verified ·
1 Parent(s): ee76aa3

Delete model

Browse files
model/ProTrek/protein_encoder.py DELETED
@@ -1,95 +0,0 @@
1
- import torch
2
-
3
- from tqdm import tqdm
4
- from torch.nn.functional import normalize
5
- from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer
6
-
7
-
8
- class ProteinEncoder(torch.nn.Module):
9
- def __init__(self,
10
- config_path: str,
11
- out_dim: int,
12
- load_pretrained: bool = True,
13
- gradient_checkpointing: bool = False):
14
- """
15
- Args:
16
- config_path: Path to the config file
17
-
18
- out_dim : Output dimension of the protein representation
19
-
20
- load_pretrained: Whether to load pretrained weights
21
-
22
- gradient_checkpointing: Whether to use gradient checkpointing
23
- """
24
- super().__init__()
25
- config = EsmConfig.from_pretrained(config_path)
26
- if load_pretrained:
27
- self.model = EsmForMaskedLM.from_pretrained(config_path)
28
- else:
29
- self.model = EsmForMaskedLM(config)
30
- self.out = torch.nn.Linear(config.hidden_size, out_dim)
31
-
32
- # Set gradient checkpointing
33
- self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
34
-
35
- # Remove contact head
36
- self.model.esm.contact_head = None
37
-
38
- # Remove position embedding if the embedding type is ``rotary``
39
- if config.position_embedding_type == "rotary":
40
- self.model.esm.embeddings.position_embeddings = None
41
-
42
- self.tokenizer = EsmTokenizer.from_pretrained(config_path)
43
-
44
- def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
45
- """
46
- Compute protein representation for the given proteins
47
- Args:
48
- protein: A list of protein sequences
49
- batch_size: Batch size for inference
50
- verbose: Whether to print progress
51
- """
52
- device = next(self.parameters()).device
53
-
54
- protein_repr = []
55
- if verbose:
56
- iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
57
- else:
58
- iterator = range(0, len(proteins), batch_size)
59
-
60
- for i in iterator:
61
- protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
62
- return_tensors="pt",
63
- padding=True)
64
- protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
65
- output, _ = self.forward(protein_inputs)
66
-
67
- protein_repr.append(output)
68
-
69
- protein_repr = torch.cat(protein_repr, dim=0)
70
- return normalize(protein_repr, dim=-1)
71
-
72
- def forward(self, inputs: dict, get_mask_logits: bool = False):
73
- """
74
- Encode protein sequence into protein representation
75
- Args:
76
- inputs: A dictionary containing the following keys:
77
- - input_ids: [batch, seq_len]
78
- - attention_mask: [batch, seq_len]
79
- get_mask_logits: Whether to return the logits for masked tokens
80
-
81
- Returns:
82
- protein_repr: [batch, protein_repr_dim]
83
- mask_logits : [batch, seq_len, vocab_size]
84
- """
85
- last_hidden_state = self.model.esm(**inputs).last_hidden_state
86
- reprs = last_hidden_state[:, 0, :]
87
- reprs = self.out(reprs)
88
-
89
- # Get logits for masked tokens
90
- if get_mask_logits:
91
- mask_logits = self.model.lm_head(last_hidden_state)
92
- else:
93
- mask_logits = None
94
-
95
- return reprs, mask_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/ProTrek/protrek_trimodal_model.py DELETED
@@ -1,874 +0,0 @@
1
- import torch
2
- import torch.distributed as dist
3
- import torchmetrics
4
- import json
5
- import math
6
- import numpy as np
7
- import os
8
- import copy
9
- import faiss
10
- import time
11
- import pandas as pd
12
- import random
13
-
14
- from tqdm import tqdm
15
- from .protein_encoder import ProteinEncoder
16
- from .structure_encoder import StructureEncoder
17
- from .text_encoder import TextEncoder
18
- from ..abstract_model import AbstractModel
19
- from ..model_interface import register_model
20
- from utils.mpr import MultipleProcessRunnerSimplifier
21
- from torch.nn.functional import normalize, cross_entropy
22
- from utils.constants import residue_level, sequence_level
23
- from sklearn.metrics import roc_auc_score
24
-
25
-
26
- def multilabel_cross_entropy(logits, labels):
27
- """
28
- Compute cross entropy loss for multilabel classification。 See "https://arxiv.org/pdf/2208.02955.pdf"
29
- Args:
30
- logits: [num_samples, num_classes]
31
- labels: [num_samples, num_classes]
32
- """
33
-
34
- loss = 0
35
- for pred, label in zip(logits, labels):
36
- pos_logits = pred[label == 1]
37
- neg_logits = pred[label == 0]
38
-
39
- diff = neg_logits.unsqueeze(-1) - pos_logits
40
- loss += torch.log(1 + torch.exp(diff).sum())
41
-
42
- return loss / len(logits)
43
-
44
- # pred = (1 - 2 * labels) * logits
45
- # pred_neg = pred - labels * 1e12
46
- # pred_pos = pred - (1 - labels) * 1e12
47
- #
48
- # zeros = torch.zeros_like(logits[..., :1], dtype=logits.dtype)
49
- # pred_neg = torch.cat([pred_neg, zeros], dim=-1)
50
- # pred_pos = torch.cat([pred_pos, zeros], dim=-1)
51
- #
52
- # neg_loss = torch.logsumexp(pred_neg, dim=-1)
53
- # pos_loss = torch.logsumexp(pred_pos, dim=-1)
54
- #
55
- # return (neg_loss + pos_loss).mean()
56
-
57
-
58
- @register_model
59
- class ProTrekTrimodalModel(AbstractModel):
60
- def __init__(self,
61
- protein_config: str,
62
- text_config: str,
63
- structure_config: str = None,
64
- repr_dim: int = 1024,
65
- temperature: float = 0.07,
66
- load_protein_pretrained: bool = True,
67
- load_text_pretrained: bool = True,
68
- use_mlm_loss: bool = False,
69
- use_zlpr_loss: bool = False,
70
- use_saprot: bool = False,
71
- gradient_checkpointing: bool = False,
72
- **kwargs):
73
- """
74
- Args:
75
- protein_config: Path to the config file for protein sequence encoder
76
-
77
- text_config: Path to the config file for text encoder
78
-
79
- structure_config: Path to the config file for structure encoder
80
-
81
- repr_dim: Output dimension of the protein and text representation
82
-
83
- temperature: Temperature for softmax
84
-
85
- load_protein_pretrained: Whether to load pretrained weights for protein encoder
86
-
87
- load_text_pretrained: Whether to load pretrained weights for text encoder
88
-
89
- use_mlm_loss: Whether to use masked language modeling loss
90
-
91
- use_zlpr_loss: Whether to use zlpr loss. See "https://arxiv.org/pdf/2208.02955.pdf"
92
-
93
- use_saprot: Whether to use SaProt as protein encoder
94
-
95
- gradient_checkpointing: Whether to use gradient checkpointing for protein encoder
96
- """
97
- self.protein_config = protein_config
98
- self.structure_config = structure_config
99
- self.text_config = text_config
100
- self.repr_dim = repr_dim
101
- self.temperature = temperature
102
- self.load_protein_pretrained = load_protein_pretrained
103
- self.load_text_pretrained = load_text_pretrained
104
- self.use_mlm_loss = use_mlm_loss
105
- self.use_zlpr_loss = use_zlpr_loss
106
- self.use_saprot = use_saprot
107
- self.gradient_checkpointing = gradient_checkpointing
108
- super().__init__(**kwargs)
109
-
110
- def initialize_metrics(self, stage: str) -> dict:
111
- return_dict = {
112
- f"{stage}_protein_text_acc": torchmetrics.Accuracy(),
113
- f"{stage}_text_protein_acc": torchmetrics.Accuracy(),
114
- }
115
-
116
- if self.use_mlm_loss:
117
- return_dict[f"{stage}_protein_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
118
- if self.structure_config is not None:
119
- return_dict[f"{stage}_structure_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
120
-
121
- if self.structure_config is not None:
122
- return_dict[f"{stage}_structure_protein_acc"] = torchmetrics.Accuracy()
123
- return_dict[f"{stage}_structure_text_acc"] = torchmetrics.Accuracy()
124
- return_dict[f"{stage}_text_structure_acc"] = torchmetrics.Accuracy()
125
- return_dict[f"{stage}_protein_structure_acc"] = torchmetrics.Accuracy()
126
-
127
- return return_dict
128
-
129
- def initialize_model(self):
130
- # Initialize encoders
131
- self.protein_encoder = ProteinEncoder(self.protein_config,
132
- self.repr_dim,
133
- self.load_protein_pretrained,
134
- self.gradient_checkpointing)
135
-
136
- self.text_encoder = TextEncoder(self.text_config,
137
- self.repr_dim,
138
- self.load_text_pretrained,
139
- self.gradient_checkpointing)
140
-
141
- # Learnable temperature
142
- self.temperature = torch.nn.Parameter(torch.tensor(self.temperature))
143
-
144
- # self.model is used for saving and loading
145
- self.model = torch.nn.ParameterList([self.temperature,
146
- self.protein_encoder,
147
- self.text_encoder])
148
-
149
- # If the structure encoder is specified
150
- if self.structure_config is not None:
151
- self.structure_encoder = StructureEncoder(self.structure_config, self.repr_dim)
152
- self.model.append(self.structure_encoder)
153
-
154
- def get_text_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
155
- return self.text_encoder.get_repr(texts, batch_size, verbose)
156
-
157
- def get_structure_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
158
- return self.structure_encoder.get_repr(proteins, batch_size, verbose)
159
-
160
- def get_protein_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
161
- return self.protein_encoder.get_repr(proteins, batch_size, verbose)
162
-
163
- def forward(self, protein_inputs: dict, text_inputs: dict, structure_inputs: dict = None):
164
- """
165
- Args:
166
- protein_inputs: A dictionary for protein encoder
167
- structure_inputs: A dictionary for structure encoder
168
- text_inputs : A dictionary for text encoder
169
- """
170
- protein_repr, protein_mask_logits = self.protein_encoder(protein_inputs, self.use_mlm_loss)
171
- text_repr = self.text_encoder(text_inputs)
172
-
173
- outputs = [text_repr, protein_repr, protein_mask_logits]
174
-
175
- if self.structure_config is not None:
176
- structure_repr, structure_mask_logits = self.structure_encoder(structure_inputs, self.use_mlm_loss)
177
- outputs += [structure_repr, structure_mask_logits]
178
-
179
- return outputs
180
-
181
- def loss_func(self, stage: str, outputs, labels):
182
- if self.structure_config is not None:
183
- text_repr, protein_repr, protein_mask_logits, structure_repr, structure_mask_logits = outputs
184
- else:
185
- text_repr, protein_repr, protein_mask_logits = outputs
186
-
187
- device = text_repr.device
188
-
189
- text_repr = normalize(text_repr, dim=-1)
190
- protein_repr = normalize(protein_repr, dim=-1)
191
-
192
- # Gather representations from all GPUs
193
- all_protein_repr = self.all_gather(protein_repr).view(-1, protein_repr.shape[-1]).detach()
194
- all_text_repr = self.all_gather(text_repr).view(-1, text_repr.shape[-1]).detach()
195
-
196
- if self.structure_config is not None:
197
- structure_repr = normalize(structure_repr, dim=-1)
198
- all_structure_repr = self.all_gather(structure_repr).view(-1, structure_repr.shape[-1]).detach()
199
-
200
- # text_idx = labels["text_idx"]
201
- # text_candidates = labels["text_candidates"]
202
- #
203
- # # Gather all text ids
204
- # text_inds = self.all_gather(text_idx).flatten()
205
- # # Create text classification labels
206
- # text_labels = torch.zeros(len(text_candidates), len(text_inds), dtype=int).to(device)
207
- # for i, candidate in enumerate(text_candidates):
208
- # for j, idx in enumerate(text_inds):
209
- # if idx.item() in candidate:
210
- # text_labels[i, j] = 1
211
- #
212
- # # Gather text labels from all GPUs
213
- # text_labels = self.all_gather(text_labels).view(-1, text_labels.shape[-1])
214
- #
215
- # # Protein classification labels are the transpose of text labels
216
- # protein_labels = text_labels.T
217
-
218
- # Batch size
219
- rank = dist.get_rank()
220
- bs = text_repr.shape[0]
221
-
222
- # Get current labels
223
- # protein_labels = protein_labels[rank * bs: rank * bs + bs]
224
- # text_labels = text_labels[rank * bs: rank * bs + bs]
225
-
226
- # Create classification labels between structure and sequence
227
- bs_labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(device)
228
-
229
- if self.structure_config is not None:
230
- pairs = {
231
- "protein": ["structure", "text"],
232
- "structure": ["protein", "text"],
233
- "text": ["protein", "structure"]
234
- }
235
- else:
236
- pairs = {
237
- "protein": ["text"],
238
- "text": ["protein"]
239
- }
240
-
241
- loss_list = []
242
- for k, values in pairs.items():
243
- for v in values:
244
- # Only calculate the similarity for the current batch
245
- sim = torch.matmul(eval(f"{k}_repr"), eval(f"all_{v}_repr").T).div(self.temperature)
246
-
247
- # if k == "text":
248
- # if self.use_zlpr_loss:
249
- # loss = multilabel_cross_entropy(sim, protein_labels)
250
- # else:
251
- # loss = cross_entropy(sim, bs_labels)
252
- #
253
- # pred = []
254
- # for s, l in zip(sim, protein_labels):
255
- # n_label = l.sum()
256
- # topk = torch.topk(s, k=n_label).indices
257
- # if l[topk].sum() == n_label:
258
- # pred.append(1)
259
- # else:
260
- # pred.append(0)
261
- #
262
- # pred = torch.tensor(pred).to(device)
263
- # label = torch.ones_like(pred)
264
- # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
265
- # # if v == "protein":
266
- # # acc = self.metrics[stage][f"{stage}_{k}_{v}_acc"].compute()
267
- # # print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
268
- #
269
- # elif v == "text":
270
- # if self.use_zlpr_loss:
271
- # loss = multilabel_cross_entropy(sim, text_labels)
272
- # else:
273
- # loss = cross_entropy(sim, bs_labels)
274
- #
275
- # pred = []
276
- # for s, l in zip(sim, text_labels):
277
- # n_label = l.sum()
278
- # topk = torch.topk(s, k=n_label).indices
279
- # if l[topk].sum() == n_label:
280
- # pred.append(1)
281
- # else:
282
- # pred.append(0)
283
- #
284
- # pred = torch.tensor(pred).to(device)
285
- # label = torch.ones_like(pred)
286
- # # if k == "protein":
287
- # # acc = pred.sum() / len(pred)
288
- # # print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
289
- # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
290
- #
291
- # else:
292
- # loss = cross_entropy(sim, bs_labels)
293
- # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
294
-
295
- loss = cross_entropy(sim, bs_labels)
296
- self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
297
- loss_list.append(loss)
298
-
299
- # Masked language modeling loss
300
- if self.use_mlm_loss:
301
- k_label = [("protein", labels["seq_labels"])]
302
- if self.structure_config is not None:
303
- k_label.append(("structure", labels["struc_labels"]))
304
-
305
- for k, label in k_label:
306
- logits = eval(f"{k}_mask_logits")
307
- # merge the first and second dimension of logits
308
- logits = logits.view(-1, logits.shape[-1])
309
- label = label.flatten().to(device)
310
- mlm_loss = cross_entropy(logits, label, ignore_index=-1)
311
- loss_list.append(mlm_loss)
312
- self.metrics[stage][f"{stage}_{k}_mask_acc"].update(logits.detach(), label)
313
-
314
- loss = sum(loss_list) / len(loss_list)
315
-
316
- if stage == "train":
317
- log_dict = self.get_log_dict("train")
318
- log_dict["train_loss"] = loss
319
- self.log_info(log_dict)
320
-
321
- # Reset train metrics
322
- self.reset_metrics("train")
323
-
324
- return loss
325
-
326
- def padded_gather(self, tensor: torch.Tensor):
327
- """
328
- Gather tensors from all GPUs, allowing different shapes at the batch dimension.
329
- """
330
-
331
- # Get the size of the tensor
332
- size = tensor.shape[0]
333
- all_sizes = self.all_gather(torch.tensor(size, device=tensor.device))
334
- max_size = max(all_sizes)
335
-
336
- # Pad the tensor
337
- if size != max_size:
338
- tmp = torch.zeros(max_size, tensor.shape[-1], dtype=tensor.dtype, device=tensor.device)
339
- tmp[:size] = tensor
340
- tensor = tmp
341
-
342
- padded_tensor = self.all_gather(tensor).view(-1, tensor.shape[-1])
343
- tensor = padded_tensor[:sum(all_sizes)]
344
-
345
- return tensor
346
-
347
- def _get_protein_indices(self):
348
- world_size = dist.get_world_size()
349
- rank = dist.get_rank()
350
-
351
- if self.use_saprot:
352
- proteins = []
353
- for sub_dict in self.uniprot2label.values():
354
- aa_seq = sub_dict["seq"]
355
- foldseek_seq = sub_dict["foldseek"]
356
- assert len(aa_seq) == len(foldseek_seq)
357
- seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
358
- proteins.append(seq)
359
-
360
- else:
361
- proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
362
-
363
- span = math.ceil(len(proteins) / world_size)
364
- sub_proteins = proteins[rank * span: (rank + 1) * span]
365
-
366
- # Display the progress bar on the rank 0 process
367
- verbose = self.trainer.local_rank == 0
368
- # Get protein representations
369
- sub_protein_repr = self.protein_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
370
- protein_repr = self.padded_gather(sub_protein_repr)
371
-
372
- # Construct faiss index
373
- d = protein_repr.shape[-1]
374
- protein_indices = faiss.IndexFlatIP(d)
375
- protein_indices.add(protein_repr.cpu().numpy())
376
- return protein_indices
377
-
378
- def _get_structure_indices(self):
379
- world_size = dist.get_world_size()
380
- rank = dist.get_rank()
381
-
382
- proteins = [sub_dict["foldseek"] for sub_dict in self.uniprot2label.values()]
383
- span = math.ceil(len(proteins) / world_size)
384
- sub_proteins = proteins[rank * span: (rank + 1) * span]
385
-
386
- # Display the progress bar on the rank 0 process
387
- verbose = self.trainer.local_rank == 0
388
- # Get protein representations
389
- sub_protein_repr = self.structure_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
390
- protein_repr = self.padded_gather(sub_protein_repr)
391
-
392
- # Construct faiss index
393
- d = protein_repr.shape[-1]
394
- structure_indices = faiss.IndexFlatIP(d)
395
- structure_indices.add(protein_repr.cpu().numpy())
396
- return structure_indices
397
-
398
- def _get_text_indices(self):
399
- world_size = dist.get_world_size()
400
- rank = dist.get_rank()
401
-
402
- # Display the progress bar on the rank 0 process
403
- verbose = self.trainer.local_rank == 0
404
- if verbose:
405
- iterator = tqdm(self.label2text.keys(), desc="Get text representations")
406
- else:
407
- iterator = self.label2text.keys()
408
-
409
- text_embeddings = {}
410
- for subsection in iterator:
411
- if subsection == "Total":
412
- continue
413
-
414
- texts = []
415
- for text_list in self.label2text[subsection].values():
416
- # Only use the first text for efficiency
417
- texts.append(text_list[0:1])
418
-
419
- span = math.ceil(len(texts) / world_size)
420
- texts = texts[rank * span: (rank + 1) * span]
421
- embeddings = []
422
- for text_list in texts:
423
- text_repr = self.text_encoder.get_repr(text_list)
424
- mean_repr = text_repr.mean(dim=0, keepdim=True)
425
- norm_repr = torch.nn.functional.normalize(mean_repr, dim=-1)
426
- embeddings.append(norm_repr)
427
-
428
- if len(embeddings) > 0:
429
- embeddings = torch.cat(embeddings, dim=0)
430
- else:
431
- embeddings = torch.zeros(0, self.repr_dim, dtype=self.dtype, device=self.device)
432
-
433
- text_repr = self.padded_gather(embeddings)
434
- text_embeddings[subsection] = text_repr
435
-
436
- # Aggregate text embeddings for global retrieval
437
- total_embeddings = []
438
- for idx in self.label2text["Total"].values():
439
- subsection, i = idx.split("|")
440
- total_embeddings.append(text_embeddings[subsection][int(i)])
441
-
442
- text_embeddings["Total"] = torch.stack(total_embeddings)
443
-
444
- # Construct faiss index
445
- text_indices = {}
446
- for subsection, text_repr in text_embeddings.items():
447
- d = text_repr.shape[-1]
448
- text_indices[subsection] = faiss.IndexFlatIP(d)
449
- text_indices[subsection].add(text_repr.cpu().numpy())
450
-
451
- return text_indices
452
-
453
- def _protein2text(self, modality: str, protein_indices, text_indices: dict):
454
- def do(process_id, idx, row, writer):
455
- subsection, uniprot_id, prob_idx, label = row
456
-
457
- # Retrieve ranking results
458
- p_embedding = protein_indices.reconstruct(prob_idx).reshape(1, -1)
459
- text_inds = text_indices[subsection]
460
- sim_scores, rank_inds = text_inds.search(p_embedding, text_inds.ntotal)
461
- sim_scores, rank_inds = sim_scores[0], rank_inds[0]
462
-
463
- # Calculate Average Precision(AP)
464
- ranks = []
465
- label = set(label)
466
- for i, rk in enumerate(rank_inds):
467
- # Find the rank of this label in all labels
468
- if rk in label:
469
- ranks.append(i + 1)
470
-
471
- ranks = np.array(ranks)
472
- ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
473
-
474
- # Calculate Mean Reciprocal Rank(MRR)
475
- best_rank = ranks[0]
476
- mrr = 1 / best_rank
477
-
478
- # Calculate the AUC
479
- true_labels = np.zeros_like(sim_scores)
480
- true_labels[ranks - 1] = 1
481
- if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
482
- auc = 0
483
- else:
484
- auc = roc_auc_score(true_labels, sim_scores)
485
-
486
- output = json.dumps([ap, mrr, auc])
487
- writer.write(output + "\n")
488
-
489
- inputs = []
490
- swissprot_subsections = set()
491
- for subsection in text_indices.keys():
492
- for i, (uniprot_id, labels) in enumerate(self.uniprot2label.items()):
493
- if uniprot_id in self.swissprot_ids:
494
- if subsection in labels:
495
- swissprot_subsections.add(subsection)
496
- label = labels[subsection]
497
- inputs.append((subsection, uniprot_id, i, label))
498
-
499
- # Randomly shuffle the inputs
500
- random.seed(20000812)
501
- random.shuffle(inputs)
502
-
503
- # Split inputs into chunks for parallel processing
504
- world_size = dist.get_world_size()
505
- rank = dist.get_rank()
506
-
507
- span = math.ceil(len(inputs) / world_size)
508
- sub_inputs = inputs[rank * span: (rank + 1) * span]
509
-
510
- # Display the progress bar on the rank 0 process
511
- verbose = self.trainer.local_rank == 0
512
- if verbose:
513
- print("Evaluating on each subsection...")
514
- tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
515
- mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
516
- return_results=True)
517
- outputs = mpr.run()
518
- os.remove(tmp_path)
519
-
520
- # Aggregate results
521
- tensor_outputs = []
522
- for output in outputs:
523
- ap, mrr, auc = json.loads(output)
524
- tensor_outputs.append([float(ap), float(mrr), float(auc)])
525
-
526
- tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
527
- tensor_outputs = self.padded_gather(tensor_outputs)
528
-
529
- # Record results
530
- avg_results = {}
531
- for subsection in swissprot_subsections:
532
- avg_results[subsection] = {"map": [],
533
- "mrr": [],
534
- "auc": []}
535
-
536
- for input, output in zip(inputs, tensor_outputs):
537
- ap, mrr, auc = output
538
- subsection, _, _, _ = input
539
-
540
- avg_results[subsection]["map"].append(ap.cpu().item())
541
- avg_results[subsection]["mrr"].append(mrr.cpu().item())
542
- avg_results[subsection]["auc"].append(auc.cpu().item())
543
-
544
- results = {
545
- f"{modality}2Text_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
546
- f"{modality}2Text_Total_map": np.mean(avg_results["Total"]["map"]),
547
- f"{modality}2Text_Total_auc": np.mean(avg_results["Total"]["auc"]),
548
- }
549
-
550
- # Average the precision and recall for each level
551
- for level, labels in [("residue-level", residue_level),
552
- ("sequence-level", sequence_level),
553
- ("all", residue_level | sequence_level)]:
554
-
555
- mrrs = []
556
- maps = []
557
- aucs = []
558
- for subsection in labels:
559
- if subsection in avg_results:
560
- mrrs.append(np.mean(avg_results[subsection]["mrr"]))
561
- maps.append(np.mean(avg_results[subsection]["map"]))
562
- aucs.append(np.mean(avg_results[subsection]["auc"]))
563
-
564
- results[f"{modality}2Text_{level}_mrr"] = np.mean(mrrs)
565
- results[f"{modality}2Text_{level}_map"] = np.mean(maps)
566
- results[f"{modality}2Text_{level}_auc"] = np.mean(aucs)
567
-
568
- return results
569
-
570
- def _text2protein(self, modality: str, protein_indices, text_indices: dict):
571
- def do(process_id, idx, row, writer):
572
- subsection, text_id, label = row
573
-
574
- # Retrieve ranking results
575
- t_embedding = text_indices[subsection].reconstruct(text_id).reshape(1, -1)
576
- sim_scores, rank_inds = protein_indices.search(t_embedding, protein_indices.ntotal)
577
- sim_scores, rank_inds = sim_scores[0], rank_inds[0]
578
-
579
- # Calculate Average Precision(AP)
580
- ranks = []
581
- label = set(label)
582
- for i, rk in enumerate(rank_inds):
583
- # Find the rank of this label in all labels
584
- if rk in label:
585
- ranks.append(i + 1)
586
-
587
- ranks = np.array(ranks)
588
- ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
589
-
590
- # Calculate Mean Reciprocal Rank(MRR)
591
- best_rank = ranks[0]
592
- mrr = 1 / best_rank
593
-
594
- # Calculate the AUC
595
- true_labels = np.zeros_like(sim_scores)
596
- true_labels[ranks - 1] = 1
597
- if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
598
- auc = 0
599
- else:
600
- auc = roc_auc_score(true_labels, sim_scores)
601
-
602
- output = json.dumps([ap, mrr, auc])
603
- writer.write(output + "\n")
604
-
605
- text2label = {}
606
- swissprot_subsections = set()
607
- for i, (uniprot_id, subsections) in enumerate(self.uniprot2label.items()):
608
- # Only evaluate the texts in Swiss-Prot
609
- if uniprot_id not in self.swissprot_ids:
610
- continue
611
-
612
- for subsection, text_ids in subsections.items():
613
- if subsection == "seq" or subsection == "foldseek":
614
- continue
615
-
616
- swissprot_subsections.add(subsection)
617
- if subsection not in text2label:
618
- text2label[subsection] = {}
619
-
620
- for text_id in text_ids:
621
- text2label[subsection][text_id] = text2label[subsection].get(text_id, []) + [i]
622
-
623
- inputs = []
624
- for subsection in swissprot_subsections:
625
- for i, (text_id, label) in enumerate(text2label[subsection].items()):
626
- inputs.append((subsection, text_id, label))
627
-
628
- # Randomly shuffle the inputs
629
- random.seed(20000812)
630
- random.shuffle(inputs)
631
-
632
- # Split inputs into chunks for parallel processing
633
- world_size = dist.get_world_size()
634
- rank = dist.get_rank()
635
-
636
- span = math.ceil(len(inputs) / world_size)
637
- sub_inputs = inputs[rank * span: (rank + 1) * span]
638
-
639
- # Display the progress bar on the rank 0 process
640
- verbose = self.trainer.local_rank == 0
641
- if verbose:
642
- print("Evaluating on each text...")
643
-
644
- # Add time stamp to the temporary file name to avoid conflicts
645
- tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
646
- mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
647
- return_results=True)
648
- outputs = mpr.run()
649
- os.remove(tmp_path)
650
-
651
- # Aggregate results
652
- tensor_outputs = []
653
- for output in outputs:
654
- ap, mrr, auc = json.loads(output)
655
- tensor_outputs.append([float(ap), float(mrr), float(auc)])
656
-
657
- tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
658
- tensor_outputs = self.padded_gather(tensor_outputs)
659
-
660
- # Record results
661
- avg_results = {}
662
- for subsection in swissprot_subsections:
663
- avg_results[subsection] = {"map": [],
664
- "mrr": [],
665
- "auc": []}
666
-
667
- for input, output in zip(inputs, tensor_outputs):
668
- ap, mrr, auc = output
669
- subsection, _, _ = input
670
-
671
- avg_results[subsection]["map"].append(ap.cpu().item())
672
- avg_results[subsection]["mrr"].append(mrr.cpu().item())
673
- avg_results[subsection]["auc"].append(auc.cpu().item())
674
-
675
- results = {
676
- f"Text2{modality}_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
677
- f"Text2{modality}_Total_map": np.mean(avg_results["Total"]["map"]),
678
- f"Text2{modality}_Total_auc": np.mean(avg_results["Total"]["auc"]),
679
- }
680
-
681
- # Average the precision and recall for each level
682
- for level, labels in [("residue-level", residue_level),
683
- ("sequence-level", sequence_level),
684
- ("all", residue_level | sequence_level)]:
685
-
686
- mrrs = []
687
- maps = []
688
- aucs = []
689
- for subsection in labels:
690
- if subsection in avg_results:
691
- mrrs.append(np.mean(avg_results[subsection]["mrr"]))
692
- maps.append(np.mean(avg_results[subsection]["map"]))
693
- aucs.append(np.mean(avg_results[subsection]["auc"]))
694
-
695
- results[f"Text2{modality}_{level}_mrr"] = np.mean(mrrs)
696
- results[f"Text2{modality}_{level}_map"] = np.mean(maps)
697
- results[f"Text2{modality}_{level}_auc"] = np.mean(aucs)
698
-
699
- return results
700
-
701
- def retrieval_eval(self) -> dict:
702
- # Get protein representations
703
- protein_indices = self._get_protein_indices()
704
-
705
- # Get structure representations
706
- # if self.structure_config is not None:
707
- # structure_embeddings = self._get_structure_embeddings()
708
-
709
- # Get text representations
710
- text_indices = self._get_text_indices()
711
-
712
- # Retrieve texts for each protein
713
- results = {}
714
- results.update(self._protein2text("Sequence", protein_indices, text_indices))
715
- # if self.structure_config is not None:
716
- # results.update(self._protein2text("Structure", structure_embeddings, text_embeddings))
717
- # results.update(self._text2protein("Structure", structure_embeddings, text_embeddings))
718
-
719
- # Retrieve proteins for each text
720
- results.update(self._text2protein("Sequence", protein_indices, text_indices))
721
-
722
- return results
723
-
724
- def _apply_bert_mask(self, tokens, tokenizer, mask_ratio):
725
- while True:
726
- masked_tokens = copy.copy(tokens)
727
- labels = torch.full((len(tokens) + 2,), -1, dtype=torch.long)
728
- vocab = [k for k in tokenizer.get_vocab().keys()]
729
-
730
- for i in range(len(tokens)):
731
- token = tokens[i]
732
-
733
- prob = random.random()
734
- if prob < mask_ratio:
735
- prob /= mask_ratio
736
- labels[i + 1] = tokenizer.convert_tokens_to_ids(token)
737
-
738
- if prob < 0.8:
739
- # 80% random change to mask token
740
- if self.use_saprot:
741
- token = "#" + token[-1]
742
- else:
743
- token = tokenizer.mask_token
744
- elif prob < 0.9:
745
- # 10% chance to change to random token
746
- token = random.choice(vocab)
747
- else:
748
- # 10% chance to keep current token
749
- pass
750
-
751
- masked_tokens[i] = token
752
-
753
- # Check if there is at least one masked token
754
- if (labels != -1).any():
755
- return masked_tokens, labels
756
-
757
- def mlm_eval(self) -> float:
758
- world_size = dist.get_world_size()
759
- rank = dist.get_rank()
760
-
761
- if self.use_saprot:
762
- proteins = []
763
- for sub_dict in self.uniprot2label.values():
764
- aa_seq = sub_dict["seq"]
765
- foldseek_seq = sub_dict["foldseek"]
766
- assert len(aa_seq) == len(foldseek_seq)
767
- seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
768
- proteins.append(seq)
769
-
770
- else:
771
- proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
772
-
773
- span = math.ceil(len(proteins) / world_size)
774
- sub_proteins = proteins[rank * span: (rank + 1) * span]
775
-
776
- # Display the progress bar on the rank 0 process
777
- if self.trainer.local_rank == 0:
778
- iterator = tqdm(sub_proteins, desc="Computing mlm...")
779
- else:
780
- iterator = sub_proteins
781
-
782
- total = torch.tensor([0], dtype=torch.long, device=self.device)
783
- correct = torch.tensor([0], dtype=torch.long, device=self.device)
784
- for seq in iterator:
785
- tokens = self.protein_encoder.tokenizer.tokenize(seq)
786
- masked_tokens, labels = self._apply_bert_mask(tokens, self.protein_encoder.tokenizer, 0.15)
787
- seq = " ".join(masked_tokens)
788
-
789
- inputs = self.protein_encoder.tokenizer(seq, return_tensors="pt")
790
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
791
- _, logits = self.protein_encoder(inputs, get_mask_logits=True)
792
-
793
- logits = logits.squeeze(0)
794
- labels = labels.to(self.device)
795
-
796
- selecor = labels != -1
797
- preds = logits.argmax(dim=-1)[selecor]
798
- labels = labels[selecor]
799
-
800
- total += len(preds)
801
- correct += (preds == labels).sum()
802
-
803
- # Gather all results
804
- total = self.padded_gather(total).sum()
805
- correct = self.padded_gather(correct).sum()
806
-
807
- acc = correct / total
808
- return acc.cpu().item()
809
-
810
- def _load_eval_data(self, stage):
811
- # Load the data
812
- lmdb_dir = eval(f"self.trainer.datamodule.{stage}_lmdb")
813
- uniprot2label_path = os.path.join(lmdb_dir, "uniprot2label.json")
814
- label2text_path = os.path.join(lmdb_dir, "label2text.json")
815
- swissprot_id_path = os.path.join(lmdb_dir, "swissprot_ids.tsv")
816
-
817
- self.uniprot2label = json.load(open(uniprot2label_path, "r"))
818
- self.label2text = json.load(open(label2text_path, "r"))
819
- self.swissprot_ids = set(pd.read_csv(swissprot_id_path, sep="\t", header=None).values.flatten().tolist())
820
- self.k = 3
821
-
822
- def on_test_start(self):
823
- self._load_eval_data("test")
824
-
825
- log_dict = self.retrieval_eval()
826
- log_dict = {"test_" + k: v for k, v in log_dict.items()}
827
- if self.use_mlm_loss:
828
- log_dict["test_mask_acc"] = self.mlm_eval()
829
- self.log_info(log_dict)
830
- print(log_dict)
831
-
832
- def on_validation_start(self):
833
- # Clear the cache
834
- torch.cuda.empty_cache()
835
-
836
- self._load_eval_data("valid")
837
-
838
- log_dict = self.retrieval_eval()
839
- log_dict = {"valid_" + k: v for k, v in log_dict.items()}
840
- if self.use_mlm_loss:
841
- log_dict["valid_mask_acc"] = self.mlm_eval()
842
- self.log_info(log_dict)
843
-
844
- self.check_save_condition(self.step, mode="max")
845
-
846
- def test_step(self, batch, batch_idx):
847
- return
848
-
849
- def validation_step(self, batch, batch_idx):
850
- return
851
-
852
- def on_train_epoch_end(self):
853
- super().on_train_epoch_end()
854
- # Re-sample the subset of the training data
855
- if self.trainer.datamodule.train_dataset.fixed_dataset_num is not None:
856
- self.trainer.datamodule.train_dataset.sample_subset()
857
-
858
- # def test_epoch_end(self, outputs):
859
- # log_dict = self.get_log_dict("test")
860
- # log_dict["test_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
861
- #
862
- # print(log_dict)
863
- # self.log_info(log_dict)
864
- #
865
- # self.reset_metrics("test")
866
- #
867
- # def validation_epoch_end(self, outputs):
868
- # log_dict = self.get_log_dict("valid")
869
- # log_dict["valid_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
870
- #
871
- # self.log_info(log_dict)
872
- # self.reset_metrics("valid")
873
- # self.check_save_condition(log_dict["valid_loss"], mode="min")
874
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/ProTrek/structure_encoder.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
-
3
- from tqdm import tqdm
4
- from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer
5
- from torch.nn.functional import normalize
6
-
7
-
8
- class StructureEncoder(torch.nn.Module):
9
- def __init__(self, config_path: str, out_dim: int, gradient_checkpointing: bool = False):
10
- """
11
- Args:
12
- config_path: Path to the config file
13
-
14
- out_dim: Output dimension of the structure representation
15
-
16
- gradient_checkpointing: Whether to use gradient checkpointing
17
- """
18
- super().__init__()
19
- config = EsmConfig.from_pretrained(config_path)
20
- self.model = EsmForMaskedLM(config)
21
- self.out = torch.nn.Linear(config.hidden_size, out_dim)
22
-
23
- # Set gradient checkpointing
24
- self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
25
-
26
- # Remove contact head
27
- self.model.esm.contact_head = None
28
-
29
- # Remove position embedding if the embedding type is ``rotary``
30
- if config.position_embedding_type == "rotary":
31
- self.model.esm.embeddings.position_embeddings = None
32
-
33
- self.tokenizer = EsmTokenizer.from_pretrained(config_path)
34
-
35
- def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
36
- """
37
- Compute protein structure representation for the given proteins
38
- Args:
39
- protein: A list of protein structural sequences
40
- batch_size: Batch size for inference
41
- verbose: Whether to print progress
42
- """
43
- device = next(self.parameters()).device
44
-
45
- protein_repr = []
46
- if verbose:
47
- iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
48
- else:
49
- iterator = range(0, len(proteins), batch_size)
50
-
51
- for i in iterator:
52
- protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
53
- return_tensors="pt",
54
- padding=True)
55
- protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
56
- output, _ = self.forward(protein_inputs)
57
-
58
- protein_repr.append(output)
59
-
60
- protein_repr = torch.cat(protein_repr, dim=0)
61
- return normalize(protein_repr, dim=-1)
62
-
63
- def forward(self, inputs: dict, get_mask_logits: bool = False):
64
- """
65
- Encode protein structure into protein representation
66
- Args:
67
- inputs: A dictionary containing the following keys:
68
- - input_ids: [batch, seq_len]
69
- - attention_mask: [batch, seq_len]
70
- get_mask_logits: Whether to return the logits for masked tokens
71
-
72
- Returns:
73
- protein_repr: [batch, protein_repr_dim]
74
- mask_logits : [batch, seq_len, vocab_size]
75
- """
76
- last_hidden_state = self.model.esm(**inputs).last_hidden_state
77
- reprs = last_hidden_state[:, 0, :]
78
- reprs = self.out(reprs)
79
-
80
- # Get logits for masked tokens
81
- if get_mask_logits:
82
- mask_logits = self.model.lm_head(last_hidden_state)
83
- else:
84
- mask_logits = None
85
-
86
- return reprs, mask_logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/ProTrek/text_encoder.py DELETED
@@ -1,81 +0,0 @@
1
- import torch
2
-
3
- from tqdm import tqdm
4
- from torch.nn.functional import normalize
5
- from transformers import BertConfig, BertModel, BertTokenizer
6
-
7
-
8
- class TextEncoder(torch.nn.Module):
9
- def __init__(self,
10
- config_path: str,
11
- out_dim: int,
12
- load_pretrained: bool = True,
13
- gradient_checkpointing: bool = False):
14
- """
15
- Args:
16
- config_path: Path to the config file
17
-
18
- out_dim: Output dimension of the text representation
19
-
20
- load_pretrained: Whether to load pretrained weights
21
-
22
- gradient_checkpointing: Whether to enable gradient checkpointing
23
- """
24
- super().__init__()
25
- config = BertConfig.from_pretrained(config_path)
26
- if load_pretrained:
27
- self.model = BertModel.from_pretrained(config_path, add_pooling_layer=False)
28
- else:
29
- self.model = BertModel(config, add_pooling_layer=False)
30
- self.out = torch.nn.Linear(config.hidden_size, out_dim)
31
-
32
- # Set gradient checkpointing
33
- self.model.encoder.gradient_checkpointing = gradient_checkpointing
34
-
35
- self.tokenizer = BertTokenizer.from_pretrained(config_path)
36
-
37
- def get_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
38
- """
39
- Compute text representation for the given texts
40
- Args:
41
- texts: A list of strings
42
- batch_size: Batch size for inference
43
- verbose: Whether to print progress
44
- """
45
- device = next(self.parameters()).device
46
-
47
- text_repr = []
48
- if verbose:
49
- iterator = tqdm(range(0, len(texts), batch_size), desc="Computing text embeddings")
50
- else:
51
- iterator = range(0, len(texts), batch_size)
52
-
53
- for i in iterator:
54
- text_inputs = self.tokenizer.batch_encode_plus(texts[i: i+batch_size],
55
- return_tensors="pt",
56
- truncation=True,
57
- max_length=512,
58
- padding=True)
59
- text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
60
- output = self(text_inputs)
61
-
62
- text_repr.append(output)
63
-
64
- text_repr = torch.cat(text_repr, dim=0)
65
- return normalize(text_repr, dim=-1)
66
-
67
- def forward(self, inputs: dict):
68
- """
69
- Encode text into text representation
70
- Args:
71
- inputs: A dictionary containing the following keys:
72
- - input_ids: [batch, seq_len]
73
- - attention_mask: [batch, seq_len]
74
- - token_type_ids: [batch, seq_len]
75
-
76
- Returns:
77
- text_repr: [batch, text_repr_dim]
78
- """
79
- reprs = self.model(**inputs).last_hidden_state[:, 0, :]
80
- reprs = self.out(reprs)
81
- return reprs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/abstract_model.py DELETED
@@ -1,401 +0,0 @@
1
- import torch
2
- import abc
3
- import os
4
- import copy
5
-
6
- import pytorch_lightning as pl
7
- from utils.lr_scheduler import *
8
- from torch import distributed as dist
9
-
10
-
11
- class AbstractModel(pl.LightningModule):
12
- def __init__(self,
13
- lr_scheduler_kwargs: dict = None,
14
- optimizer_kwargs: dict = None,
15
- save_path: str = None,
16
- from_checkpoint: str = None,
17
- load_prev_scheduler: bool = False,
18
- save_weights_only: bool = True,):
19
- """
20
-
21
- Args:
22
- lr_scheduler: Kwargs for lr_scheduler
23
- optimizer_kwargs: Kwargs for optimizer_kwargs
24
- save_path: Save trained model
25
- from_checkpoint: Load model from checkpoint
26
- load_prev_scheduler: Whether load previous scheduler from checkpoint
27
- load_strict: Whether load model strictly
28
- save_weights_only: Whether save only weights or also optimizer and lr_scheduler
29
-
30
- """
31
- super().__init__()
32
- self.initialize_model()
33
-
34
- self.metrics = {}
35
- for stage in ["train", "valid", "test"]:
36
- stage_metrics = self.initialize_metrics(stage)
37
- # Rigister metrics as attributes
38
- for metric_name, metric in stage_metrics.items():
39
- setattr(self, metric_name, metric)
40
-
41
- self.metrics[stage] = stage_metrics
42
-
43
- if lr_scheduler_kwargs is None:
44
- # Default lr_scheduler
45
- self.lr_scheduler_kwargs = {
46
- "class": "ConstantLRScheduler",
47
- "init_lr": 0,
48
- }
49
- print("No lr_scheduler_kwargs provided. The default learning rate is 0.")
50
-
51
- else:
52
- self.lr_scheduler_kwargs = lr_scheduler_kwargs
53
-
54
- if optimizer_kwargs is None:
55
- # Default optimizer
56
- self.optimizer_kwargs = {
57
- "class": "AdamW",
58
- "betas": (0.9, 0.98),
59
- "weight_decay": 0.01,
60
- }
61
- print("No optimizer_kwargs provided. The default optimizer is AdamW.")
62
- else:
63
- self.optimizer_kwargs = optimizer_kwargs
64
- self.init_optimizers()
65
-
66
- self.save_path = save_path
67
- self.save_weights_only = save_weights_only
68
-
69
- # temp_step is used for accumulating gradients
70
- self.temp_step = 0
71
- self.step = 0
72
- self.epoch = 0
73
-
74
- self.load_prev_scheduler = load_prev_scheduler
75
- self.from_checkpoint = from_checkpoint
76
- if from_checkpoint:
77
- self.load_checkpoint(from_checkpoint)
78
-
79
- @abc.abstractmethod
80
- def initialize_model(self) -> None:
81
- """
82
- All model initialization should be done here
83
- Note that the whole model must be named as "self.model" for model saving and loading
84
- """
85
- raise NotImplementedError
86
-
87
- @abc.abstractmethod
88
- def forward(self, *args, **kwargs):
89
- """
90
- Forward propagation
91
- """
92
- raise NotImplementedError
93
-
94
- @abc.abstractmethod
95
- def initialize_metrics(self, stage: str) -> dict:
96
- """
97
- Initialize metrics for each stage
98
- Args:
99
- stage: "train", "valid" or "test"
100
-
101
- Returns:
102
- A dictionary of metrics for the stage. Keys are metric names and values are metric objects
103
- """
104
- raise NotImplementedError
105
-
106
- @abc.abstractmethod
107
- def loss_func(self, stage: str, outputs, labels) -> torch.Tensor:
108
- """
109
-
110
- Args:
111
- stage: "train", "valid" or "test"
112
- outputs: model outputs for calculating loss
113
- labels: labels for calculating loss
114
-
115
- Returns:
116
- loss
117
-
118
- """
119
- raise NotImplementedError
120
-
121
- @staticmethod
122
- def load_weights(model, weights):
123
- model_dict = model.state_dict()
124
-
125
- unused_params = []
126
- missed_params = list(model_dict.keys())
127
-
128
- for k, v in weights.items():
129
- if k in model_dict.keys():
130
- model_dict[k] = v
131
- missed_params.remove(k)
132
-
133
- else:
134
- unused_params.append(k)
135
-
136
- if len(missed_params) > 0:
137
- print(f"\033[31mSome weights of {type(model).__name__} were not "
138
- f"initialized from the model checkpoint: {missed_params}\033[0m")
139
-
140
- if len(unused_params) > 0:
141
- print(f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m")
142
-
143
- model.load_state_dict(model_dict)
144
-
145
- def optimizer_step(
146
- self,
147
- epoch: int,
148
- batch_idx: int,
149
- optimizer,
150
- optimizer_closure=None,
151
- ) -> None:
152
- super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)
153
-
154
- self.temp_step += 1
155
- if self.temp_step == self.trainer.accumulate_grad_batches:
156
- self.step += 1
157
- self.temp_step = 0
158
-
159
- # For pytorch-lightning 1.9.5
160
- # def optimizer_step(
161
- # self,
162
- # epoch: int,
163
- # batch_idx: int,
164
- # optimizer,
165
- # optimizer_idx: int = 0,
166
- # optimizer_closure=None,
167
- # on_tpu: bool = False,
168
- # using_native_amp: bool = False,
169
- # using_lbfgs: bool = False,
170
- # ) -> None:
171
- # super().optimizer_step(
172
- # epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs
173
- # )
174
- # self.temp_step += 1
175
- # if self.temp_step == self.trainer.accumulate_grad_batches:
176
- # self.step += 1
177
- # self.temp_step = 0
178
-
179
- def on_train_epoch_end(self):
180
- self.epoch += 1
181
-
182
- def training_step(self, batch, batch_idx):
183
- inputs, labels = batch
184
-
185
- # optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.98))
186
- # for _ in range(1000):
187
- # outputs = self(**inputs)
188
- # loss = self.loss_func('train', outputs, labels)
189
- # loss.backward()
190
- # optimizer.step()
191
- # optimizer.zero_grad()
192
- #
193
- # raise
194
-
195
- outputs = self(**inputs)
196
- loss = self.loss_func('train', outputs, labels)
197
-
198
- self.log("loss", loss, prog_bar=True)
199
- return loss
200
-
201
- def validation_step(self, batch, batch_idx):
202
- inputs, labels = batch
203
- outputs = self(**inputs)
204
- loss = self.loss_func('valid', outputs, labels)
205
- self.valid_outputs.append(loss)
206
- return loss
207
-
208
- def test_step(self, batch, batch_idx):
209
- inputs, labels = batch
210
- outputs = self(**inputs)
211
-
212
- loss = self.loss_func('test', outputs, labels)
213
- self.test_outputs.append(loss)
214
- return loss
215
-
216
- def on_train_start(self) -> None:
217
- # Load previous scheduler
218
- if getattr(self, "prev_schechuler", None) is not None:
219
- try:
220
- self.step = self.prev_schechuler["global_step"]
221
- self.epoch = self.prev_schechuler["epoch"]
222
- self.best_value = self.prev_schechuler["best_value"]
223
- self.lr_scheduler.load_state_dict(self.prev_schechuler["lr_scheduler"])
224
- print(f"Previous training global step: {self.step}")
225
- print(f"Previous training epoch: {self.epoch}")
226
- print(f"Previous best value: {self.best_value}")
227
- print(f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}")
228
-
229
- # Load optimizer state
230
- if hasattr(self.trainer.strategy, "deepspeed_engine"):
231
- # For DeepSpeed strategy
232
- try:
233
- self.trainer.strategy.deepspeed_engine.load_checkpoint(self.from_checkpoint)
234
- except Exception as e:
235
- print(e)
236
-
237
- else:
238
- # For DDP strategy
239
- self.optimizer.load_state_dict(self.prev_schechuler["optimizer"])
240
-
241
- except Exception as e:
242
- print(e)
243
- raise Exception("Error in loading previous scheduler. Please set load_prev_scheduler=False")
244
-
245
- def on_validation_epoch_start(self) -> None:
246
- setattr(self, "valid_outputs", [])
247
-
248
- def on_test_epoch_start(self) -> None:
249
- setattr(self, "test_outputs", [])
250
-
251
- def load_checkpoint(self, from_checkpoint: str) -> None:
252
- """
253
- Args:
254
- from_checkpoint: Path to checkpoint.
255
- """
256
-
257
- # If ``from_checkpoint`` is a directory, load the checkpoint in it
258
- if os.path.isdir(from_checkpoint):
259
- basename = os.path.basename(from_checkpoint)
260
- from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt")
261
-
262
- state_dict = torch.load(from_checkpoint, map_location=self.device)
263
- self.load_weights(self.model, state_dict["model"])
264
-
265
- if self.load_prev_scheduler:
266
- state_dict.pop("model")
267
- self.prev_schechuler = state_dict
268
-
269
- def save_checkpoint(self, save_path: str, save_info: dict = None, save_weights_only: bool = True) -> None:
270
- """
271
- Save model to save_path
272
- Args:
273
- save_path: Path to save model
274
- save_info: Other info to save
275
- save_weights_only: Whether only save model weights
276
- """
277
- dir = os.path.dirname(save_path)
278
- os.makedirs(dir, exist_ok=True)
279
-
280
- state_dict = {} if save_info is None else save_info
281
- state_dict["model"] = self.model.state_dict()
282
-
283
- # Convert model weights to fp32
284
- for k, v in state_dict["model"].items():
285
- state_dict["model"][k] = v.float()
286
-
287
- if not save_weights_only:
288
- state_dict["global_step"] = self.step
289
- state_dict["epoch"] = self.epoch
290
- state_dict["best_value"] = getattr(self, f"best_value", None)
291
- state_dict["lr_scheduler"] = self.lr_schedulers().state_dict()
292
-
293
- # If not using DeepSpeed, save optimizer state
294
- if not hasattr(self.trainer.strategy, "deepspeed_engine"):
295
- state_dict["optimizer"] = self.optimizers().optimizer.state_dict()
296
-
297
- torch.save(state_dict, save_path)
298
-
299
- def check_save_condition(self, now_value: float, mode: str, save_info: dict = None) -> None:
300
- """
301
- Check whether to save model. If save_path is not None and now_value is the best, save model.
302
- Args:
303
- now_value: Current metric value
304
- mode: "min" or "max", meaning whether the lower the better or the higher the better
305
- save_info: Other info to save
306
- """
307
-
308
- assert mode in ["min", "max"], "mode should be 'min' or 'max'"
309
-
310
- if self.save_path is not None:
311
- # In case there are variables to be included in the save path
312
- save_path = eval(f"f'{self.save_path}'")
313
-
314
- dir = os.path.dirname(save_path)
315
- os.makedirs(dir, exist_ok=True)
316
-
317
- # Check whether to save model
318
- best_value = getattr(self, f"best_value", None)
319
- if best_value is not None:
320
- if mode == "min" and now_value >= best_value or mode == "max" and now_value <= best_value:
321
- return
322
-
323
- setattr(self, "best_value", now_value)
324
-
325
- # For DeepSpeed strategy
326
- if hasattr(self.trainer.strategy, "deepspeed_engine"):
327
- if not self.save_weights_only:
328
- self.trainer.strategy.deepspeed_engine.save_checkpoint(save_path, tag="deepspeed_ckpt")
329
-
330
- # Save a complete checkpoint
331
- if dist.get_rank() == 0:
332
- basename = os.path.basename(save_path)
333
- ckpt_path = os.path.join(save_path, f"{basename}.pt")
334
- self.save_checkpoint(ckpt_path, save_info, self.save_weights_only)
335
-
336
- # For normal situation
337
- else:
338
- if dist.get_rank() == 0:
339
- self.save_checkpoint(save_path, save_info, self.save_weights_only)
340
-
341
- def reset_metrics(self, stage) -> None:
342
- """
343
- Reset metrics for given stage
344
- Args:
345
- stage: "train", "valid" or "test"
346
- """
347
- for metric in self.metrics[stage].values():
348
- metric.reset()
349
-
350
- def get_log_dict(self, stage: str) -> dict:
351
- """
352
- Get log dict for the stage
353
- Args:
354
- stage: "train", "valid" or "test"
355
-
356
- Returns:
357
- A dictionary of metrics for the stage. Keys are metric names and values are metric values
358
-
359
- """
360
- return {name: metric.compute() for name, metric in self.metrics[stage].items()}
361
-
362
- def log_info(self, info: dict) -> None:
363
- """
364
- Record metrics during training and testing
365
- Args:
366
- info: dict of metrics
367
- """
368
- if getattr(self, "logger", None) is not None and dist.get_rank() == 0:
369
- info["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
370
- info["epoch"] = self.epoch
371
- self.logger.log_metrics(info, step=self.step)
372
-
373
- def init_optimizers(self):
374
- copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs)
375
-
376
- # No decay for layer norm and bias
377
- no_decay = ['LayerNorm.weight', 'bias']
378
- weight_decay = copy_optimizer_kwargs.pop("weight_decay")
379
-
380
- optimizer_grouped_parameters = [
381
- {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
382
- 'weight_decay': weight_decay},
383
- {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
384
- 'weight_decay': 0.0}
385
- ]
386
-
387
- optimizer_cls = eval(f"torch.optim.{copy_optimizer_kwargs.pop('class')}")
388
- self.optimizer = optimizer_cls(optimizer_grouped_parameters,
389
- lr=self.lr_scheduler_kwargs['init_lr'],
390
- **copy_optimizer_kwargs)
391
-
392
- tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs)
393
- lr_scheduler = tmp_kwargs.pop("class")
394
- self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs)
395
-
396
- def configure_optimizers(self):
397
- return {"optimizer": self.optimizer,
398
- "lr_scheduler": {"scheduler": self.lr_scheduler,
399
- "interval": "step",
400
- "frequency": 1}
401
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/model_interface.py DELETED
@@ -1,104 +0,0 @@
1
- import os
2
- import yaml
3
- import glob
4
-
5
-
6
- # register all available models through *_model.py files
7
- # def construct_model():
8
- # model_dir = os.path.dirname(__file__)
9
- #
10
- # # lists all model files
11
- # model_list = []
12
- # for root, _, names in os.walk(model_dir):
13
- # for name in names:
14
- # if name.endswith('_model.py'):
15
- # sub_dirs = root.replace(model_dir, '').split(os.sep)
16
- # model_list.append((sub_dirs, name[:-3]))
17
- #
18
- # # load model_config.yaml, controlling which models to be loaded
19
- # model_config = yaml.safe_load(open(f"{model_dir}/model_config.yaml", "r"))
20
- #
21
- # if model_config["verbose"]:
22
- # print("*" * 30 + f" Loading model " + "*" * 30)
23
- #
24
- # # register models
25
- # for sub_dirs, name in model_list:
26
- # if name in model_config["models"]:
27
- # if len(sub_dirs) > 1:
28
- # cmd = f"from {'.'.join(sub_dirs)} import {name}"
29
- # else:
30
- # cmd = f"from . import {name}"
31
- #
32
- # exec(cmd)
33
- #
34
- # if model_config["verbose"]:
35
- # info = f"Loaded model: {name}"
36
- # print(f"\033[32m{info}\033[0m")
37
- # else:
38
- # if model_config["verbose"]:
39
- # info = f"Skipped model: {name}"
40
- # print(f"\033[31m{info}\033[0m")
41
- #
42
- # if model_config["verbose"]:
43
- # print("*" * 75)
44
- #
45
- #
46
- # # register function as a wrapper for all models
47
- # def register_model(cls):
48
- # model_dict[cls.__name__] = cls
49
- # return cls
50
- #
51
- #
52
- # model_dict = {}
53
- # construct_model()
54
- #
55
- #
56
- # class ModelInterface:
57
- # @classmethod
58
- # def get_available_models(cls):
59
- # return model_dict.keys()
60
- #
61
- # @classmethod
62
- # def init_model(cls, model: str, **kwargs):
63
- # """
64
- #
65
- # Args:
66
- # model : Class name of model you want to use. Must be in model_dict.keys()
67
- # **kwargs: Kwargs for model initialization
68
- #
69
- # Returns: Corresponding model
70
- #
71
- # """
72
- # assert model in model_dict.keys(), f"class {model} doesn't exist!"
73
- # return model_dict[model](**kwargs)
74
-
75
-
76
- ########################################################################
77
- # Version 2 #
78
- ########################################################################
79
- # register function as a wrapper for all models
80
- def register_model(cls):
81
- global now_cls
82
- now_cls = cls
83
- return cls
84
-
85
-
86
- now_cls = None
87
-
88
-
89
- class ModelInterface:
90
- @classmethod
91
- def init_model(cls, model_py_path: str, **kwargs):
92
- """
93
-
94
- Args:
95
- model_py_path: Py file Path of model you want to use.
96
- **kwargs: Kwargs for model initialization
97
-
98
- Returns: Corresponding model
99
- """
100
- sub_dirs = model_py_path.split(os.sep)
101
- cmd = f"from {'.' + '.'.join(sub_dirs[:-1])} import {sub_dirs[-1]}"
102
- exec(cmd)
103
-
104
- return now_cls(**kwargs)