Spaces:
Running
Running
Delete model
Browse files- model/ProTrek/protein_encoder.py +0 -95
- model/ProTrek/protrek_trimodal_model.py +0 -874
- model/ProTrek/structure_encoder.py +0 -86
- model/ProTrek/text_encoder.py +0 -81
- model/abstract_model.py +0 -401
- model/model_interface.py +0 -104
model/ProTrek/protein_encoder.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
from tqdm import tqdm
|
4 |
-
from torch.nn.functional import normalize
|
5 |
-
from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer
|
6 |
-
|
7 |
-
|
8 |
-
class ProteinEncoder(torch.nn.Module):
|
9 |
-
def __init__(self,
|
10 |
-
config_path: str,
|
11 |
-
out_dim: int,
|
12 |
-
load_pretrained: bool = True,
|
13 |
-
gradient_checkpointing: bool = False):
|
14 |
-
"""
|
15 |
-
Args:
|
16 |
-
config_path: Path to the config file
|
17 |
-
|
18 |
-
out_dim : Output dimension of the protein representation
|
19 |
-
|
20 |
-
load_pretrained: Whether to load pretrained weights
|
21 |
-
|
22 |
-
gradient_checkpointing: Whether to use gradient checkpointing
|
23 |
-
"""
|
24 |
-
super().__init__()
|
25 |
-
config = EsmConfig.from_pretrained(config_path)
|
26 |
-
if load_pretrained:
|
27 |
-
self.model = EsmForMaskedLM.from_pretrained(config_path)
|
28 |
-
else:
|
29 |
-
self.model = EsmForMaskedLM(config)
|
30 |
-
self.out = torch.nn.Linear(config.hidden_size, out_dim)
|
31 |
-
|
32 |
-
# Set gradient checkpointing
|
33 |
-
self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
|
34 |
-
|
35 |
-
# Remove contact head
|
36 |
-
self.model.esm.contact_head = None
|
37 |
-
|
38 |
-
# Remove position embedding if the embedding type is ``rotary``
|
39 |
-
if config.position_embedding_type == "rotary":
|
40 |
-
self.model.esm.embeddings.position_embeddings = None
|
41 |
-
|
42 |
-
self.tokenizer = EsmTokenizer.from_pretrained(config_path)
|
43 |
-
|
44 |
-
def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
45 |
-
"""
|
46 |
-
Compute protein representation for the given proteins
|
47 |
-
Args:
|
48 |
-
protein: A list of protein sequences
|
49 |
-
batch_size: Batch size for inference
|
50 |
-
verbose: Whether to print progress
|
51 |
-
"""
|
52 |
-
device = next(self.parameters()).device
|
53 |
-
|
54 |
-
protein_repr = []
|
55 |
-
if verbose:
|
56 |
-
iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
|
57 |
-
else:
|
58 |
-
iterator = range(0, len(proteins), batch_size)
|
59 |
-
|
60 |
-
for i in iterator:
|
61 |
-
protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
|
62 |
-
return_tensors="pt",
|
63 |
-
padding=True)
|
64 |
-
protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
|
65 |
-
output, _ = self.forward(protein_inputs)
|
66 |
-
|
67 |
-
protein_repr.append(output)
|
68 |
-
|
69 |
-
protein_repr = torch.cat(protein_repr, dim=0)
|
70 |
-
return normalize(protein_repr, dim=-1)
|
71 |
-
|
72 |
-
def forward(self, inputs: dict, get_mask_logits: bool = False):
|
73 |
-
"""
|
74 |
-
Encode protein sequence into protein representation
|
75 |
-
Args:
|
76 |
-
inputs: A dictionary containing the following keys:
|
77 |
-
- input_ids: [batch, seq_len]
|
78 |
-
- attention_mask: [batch, seq_len]
|
79 |
-
get_mask_logits: Whether to return the logits for masked tokens
|
80 |
-
|
81 |
-
Returns:
|
82 |
-
protein_repr: [batch, protein_repr_dim]
|
83 |
-
mask_logits : [batch, seq_len, vocab_size]
|
84 |
-
"""
|
85 |
-
last_hidden_state = self.model.esm(**inputs).last_hidden_state
|
86 |
-
reprs = last_hidden_state[:, 0, :]
|
87 |
-
reprs = self.out(reprs)
|
88 |
-
|
89 |
-
# Get logits for masked tokens
|
90 |
-
if get_mask_logits:
|
91 |
-
mask_logits = self.model.lm_head(last_hidden_state)
|
92 |
-
else:
|
93 |
-
mask_logits = None
|
94 |
-
|
95 |
-
return reprs, mask_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/ProTrek/protrek_trimodal_model.py
DELETED
@@ -1,874 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.distributed as dist
|
3 |
-
import torchmetrics
|
4 |
-
import json
|
5 |
-
import math
|
6 |
-
import numpy as np
|
7 |
-
import os
|
8 |
-
import copy
|
9 |
-
import faiss
|
10 |
-
import time
|
11 |
-
import pandas as pd
|
12 |
-
import random
|
13 |
-
|
14 |
-
from tqdm import tqdm
|
15 |
-
from .protein_encoder import ProteinEncoder
|
16 |
-
from .structure_encoder import StructureEncoder
|
17 |
-
from .text_encoder import TextEncoder
|
18 |
-
from ..abstract_model import AbstractModel
|
19 |
-
from ..model_interface import register_model
|
20 |
-
from utils.mpr import MultipleProcessRunnerSimplifier
|
21 |
-
from torch.nn.functional import normalize, cross_entropy
|
22 |
-
from utils.constants import residue_level, sequence_level
|
23 |
-
from sklearn.metrics import roc_auc_score
|
24 |
-
|
25 |
-
|
26 |
-
def multilabel_cross_entropy(logits, labels):
|
27 |
-
"""
|
28 |
-
Compute cross entropy loss for multilabel classification。 See "https://arxiv.org/pdf/2208.02955.pdf"
|
29 |
-
Args:
|
30 |
-
logits: [num_samples, num_classes]
|
31 |
-
labels: [num_samples, num_classes]
|
32 |
-
"""
|
33 |
-
|
34 |
-
loss = 0
|
35 |
-
for pred, label in zip(logits, labels):
|
36 |
-
pos_logits = pred[label == 1]
|
37 |
-
neg_logits = pred[label == 0]
|
38 |
-
|
39 |
-
diff = neg_logits.unsqueeze(-1) - pos_logits
|
40 |
-
loss += torch.log(1 + torch.exp(diff).sum())
|
41 |
-
|
42 |
-
return loss / len(logits)
|
43 |
-
|
44 |
-
# pred = (1 - 2 * labels) * logits
|
45 |
-
# pred_neg = pred - labels * 1e12
|
46 |
-
# pred_pos = pred - (1 - labels) * 1e12
|
47 |
-
#
|
48 |
-
# zeros = torch.zeros_like(logits[..., :1], dtype=logits.dtype)
|
49 |
-
# pred_neg = torch.cat([pred_neg, zeros], dim=-1)
|
50 |
-
# pred_pos = torch.cat([pred_pos, zeros], dim=-1)
|
51 |
-
#
|
52 |
-
# neg_loss = torch.logsumexp(pred_neg, dim=-1)
|
53 |
-
# pos_loss = torch.logsumexp(pred_pos, dim=-1)
|
54 |
-
#
|
55 |
-
# return (neg_loss + pos_loss).mean()
|
56 |
-
|
57 |
-
|
58 |
-
@register_model
|
59 |
-
class ProTrekTrimodalModel(AbstractModel):
|
60 |
-
def __init__(self,
|
61 |
-
protein_config: str,
|
62 |
-
text_config: str,
|
63 |
-
structure_config: str = None,
|
64 |
-
repr_dim: int = 1024,
|
65 |
-
temperature: float = 0.07,
|
66 |
-
load_protein_pretrained: bool = True,
|
67 |
-
load_text_pretrained: bool = True,
|
68 |
-
use_mlm_loss: bool = False,
|
69 |
-
use_zlpr_loss: bool = False,
|
70 |
-
use_saprot: bool = False,
|
71 |
-
gradient_checkpointing: bool = False,
|
72 |
-
**kwargs):
|
73 |
-
"""
|
74 |
-
Args:
|
75 |
-
protein_config: Path to the config file for protein sequence encoder
|
76 |
-
|
77 |
-
text_config: Path to the config file for text encoder
|
78 |
-
|
79 |
-
structure_config: Path to the config file for structure encoder
|
80 |
-
|
81 |
-
repr_dim: Output dimension of the protein and text representation
|
82 |
-
|
83 |
-
temperature: Temperature for softmax
|
84 |
-
|
85 |
-
load_protein_pretrained: Whether to load pretrained weights for protein encoder
|
86 |
-
|
87 |
-
load_text_pretrained: Whether to load pretrained weights for text encoder
|
88 |
-
|
89 |
-
use_mlm_loss: Whether to use masked language modeling loss
|
90 |
-
|
91 |
-
use_zlpr_loss: Whether to use zlpr loss. See "https://arxiv.org/pdf/2208.02955.pdf"
|
92 |
-
|
93 |
-
use_saprot: Whether to use SaProt as protein encoder
|
94 |
-
|
95 |
-
gradient_checkpointing: Whether to use gradient checkpointing for protein encoder
|
96 |
-
"""
|
97 |
-
self.protein_config = protein_config
|
98 |
-
self.structure_config = structure_config
|
99 |
-
self.text_config = text_config
|
100 |
-
self.repr_dim = repr_dim
|
101 |
-
self.temperature = temperature
|
102 |
-
self.load_protein_pretrained = load_protein_pretrained
|
103 |
-
self.load_text_pretrained = load_text_pretrained
|
104 |
-
self.use_mlm_loss = use_mlm_loss
|
105 |
-
self.use_zlpr_loss = use_zlpr_loss
|
106 |
-
self.use_saprot = use_saprot
|
107 |
-
self.gradient_checkpointing = gradient_checkpointing
|
108 |
-
super().__init__(**kwargs)
|
109 |
-
|
110 |
-
def initialize_metrics(self, stage: str) -> dict:
|
111 |
-
return_dict = {
|
112 |
-
f"{stage}_protein_text_acc": torchmetrics.Accuracy(),
|
113 |
-
f"{stage}_text_protein_acc": torchmetrics.Accuracy(),
|
114 |
-
}
|
115 |
-
|
116 |
-
if self.use_mlm_loss:
|
117 |
-
return_dict[f"{stage}_protein_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
|
118 |
-
if self.structure_config is not None:
|
119 |
-
return_dict[f"{stage}_structure_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
|
120 |
-
|
121 |
-
if self.structure_config is not None:
|
122 |
-
return_dict[f"{stage}_structure_protein_acc"] = torchmetrics.Accuracy()
|
123 |
-
return_dict[f"{stage}_structure_text_acc"] = torchmetrics.Accuracy()
|
124 |
-
return_dict[f"{stage}_text_structure_acc"] = torchmetrics.Accuracy()
|
125 |
-
return_dict[f"{stage}_protein_structure_acc"] = torchmetrics.Accuracy()
|
126 |
-
|
127 |
-
return return_dict
|
128 |
-
|
129 |
-
def initialize_model(self):
|
130 |
-
# Initialize encoders
|
131 |
-
self.protein_encoder = ProteinEncoder(self.protein_config,
|
132 |
-
self.repr_dim,
|
133 |
-
self.load_protein_pretrained,
|
134 |
-
self.gradient_checkpointing)
|
135 |
-
|
136 |
-
self.text_encoder = TextEncoder(self.text_config,
|
137 |
-
self.repr_dim,
|
138 |
-
self.load_text_pretrained,
|
139 |
-
self.gradient_checkpointing)
|
140 |
-
|
141 |
-
# Learnable temperature
|
142 |
-
self.temperature = torch.nn.Parameter(torch.tensor(self.temperature))
|
143 |
-
|
144 |
-
# self.model is used for saving and loading
|
145 |
-
self.model = torch.nn.ParameterList([self.temperature,
|
146 |
-
self.protein_encoder,
|
147 |
-
self.text_encoder])
|
148 |
-
|
149 |
-
# If the structure encoder is specified
|
150 |
-
if self.structure_config is not None:
|
151 |
-
self.structure_encoder = StructureEncoder(self.structure_config, self.repr_dim)
|
152 |
-
self.model.append(self.structure_encoder)
|
153 |
-
|
154 |
-
def get_text_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
155 |
-
return self.text_encoder.get_repr(texts, batch_size, verbose)
|
156 |
-
|
157 |
-
def get_structure_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
158 |
-
return self.structure_encoder.get_repr(proteins, batch_size, verbose)
|
159 |
-
|
160 |
-
def get_protein_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
161 |
-
return self.protein_encoder.get_repr(proteins, batch_size, verbose)
|
162 |
-
|
163 |
-
def forward(self, protein_inputs: dict, text_inputs: dict, structure_inputs: dict = None):
|
164 |
-
"""
|
165 |
-
Args:
|
166 |
-
protein_inputs: A dictionary for protein encoder
|
167 |
-
structure_inputs: A dictionary for structure encoder
|
168 |
-
text_inputs : A dictionary for text encoder
|
169 |
-
"""
|
170 |
-
protein_repr, protein_mask_logits = self.protein_encoder(protein_inputs, self.use_mlm_loss)
|
171 |
-
text_repr = self.text_encoder(text_inputs)
|
172 |
-
|
173 |
-
outputs = [text_repr, protein_repr, protein_mask_logits]
|
174 |
-
|
175 |
-
if self.structure_config is not None:
|
176 |
-
structure_repr, structure_mask_logits = self.structure_encoder(structure_inputs, self.use_mlm_loss)
|
177 |
-
outputs += [structure_repr, structure_mask_logits]
|
178 |
-
|
179 |
-
return outputs
|
180 |
-
|
181 |
-
def loss_func(self, stage: str, outputs, labels):
|
182 |
-
if self.structure_config is not None:
|
183 |
-
text_repr, protein_repr, protein_mask_logits, structure_repr, structure_mask_logits = outputs
|
184 |
-
else:
|
185 |
-
text_repr, protein_repr, protein_mask_logits = outputs
|
186 |
-
|
187 |
-
device = text_repr.device
|
188 |
-
|
189 |
-
text_repr = normalize(text_repr, dim=-1)
|
190 |
-
protein_repr = normalize(protein_repr, dim=-1)
|
191 |
-
|
192 |
-
# Gather representations from all GPUs
|
193 |
-
all_protein_repr = self.all_gather(protein_repr).view(-1, protein_repr.shape[-1]).detach()
|
194 |
-
all_text_repr = self.all_gather(text_repr).view(-1, text_repr.shape[-1]).detach()
|
195 |
-
|
196 |
-
if self.structure_config is not None:
|
197 |
-
structure_repr = normalize(structure_repr, dim=-1)
|
198 |
-
all_structure_repr = self.all_gather(structure_repr).view(-1, structure_repr.shape[-1]).detach()
|
199 |
-
|
200 |
-
# text_idx = labels["text_idx"]
|
201 |
-
# text_candidates = labels["text_candidates"]
|
202 |
-
#
|
203 |
-
# # Gather all text ids
|
204 |
-
# text_inds = self.all_gather(text_idx).flatten()
|
205 |
-
# # Create text classification labels
|
206 |
-
# text_labels = torch.zeros(len(text_candidates), len(text_inds), dtype=int).to(device)
|
207 |
-
# for i, candidate in enumerate(text_candidates):
|
208 |
-
# for j, idx in enumerate(text_inds):
|
209 |
-
# if idx.item() in candidate:
|
210 |
-
# text_labels[i, j] = 1
|
211 |
-
#
|
212 |
-
# # Gather text labels from all GPUs
|
213 |
-
# text_labels = self.all_gather(text_labels).view(-1, text_labels.shape[-1])
|
214 |
-
#
|
215 |
-
# # Protein classification labels are the transpose of text labels
|
216 |
-
# protein_labels = text_labels.T
|
217 |
-
|
218 |
-
# Batch size
|
219 |
-
rank = dist.get_rank()
|
220 |
-
bs = text_repr.shape[0]
|
221 |
-
|
222 |
-
# Get current labels
|
223 |
-
# protein_labels = protein_labels[rank * bs: rank * bs + bs]
|
224 |
-
# text_labels = text_labels[rank * bs: rank * bs + bs]
|
225 |
-
|
226 |
-
# Create classification labels between structure and sequence
|
227 |
-
bs_labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(device)
|
228 |
-
|
229 |
-
if self.structure_config is not None:
|
230 |
-
pairs = {
|
231 |
-
"protein": ["structure", "text"],
|
232 |
-
"structure": ["protein", "text"],
|
233 |
-
"text": ["protein", "structure"]
|
234 |
-
}
|
235 |
-
else:
|
236 |
-
pairs = {
|
237 |
-
"protein": ["text"],
|
238 |
-
"text": ["protein"]
|
239 |
-
}
|
240 |
-
|
241 |
-
loss_list = []
|
242 |
-
for k, values in pairs.items():
|
243 |
-
for v in values:
|
244 |
-
# Only calculate the similarity for the current batch
|
245 |
-
sim = torch.matmul(eval(f"{k}_repr"), eval(f"all_{v}_repr").T).div(self.temperature)
|
246 |
-
|
247 |
-
# if k == "text":
|
248 |
-
# if self.use_zlpr_loss:
|
249 |
-
# loss = multilabel_cross_entropy(sim, protein_labels)
|
250 |
-
# else:
|
251 |
-
# loss = cross_entropy(sim, bs_labels)
|
252 |
-
#
|
253 |
-
# pred = []
|
254 |
-
# for s, l in zip(sim, protein_labels):
|
255 |
-
# n_label = l.sum()
|
256 |
-
# topk = torch.topk(s, k=n_label).indices
|
257 |
-
# if l[topk].sum() == n_label:
|
258 |
-
# pred.append(1)
|
259 |
-
# else:
|
260 |
-
# pred.append(0)
|
261 |
-
#
|
262 |
-
# pred = torch.tensor(pred).to(device)
|
263 |
-
# label = torch.ones_like(pred)
|
264 |
-
# self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
|
265 |
-
# # if v == "protein":
|
266 |
-
# # acc = self.metrics[stage][f"{stage}_{k}_{v}_acc"].compute()
|
267 |
-
# # print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
|
268 |
-
#
|
269 |
-
# elif v == "text":
|
270 |
-
# if self.use_zlpr_loss:
|
271 |
-
# loss = multilabel_cross_entropy(sim, text_labels)
|
272 |
-
# else:
|
273 |
-
# loss = cross_entropy(sim, bs_labels)
|
274 |
-
#
|
275 |
-
# pred = []
|
276 |
-
# for s, l in zip(sim, text_labels):
|
277 |
-
# n_label = l.sum()
|
278 |
-
# topk = torch.topk(s, k=n_label).indices
|
279 |
-
# if l[topk].sum() == n_label:
|
280 |
-
# pred.append(1)
|
281 |
-
# else:
|
282 |
-
# pred.append(0)
|
283 |
-
#
|
284 |
-
# pred = torch.tensor(pred).to(device)
|
285 |
-
# label = torch.ones_like(pred)
|
286 |
-
# # if k == "protein":
|
287 |
-
# # acc = pred.sum() / len(pred)
|
288 |
-
# # print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
|
289 |
-
# self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
|
290 |
-
#
|
291 |
-
# else:
|
292 |
-
# loss = cross_entropy(sim, bs_labels)
|
293 |
-
# self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
|
294 |
-
|
295 |
-
loss = cross_entropy(sim, bs_labels)
|
296 |
-
self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
|
297 |
-
loss_list.append(loss)
|
298 |
-
|
299 |
-
# Masked language modeling loss
|
300 |
-
if self.use_mlm_loss:
|
301 |
-
k_label = [("protein", labels["seq_labels"])]
|
302 |
-
if self.structure_config is not None:
|
303 |
-
k_label.append(("structure", labels["struc_labels"]))
|
304 |
-
|
305 |
-
for k, label in k_label:
|
306 |
-
logits = eval(f"{k}_mask_logits")
|
307 |
-
# merge the first and second dimension of logits
|
308 |
-
logits = logits.view(-1, logits.shape[-1])
|
309 |
-
label = label.flatten().to(device)
|
310 |
-
mlm_loss = cross_entropy(logits, label, ignore_index=-1)
|
311 |
-
loss_list.append(mlm_loss)
|
312 |
-
self.metrics[stage][f"{stage}_{k}_mask_acc"].update(logits.detach(), label)
|
313 |
-
|
314 |
-
loss = sum(loss_list) / len(loss_list)
|
315 |
-
|
316 |
-
if stage == "train":
|
317 |
-
log_dict = self.get_log_dict("train")
|
318 |
-
log_dict["train_loss"] = loss
|
319 |
-
self.log_info(log_dict)
|
320 |
-
|
321 |
-
# Reset train metrics
|
322 |
-
self.reset_metrics("train")
|
323 |
-
|
324 |
-
return loss
|
325 |
-
|
326 |
-
def padded_gather(self, tensor: torch.Tensor):
|
327 |
-
"""
|
328 |
-
Gather tensors from all GPUs, allowing different shapes at the batch dimension.
|
329 |
-
"""
|
330 |
-
|
331 |
-
# Get the size of the tensor
|
332 |
-
size = tensor.shape[0]
|
333 |
-
all_sizes = self.all_gather(torch.tensor(size, device=tensor.device))
|
334 |
-
max_size = max(all_sizes)
|
335 |
-
|
336 |
-
# Pad the tensor
|
337 |
-
if size != max_size:
|
338 |
-
tmp = torch.zeros(max_size, tensor.shape[-1], dtype=tensor.dtype, device=tensor.device)
|
339 |
-
tmp[:size] = tensor
|
340 |
-
tensor = tmp
|
341 |
-
|
342 |
-
padded_tensor = self.all_gather(tensor).view(-1, tensor.shape[-1])
|
343 |
-
tensor = padded_tensor[:sum(all_sizes)]
|
344 |
-
|
345 |
-
return tensor
|
346 |
-
|
347 |
-
def _get_protein_indices(self):
|
348 |
-
world_size = dist.get_world_size()
|
349 |
-
rank = dist.get_rank()
|
350 |
-
|
351 |
-
if self.use_saprot:
|
352 |
-
proteins = []
|
353 |
-
for sub_dict in self.uniprot2label.values():
|
354 |
-
aa_seq = sub_dict["seq"]
|
355 |
-
foldseek_seq = sub_dict["foldseek"]
|
356 |
-
assert len(aa_seq) == len(foldseek_seq)
|
357 |
-
seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
|
358 |
-
proteins.append(seq)
|
359 |
-
|
360 |
-
else:
|
361 |
-
proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
|
362 |
-
|
363 |
-
span = math.ceil(len(proteins) / world_size)
|
364 |
-
sub_proteins = proteins[rank * span: (rank + 1) * span]
|
365 |
-
|
366 |
-
# Display the progress bar on the rank 0 process
|
367 |
-
verbose = self.trainer.local_rank == 0
|
368 |
-
# Get protein representations
|
369 |
-
sub_protein_repr = self.protein_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
|
370 |
-
protein_repr = self.padded_gather(sub_protein_repr)
|
371 |
-
|
372 |
-
# Construct faiss index
|
373 |
-
d = protein_repr.shape[-1]
|
374 |
-
protein_indices = faiss.IndexFlatIP(d)
|
375 |
-
protein_indices.add(protein_repr.cpu().numpy())
|
376 |
-
return protein_indices
|
377 |
-
|
378 |
-
def _get_structure_indices(self):
|
379 |
-
world_size = dist.get_world_size()
|
380 |
-
rank = dist.get_rank()
|
381 |
-
|
382 |
-
proteins = [sub_dict["foldseek"] for sub_dict in self.uniprot2label.values()]
|
383 |
-
span = math.ceil(len(proteins) / world_size)
|
384 |
-
sub_proteins = proteins[rank * span: (rank + 1) * span]
|
385 |
-
|
386 |
-
# Display the progress bar on the rank 0 process
|
387 |
-
verbose = self.trainer.local_rank == 0
|
388 |
-
# Get protein representations
|
389 |
-
sub_protein_repr = self.structure_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
|
390 |
-
protein_repr = self.padded_gather(sub_protein_repr)
|
391 |
-
|
392 |
-
# Construct faiss index
|
393 |
-
d = protein_repr.shape[-1]
|
394 |
-
structure_indices = faiss.IndexFlatIP(d)
|
395 |
-
structure_indices.add(protein_repr.cpu().numpy())
|
396 |
-
return structure_indices
|
397 |
-
|
398 |
-
def _get_text_indices(self):
|
399 |
-
world_size = dist.get_world_size()
|
400 |
-
rank = dist.get_rank()
|
401 |
-
|
402 |
-
# Display the progress bar on the rank 0 process
|
403 |
-
verbose = self.trainer.local_rank == 0
|
404 |
-
if verbose:
|
405 |
-
iterator = tqdm(self.label2text.keys(), desc="Get text representations")
|
406 |
-
else:
|
407 |
-
iterator = self.label2text.keys()
|
408 |
-
|
409 |
-
text_embeddings = {}
|
410 |
-
for subsection in iterator:
|
411 |
-
if subsection == "Total":
|
412 |
-
continue
|
413 |
-
|
414 |
-
texts = []
|
415 |
-
for text_list in self.label2text[subsection].values():
|
416 |
-
# Only use the first text for efficiency
|
417 |
-
texts.append(text_list[0:1])
|
418 |
-
|
419 |
-
span = math.ceil(len(texts) / world_size)
|
420 |
-
texts = texts[rank * span: (rank + 1) * span]
|
421 |
-
embeddings = []
|
422 |
-
for text_list in texts:
|
423 |
-
text_repr = self.text_encoder.get_repr(text_list)
|
424 |
-
mean_repr = text_repr.mean(dim=0, keepdim=True)
|
425 |
-
norm_repr = torch.nn.functional.normalize(mean_repr, dim=-1)
|
426 |
-
embeddings.append(norm_repr)
|
427 |
-
|
428 |
-
if len(embeddings) > 0:
|
429 |
-
embeddings = torch.cat(embeddings, dim=0)
|
430 |
-
else:
|
431 |
-
embeddings = torch.zeros(0, self.repr_dim, dtype=self.dtype, device=self.device)
|
432 |
-
|
433 |
-
text_repr = self.padded_gather(embeddings)
|
434 |
-
text_embeddings[subsection] = text_repr
|
435 |
-
|
436 |
-
# Aggregate text embeddings for global retrieval
|
437 |
-
total_embeddings = []
|
438 |
-
for idx in self.label2text["Total"].values():
|
439 |
-
subsection, i = idx.split("|")
|
440 |
-
total_embeddings.append(text_embeddings[subsection][int(i)])
|
441 |
-
|
442 |
-
text_embeddings["Total"] = torch.stack(total_embeddings)
|
443 |
-
|
444 |
-
# Construct faiss index
|
445 |
-
text_indices = {}
|
446 |
-
for subsection, text_repr in text_embeddings.items():
|
447 |
-
d = text_repr.shape[-1]
|
448 |
-
text_indices[subsection] = faiss.IndexFlatIP(d)
|
449 |
-
text_indices[subsection].add(text_repr.cpu().numpy())
|
450 |
-
|
451 |
-
return text_indices
|
452 |
-
|
453 |
-
def _protein2text(self, modality: str, protein_indices, text_indices: dict):
|
454 |
-
def do(process_id, idx, row, writer):
|
455 |
-
subsection, uniprot_id, prob_idx, label = row
|
456 |
-
|
457 |
-
# Retrieve ranking results
|
458 |
-
p_embedding = protein_indices.reconstruct(prob_idx).reshape(1, -1)
|
459 |
-
text_inds = text_indices[subsection]
|
460 |
-
sim_scores, rank_inds = text_inds.search(p_embedding, text_inds.ntotal)
|
461 |
-
sim_scores, rank_inds = sim_scores[0], rank_inds[0]
|
462 |
-
|
463 |
-
# Calculate Average Precision(AP)
|
464 |
-
ranks = []
|
465 |
-
label = set(label)
|
466 |
-
for i, rk in enumerate(rank_inds):
|
467 |
-
# Find the rank of this label in all labels
|
468 |
-
if rk in label:
|
469 |
-
ranks.append(i + 1)
|
470 |
-
|
471 |
-
ranks = np.array(ranks)
|
472 |
-
ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
|
473 |
-
|
474 |
-
# Calculate Mean Reciprocal Rank(MRR)
|
475 |
-
best_rank = ranks[0]
|
476 |
-
mrr = 1 / best_rank
|
477 |
-
|
478 |
-
# Calculate the AUC
|
479 |
-
true_labels = np.zeros_like(sim_scores)
|
480 |
-
true_labels[ranks - 1] = 1
|
481 |
-
if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
|
482 |
-
auc = 0
|
483 |
-
else:
|
484 |
-
auc = roc_auc_score(true_labels, sim_scores)
|
485 |
-
|
486 |
-
output = json.dumps([ap, mrr, auc])
|
487 |
-
writer.write(output + "\n")
|
488 |
-
|
489 |
-
inputs = []
|
490 |
-
swissprot_subsections = set()
|
491 |
-
for subsection in text_indices.keys():
|
492 |
-
for i, (uniprot_id, labels) in enumerate(self.uniprot2label.items()):
|
493 |
-
if uniprot_id in self.swissprot_ids:
|
494 |
-
if subsection in labels:
|
495 |
-
swissprot_subsections.add(subsection)
|
496 |
-
label = labels[subsection]
|
497 |
-
inputs.append((subsection, uniprot_id, i, label))
|
498 |
-
|
499 |
-
# Randomly shuffle the inputs
|
500 |
-
random.seed(20000812)
|
501 |
-
random.shuffle(inputs)
|
502 |
-
|
503 |
-
# Split inputs into chunks for parallel processing
|
504 |
-
world_size = dist.get_world_size()
|
505 |
-
rank = dist.get_rank()
|
506 |
-
|
507 |
-
span = math.ceil(len(inputs) / world_size)
|
508 |
-
sub_inputs = inputs[rank * span: (rank + 1) * span]
|
509 |
-
|
510 |
-
# Display the progress bar on the rank 0 process
|
511 |
-
verbose = self.trainer.local_rank == 0
|
512 |
-
if verbose:
|
513 |
-
print("Evaluating on each subsection...")
|
514 |
-
tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
|
515 |
-
mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
|
516 |
-
return_results=True)
|
517 |
-
outputs = mpr.run()
|
518 |
-
os.remove(tmp_path)
|
519 |
-
|
520 |
-
# Aggregate results
|
521 |
-
tensor_outputs = []
|
522 |
-
for output in outputs:
|
523 |
-
ap, mrr, auc = json.loads(output)
|
524 |
-
tensor_outputs.append([float(ap), float(mrr), float(auc)])
|
525 |
-
|
526 |
-
tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
|
527 |
-
tensor_outputs = self.padded_gather(tensor_outputs)
|
528 |
-
|
529 |
-
# Record results
|
530 |
-
avg_results = {}
|
531 |
-
for subsection in swissprot_subsections:
|
532 |
-
avg_results[subsection] = {"map": [],
|
533 |
-
"mrr": [],
|
534 |
-
"auc": []}
|
535 |
-
|
536 |
-
for input, output in zip(inputs, tensor_outputs):
|
537 |
-
ap, mrr, auc = output
|
538 |
-
subsection, _, _, _ = input
|
539 |
-
|
540 |
-
avg_results[subsection]["map"].append(ap.cpu().item())
|
541 |
-
avg_results[subsection]["mrr"].append(mrr.cpu().item())
|
542 |
-
avg_results[subsection]["auc"].append(auc.cpu().item())
|
543 |
-
|
544 |
-
results = {
|
545 |
-
f"{modality}2Text_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
|
546 |
-
f"{modality}2Text_Total_map": np.mean(avg_results["Total"]["map"]),
|
547 |
-
f"{modality}2Text_Total_auc": np.mean(avg_results["Total"]["auc"]),
|
548 |
-
}
|
549 |
-
|
550 |
-
# Average the precision and recall for each level
|
551 |
-
for level, labels in [("residue-level", residue_level),
|
552 |
-
("sequence-level", sequence_level),
|
553 |
-
("all", residue_level | sequence_level)]:
|
554 |
-
|
555 |
-
mrrs = []
|
556 |
-
maps = []
|
557 |
-
aucs = []
|
558 |
-
for subsection in labels:
|
559 |
-
if subsection in avg_results:
|
560 |
-
mrrs.append(np.mean(avg_results[subsection]["mrr"]))
|
561 |
-
maps.append(np.mean(avg_results[subsection]["map"]))
|
562 |
-
aucs.append(np.mean(avg_results[subsection]["auc"]))
|
563 |
-
|
564 |
-
results[f"{modality}2Text_{level}_mrr"] = np.mean(mrrs)
|
565 |
-
results[f"{modality}2Text_{level}_map"] = np.mean(maps)
|
566 |
-
results[f"{modality}2Text_{level}_auc"] = np.mean(aucs)
|
567 |
-
|
568 |
-
return results
|
569 |
-
|
570 |
-
def _text2protein(self, modality: str, protein_indices, text_indices: dict):
|
571 |
-
def do(process_id, idx, row, writer):
|
572 |
-
subsection, text_id, label = row
|
573 |
-
|
574 |
-
# Retrieve ranking results
|
575 |
-
t_embedding = text_indices[subsection].reconstruct(text_id).reshape(1, -1)
|
576 |
-
sim_scores, rank_inds = protein_indices.search(t_embedding, protein_indices.ntotal)
|
577 |
-
sim_scores, rank_inds = sim_scores[0], rank_inds[0]
|
578 |
-
|
579 |
-
# Calculate Average Precision(AP)
|
580 |
-
ranks = []
|
581 |
-
label = set(label)
|
582 |
-
for i, rk in enumerate(rank_inds):
|
583 |
-
# Find the rank of this label in all labels
|
584 |
-
if rk in label:
|
585 |
-
ranks.append(i + 1)
|
586 |
-
|
587 |
-
ranks = np.array(ranks)
|
588 |
-
ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
|
589 |
-
|
590 |
-
# Calculate Mean Reciprocal Rank(MRR)
|
591 |
-
best_rank = ranks[0]
|
592 |
-
mrr = 1 / best_rank
|
593 |
-
|
594 |
-
# Calculate the AUC
|
595 |
-
true_labels = np.zeros_like(sim_scores)
|
596 |
-
true_labels[ranks - 1] = 1
|
597 |
-
if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
|
598 |
-
auc = 0
|
599 |
-
else:
|
600 |
-
auc = roc_auc_score(true_labels, sim_scores)
|
601 |
-
|
602 |
-
output = json.dumps([ap, mrr, auc])
|
603 |
-
writer.write(output + "\n")
|
604 |
-
|
605 |
-
text2label = {}
|
606 |
-
swissprot_subsections = set()
|
607 |
-
for i, (uniprot_id, subsections) in enumerate(self.uniprot2label.items()):
|
608 |
-
# Only evaluate the texts in Swiss-Prot
|
609 |
-
if uniprot_id not in self.swissprot_ids:
|
610 |
-
continue
|
611 |
-
|
612 |
-
for subsection, text_ids in subsections.items():
|
613 |
-
if subsection == "seq" or subsection == "foldseek":
|
614 |
-
continue
|
615 |
-
|
616 |
-
swissprot_subsections.add(subsection)
|
617 |
-
if subsection not in text2label:
|
618 |
-
text2label[subsection] = {}
|
619 |
-
|
620 |
-
for text_id in text_ids:
|
621 |
-
text2label[subsection][text_id] = text2label[subsection].get(text_id, []) + [i]
|
622 |
-
|
623 |
-
inputs = []
|
624 |
-
for subsection in swissprot_subsections:
|
625 |
-
for i, (text_id, label) in enumerate(text2label[subsection].items()):
|
626 |
-
inputs.append((subsection, text_id, label))
|
627 |
-
|
628 |
-
# Randomly shuffle the inputs
|
629 |
-
random.seed(20000812)
|
630 |
-
random.shuffle(inputs)
|
631 |
-
|
632 |
-
# Split inputs into chunks for parallel processing
|
633 |
-
world_size = dist.get_world_size()
|
634 |
-
rank = dist.get_rank()
|
635 |
-
|
636 |
-
span = math.ceil(len(inputs) / world_size)
|
637 |
-
sub_inputs = inputs[rank * span: (rank + 1) * span]
|
638 |
-
|
639 |
-
# Display the progress bar on the rank 0 process
|
640 |
-
verbose = self.trainer.local_rank == 0
|
641 |
-
if verbose:
|
642 |
-
print("Evaluating on each text...")
|
643 |
-
|
644 |
-
# Add time stamp to the temporary file name to avoid conflicts
|
645 |
-
tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
|
646 |
-
mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
|
647 |
-
return_results=True)
|
648 |
-
outputs = mpr.run()
|
649 |
-
os.remove(tmp_path)
|
650 |
-
|
651 |
-
# Aggregate results
|
652 |
-
tensor_outputs = []
|
653 |
-
for output in outputs:
|
654 |
-
ap, mrr, auc = json.loads(output)
|
655 |
-
tensor_outputs.append([float(ap), float(mrr), float(auc)])
|
656 |
-
|
657 |
-
tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
|
658 |
-
tensor_outputs = self.padded_gather(tensor_outputs)
|
659 |
-
|
660 |
-
# Record results
|
661 |
-
avg_results = {}
|
662 |
-
for subsection in swissprot_subsections:
|
663 |
-
avg_results[subsection] = {"map": [],
|
664 |
-
"mrr": [],
|
665 |
-
"auc": []}
|
666 |
-
|
667 |
-
for input, output in zip(inputs, tensor_outputs):
|
668 |
-
ap, mrr, auc = output
|
669 |
-
subsection, _, _ = input
|
670 |
-
|
671 |
-
avg_results[subsection]["map"].append(ap.cpu().item())
|
672 |
-
avg_results[subsection]["mrr"].append(mrr.cpu().item())
|
673 |
-
avg_results[subsection]["auc"].append(auc.cpu().item())
|
674 |
-
|
675 |
-
results = {
|
676 |
-
f"Text2{modality}_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
|
677 |
-
f"Text2{modality}_Total_map": np.mean(avg_results["Total"]["map"]),
|
678 |
-
f"Text2{modality}_Total_auc": np.mean(avg_results["Total"]["auc"]),
|
679 |
-
}
|
680 |
-
|
681 |
-
# Average the precision and recall for each level
|
682 |
-
for level, labels in [("residue-level", residue_level),
|
683 |
-
("sequence-level", sequence_level),
|
684 |
-
("all", residue_level | sequence_level)]:
|
685 |
-
|
686 |
-
mrrs = []
|
687 |
-
maps = []
|
688 |
-
aucs = []
|
689 |
-
for subsection in labels:
|
690 |
-
if subsection in avg_results:
|
691 |
-
mrrs.append(np.mean(avg_results[subsection]["mrr"]))
|
692 |
-
maps.append(np.mean(avg_results[subsection]["map"]))
|
693 |
-
aucs.append(np.mean(avg_results[subsection]["auc"]))
|
694 |
-
|
695 |
-
results[f"Text2{modality}_{level}_mrr"] = np.mean(mrrs)
|
696 |
-
results[f"Text2{modality}_{level}_map"] = np.mean(maps)
|
697 |
-
results[f"Text2{modality}_{level}_auc"] = np.mean(aucs)
|
698 |
-
|
699 |
-
return results
|
700 |
-
|
701 |
-
def retrieval_eval(self) -> dict:
|
702 |
-
# Get protein representations
|
703 |
-
protein_indices = self._get_protein_indices()
|
704 |
-
|
705 |
-
# Get structure representations
|
706 |
-
# if self.structure_config is not None:
|
707 |
-
# structure_embeddings = self._get_structure_embeddings()
|
708 |
-
|
709 |
-
# Get text representations
|
710 |
-
text_indices = self._get_text_indices()
|
711 |
-
|
712 |
-
# Retrieve texts for each protein
|
713 |
-
results = {}
|
714 |
-
results.update(self._protein2text("Sequence", protein_indices, text_indices))
|
715 |
-
# if self.structure_config is not None:
|
716 |
-
# results.update(self._protein2text("Structure", structure_embeddings, text_embeddings))
|
717 |
-
# results.update(self._text2protein("Structure", structure_embeddings, text_embeddings))
|
718 |
-
|
719 |
-
# Retrieve proteins for each text
|
720 |
-
results.update(self._text2protein("Sequence", protein_indices, text_indices))
|
721 |
-
|
722 |
-
return results
|
723 |
-
|
724 |
-
def _apply_bert_mask(self, tokens, tokenizer, mask_ratio):
|
725 |
-
while True:
|
726 |
-
masked_tokens = copy.copy(tokens)
|
727 |
-
labels = torch.full((len(tokens) + 2,), -1, dtype=torch.long)
|
728 |
-
vocab = [k for k in tokenizer.get_vocab().keys()]
|
729 |
-
|
730 |
-
for i in range(len(tokens)):
|
731 |
-
token = tokens[i]
|
732 |
-
|
733 |
-
prob = random.random()
|
734 |
-
if prob < mask_ratio:
|
735 |
-
prob /= mask_ratio
|
736 |
-
labels[i + 1] = tokenizer.convert_tokens_to_ids(token)
|
737 |
-
|
738 |
-
if prob < 0.8:
|
739 |
-
# 80% random change to mask token
|
740 |
-
if self.use_saprot:
|
741 |
-
token = "#" + token[-1]
|
742 |
-
else:
|
743 |
-
token = tokenizer.mask_token
|
744 |
-
elif prob < 0.9:
|
745 |
-
# 10% chance to change to random token
|
746 |
-
token = random.choice(vocab)
|
747 |
-
else:
|
748 |
-
# 10% chance to keep current token
|
749 |
-
pass
|
750 |
-
|
751 |
-
masked_tokens[i] = token
|
752 |
-
|
753 |
-
# Check if there is at least one masked token
|
754 |
-
if (labels != -1).any():
|
755 |
-
return masked_tokens, labels
|
756 |
-
|
757 |
-
def mlm_eval(self) -> float:
|
758 |
-
world_size = dist.get_world_size()
|
759 |
-
rank = dist.get_rank()
|
760 |
-
|
761 |
-
if self.use_saprot:
|
762 |
-
proteins = []
|
763 |
-
for sub_dict in self.uniprot2label.values():
|
764 |
-
aa_seq = sub_dict["seq"]
|
765 |
-
foldseek_seq = sub_dict["foldseek"]
|
766 |
-
assert len(aa_seq) == len(foldseek_seq)
|
767 |
-
seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
|
768 |
-
proteins.append(seq)
|
769 |
-
|
770 |
-
else:
|
771 |
-
proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
|
772 |
-
|
773 |
-
span = math.ceil(len(proteins) / world_size)
|
774 |
-
sub_proteins = proteins[rank * span: (rank + 1) * span]
|
775 |
-
|
776 |
-
# Display the progress bar on the rank 0 process
|
777 |
-
if self.trainer.local_rank == 0:
|
778 |
-
iterator = tqdm(sub_proteins, desc="Computing mlm...")
|
779 |
-
else:
|
780 |
-
iterator = sub_proteins
|
781 |
-
|
782 |
-
total = torch.tensor([0], dtype=torch.long, device=self.device)
|
783 |
-
correct = torch.tensor([0], dtype=torch.long, device=self.device)
|
784 |
-
for seq in iterator:
|
785 |
-
tokens = self.protein_encoder.tokenizer.tokenize(seq)
|
786 |
-
masked_tokens, labels = self._apply_bert_mask(tokens, self.protein_encoder.tokenizer, 0.15)
|
787 |
-
seq = " ".join(masked_tokens)
|
788 |
-
|
789 |
-
inputs = self.protein_encoder.tokenizer(seq, return_tensors="pt")
|
790 |
-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
791 |
-
_, logits = self.protein_encoder(inputs, get_mask_logits=True)
|
792 |
-
|
793 |
-
logits = logits.squeeze(0)
|
794 |
-
labels = labels.to(self.device)
|
795 |
-
|
796 |
-
selecor = labels != -1
|
797 |
-
preds = logits.argmax(dim=-1)[selecor]
|
798 |
-
labels = labels[selecor]
|
799 |
-
|
800 |
-
total += len(preds)
|
801 |
-
correct += (preds == labels).sum()
|
802 |
-
|
803 |
-
# Gather all results
|
804 |
-
total = self.padded_gather(total).sum()
|
805 |
-
correct = self.padded_gather(correct).sum()
|
806 |
-
|
807 |
-
acc = correct / total
|
808 |
-
return acc.cpu().item()
|
809 |
-
|
810 |
-
def _load_eval_data(self, stage):
|
811 |
-
# Load the data
|
812 |
-
lmdb_dir = eval(f"self.trainer.datamodule.{stage}_lmdb")
|
813 |
-
uniprot2label_path = os.path.join(lmdb_dir, "uniprot2label.json")
|
814 |
-
label2text_path = os.path.join(lmdb_dir, "label2text.json")
|
815 |
-
swissprot_id_path = os.path.join(lmdb_dir, "swissprot_ids.tsv")
|
816 |
-
|
817 |
-
self.uniprot2label = json.load(open(uniprot2label_path, "r"))
|
818 |
-
self.label2text = json.load(open(label2text_path, "r"))
|
819 |
-
self.swissprot_ids = set(pd.read_csv(swissprot_id_path, sep="\t", header=None).values.flatten().tolist())
|
820 |
-
self.k = 3
|
821 |
-
|
822 |
-
def on_test_start(self):
|
823 |
-
self._load_eval_data("test")
|
824 |
-
|
825 |
-
log_dict = self.retrieval_eval()
|
826 |
-
log_dict = {"test_" + k: v for k, v in log_dict.items()}
|
827 |
-
if self.use_mlm_loss:
|
828 |
-
log_dict["test_mask_acc"] = self.mlm_eval()
|
829 |
-
self.log_info(log_dict)
|
830 |
-
print(log_dict)
|
831 |
-
|
832 |
-
def on_validation_start(self):
|
833 |
-
# Clear the cache
|
834 |
-
torch.cuda.empty_cache()
|
835 |
-
|
836 |
-
self._load_eval_data("valid")
|
837 |
-
|
838 |
-
log_dict = self.retrieval_eval()
|
839 |
-
log_dict = {"valid_" + k: v for k, v in log_dict.items()}
|
840 |
-
if self.use_mlm_loss:
|
841 |
-
log_dict["valid_mask_acc"] = self.mlm_eval()
|
842 |
-
self.log_info(log_dict)
|
843 |
-
|
844 |
-
self.check_save_condition(self.step, mode="max")
|
845 |
-
|
846 |
-
def test_step(self, batch, batch_idx):
|
847 |
-
return
|
848 |
-
|
849 |
-
def validation_step(self, batch, batch_idx):
|
850 |
-
return
|
851 |
-
|
852 |
-
def on_train_epoch_end(self):
|
853 |
-
super().on_train_epoch_end()
|
854 |
-
# Re-sample the subset of the training data
|
855 |
-
if self.trainer.datamodule.train_dataset.fixed_dataset_num is not None:
|
856 |
-
self.trainer.datamodule.train_dataset.sample_subset()
|
857 |
-
|
858 |
-
# def test_epoch_end(self, outputs):
|
859 |
-
# log_dict = self.get_log_dict("test")
|
860 |
-
# log_dict["test_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
|
861 |
-
#
|
862 |
-
# print(log_dict)
|
863 |
-
# self.log_info(log_dict)
|
864 |
-
#
|
865 |
-
# self.reset_metrics("test")
|
866 |
-
#
|
867 |
-
# def validation_epoch_end(self, outputs):
|
868 |
-
# log_dict = self.get_log_dict("valid")
|
869 |
-
# log_dict["valid_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
|
870 |
-
#
|
871 |
-
# self.log_info(log_dict)
|
872 |
-
# self.reset_metrics("valid")
|
873 |
-
# self.check_save_condition(log_dict["valid_loss"], mode="min")
|
874 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/ProTrek/structure_encoder.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
from tqdm import tqdm
|
4 |
-
from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer
|
5 |
-
from torch.nn.functional import normalize
|
6 |
-
|
7 |
-
|
8 |
-
class StructureEncoder(torch.nn.Module):
|
9 |
-
def __init__(self, config_path: str, out_dim: int, gradient_checkpointing: bool = False):
|
10 |
-
"""
|
11 |
-
Args:
|
12 |
-
config_path: Path to the config file
|
13 |
-
|
14 |
-
out_dim: Output dimension of the structure representation
|
15 |
-
|
16 |
-
gradient_checkpointing: Whether to use gradient checkpointing
|
17 |
-
"""
|
18 |
-
super().__init__()
|
19 |
-
config = EsmConfig.from_pretrained(config_path)
|
20 |
-
self.model = EsmForMaskedLM(config)
|
21 |
-
self.out = torch.nn.Linear(config.hidden_size, out_dim)
|
22 |
-
|
23 |
-
# Set gradient checkpointing
|
24 |
-
self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
|
25 |
-
|
26 |
-
# Remove contact head
|
27 |
-
self.model.esm.contact_head = None
|
28 |
-
|
29 |
-
# Remove position embedding if the embedding type is ``rotary``
|
30 |
-
if config.position_embedding_type == "rotary":
|
31 |
-
self.model.esm.embeddings.position_embeddings = None
|
32 |
-
|
33 |
-
self.tokenizer = EsmTokenizer.from_pretrained(config_path)
|
34 |
-
|
35 |
-
def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
36 |
-
"""
|
37 |
-
Compute protein structure representation for the given proteins
|
38 |
-
Args:
|
39 |
-
protein: A list of protein structural sequences
|
40 |
-
batch_size: Batch size for inference
|
41 |
-
verbose: Whether to print progress
|
42 |
-
"""
|
43 |
-
device = next(self.parameters()).device
|
44 |
-
|
45 |
-
protein_repr = []
|
46 |
-
if verbose:
|
47 |
-
iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
|
48 |
-
else:
|
49 |
-
iterator = range(0, len(proteins), batch_size)
|
50 |
-
|
51 |
-
for i in iterator:
|
52 |
-
protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
|
53 |
-
return_tensors="pt",
|
54 |
-
padding=True)
|
55 |
-
protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
|
56 |
-
output, _ = self.forward(protein_inputs)
|
57 |
-
|
58 |
-
protein_repr.append(output)
|
59 |
-
|
60 |
-
protein_repr = torch.cat(protein_repr, dim=0)
|
61 |
-
return normalize(protein_repr, dim=-1)
|
62 |
-
|
63 |
-
def forward(self, inputs: dict, get_mask_logits: bool = False):
|
64 |
-
"""
|
65 |
-
Encode protein structure into protein representation
|
66 |
-
Args:
|
67 |
-
inputs: A dictionary containing the following keys:
|
68 |
-
- input_ids: [batch, seq_len]
|
69 |
-
- attention_mask: [batch, seq_len]
|
70 |
-
get_mask_logits: Whether to return the logits for masked tokens
|
71 |
-
|
72 |
-
Returns:
|
73 |
-
protein_repr: [batch, protein_repr_dim]
|
74 |
-
mask_logits : [batch, seq_len, vocab_size]
|
75 |
-
"""
|
76 |
-
last_hidden_state = self.model.esm(**inputs).last_hidden_state
|
77 |
-
reprs = last_hidden_state[:, 0, :]
|
78 |
-
reprs = self.out(reprs)
|
79 |
-
|
80 |
-
# Get logits for masked tokens
|
81 |
-
if get_mask_logits:
|
82 |
-
mask_logits = self.model.lm_head(last_hidden_state)
|
83 |
-
else:
|
84 |
-
mask_logits = None
|
85 |
-
|
86 |
-
return reprs, mask_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/ProTrek/text_encoder.py
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
from tqdm import tqdm
|
4 |
-
from torch.nn.functional import normalize
|
5 |
-
from transformers import BertConfig, BertModel, BertTokenizer
|
6 |
-
|
7 |
-
|
8 |
-
class TextEncoder(torch.nn.Module):
|
9 |
-
def __init__(self,
|
10 |
-
config_path: str,
|
11 |
-
out_dim: int,
|
12 |
-
load_pretrained: bool = True,
|
13 |
-
gradient_checkpointing: bool = False):
|
14 |
-
"""
|
15 |
-
Args:
|
16 |
-
config_path: Path to the config file
|
17 |
-
|
18 |
-
out_dim: Output dimension of the text representation
|
19 |
-
|
20 |
-
load_pretrained: Whether to load pretrained weights
|
21 |
-
|
22 |
-
gradient_checkpointing: Whether to enable gradient checkpointing
|
23 |
-
"""
|
24 |
-
super().__init__()
|
25 |
-
config = BertConfig.from_pretrained(config_path)
|
26 |
-
if load_pretrained:
|
27 |
-
self.model = BertModel.from_pretrained(config_path, add_pooling_layer=False)
|
28 |
-
else:
|
29 |
-
self.model = BertModel(config, add_pooling_layer=False)
|
30 |
-
self.out = torch.nn.Linear(config.hidden_size, out_dim)
|
31 |
-
|
32 |
-
# Set gradient checkpointing
|
33 |
-
self.model.encoder.gradient_checkpointing = gradient_checkpointing
|
34 |
-
|
35 |
-
self.tokenizer = BertTokenizer.from_pretrained(config_path)
|
36 |
-
|
37 |
-
def get_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
|
38 |
-
"""
|
39 |
-
Compute text representation for the given texts
|
40 |
-
Args:
|
41 |
-
texts: A list of strings
|
42 |
-
batch_size: Batch size for inference
|
43 |
-
verbose: Whether to print progress
|
44 |
-
"""
|
45 |
-
device = next(self.parameters()).device
|
46 |
-
|
47 |
-
text_repr = []
|
48 |
-
if verbose:
|
49 |
-
iterator = tqdm(range(0, len(texts), batch_size), desc="Computing text embeddings")
|
50 |
-
else:
|
51 |
-
iterator = range(0, len(texts), batch_size)
|
52 |
-
|
53 |
-
for i in iterator:
|
54 |
-
text_inputs = self.tokenizer.batch_encode_plus(texts[i: i+batch_size],
|
55 |
-
return_tensors="pt",
|
56 |
-
truncation=True,
|
57 |
-
max_length=512,
|
58 |
-
padding=True)
|
59 |
-
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
|
60 |
-
output = self(text_inputs)
|
61 |
-
|
62 |
-
text_repr.append(output)
|
63 |
-
|
64 |
-
text_repr = torch.cat(text_repr, dim=0)
|
65 |
-
return normalize(text_repr, dim=-1)
|
66 |
-
|
67 |
-
def forward(self, inputs: dict):
|
68 |
-
"""
|
69 |
-
Encode text into text representation
|
70 |
-
Args:
|
71 |
-
inputs: A dictionary containing the following keys:
|
72 |
-
- input_ids: [batch, seq_len]
|
73 |
-
- attention_mask: [batch, seq_len]
|
74 |
-
- token_type_ids: [batch, seq_len]
|
75 |
-
|
76 |
-
Returns:
|
77 |
-
text_repr: [batch, text_repr_dim]
|
78 |
-
"""
|
79 |
-
reprs = self.model(**inputs).last_hidden_state[:, 0, :]
|
80 |
-
reprs = self.out(reprs)
|
81 |
-
return reprs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/abstract_model.py
DELETED
@@ -1,401 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import abc
|
3 |
-
import os
|
4 |
-
import copy
|
5 |
-
|
6 |
-
import pytorch_lightning as pl
|
7 |
-
from utils.lr_scheduler import *
|
8 |
-
from torch import distributed as dist
|
9 |
-
|
10 |
-
|
11 |
-
class AbstractModel(pl.LightningModule):
|
12 |
-
def __init__(self,
|
13 |
-
lr_scheduler_kwargs: dict = None,
|
14 |
-
optimizer_kwargs: dict = None,
|
15 |
-
save_path: str = None,
|
16 |
-
from_checkpoint: str = None,
|
17 |
-
load_prev_scheduler: bool = False,
|
18 |
-
save_weights_only: bool = True,):
|
19 |
-
"""
|
20 |
-
|
21 |
-
Args:
|
22 |
-
lr_scheduler: Kwargs for lr_scheduler
|
23 |
-
optimizer_kwargs: Kwargs for optimizer_kwargs
|
24 |
-
save_path: Save trained model
|
25 |
-
from_checkpoint: Load model from checkpoint
|
26 |
-
load_prev_scheduler: Whether load previous scheduler from checkpoint
|
27 |
-
load_strict: Whether load model strictly
|
28 |
-
save_weights_only: Whether save only weights or also optimizer and lr_scheduler
|
29 |
-
|
30 |
-
"""
|
31 |
-
super().__init__()
|
32 |
-
self.initialize_model()
|
33 |
-
|
34 |
-
self.metrics = {}
|
35 |
-
for stage in ["train", "valid", "test"]:
|
36 |
-
stage_metrics = self.initialize_metrics(stage)
|
37 |
-
# Rigister metrics as attributes
|
38 |
-
for metric_name, metric in stage_metrics.items():
|
39 |
-
setattr(self, metric_name, metric)
|
40 |
-
|
41 |
-
self.metrics[stage] = stage_metrics
|
42 |
-
|
43 |
-
if lr_scheduler_kwargs is None:
|
44 |
-
# Default lr_scheduler
|
45 |
-
self.lr_scheduler_kwargs = {
|
46 |
-
"class": "ConstantLRScheduler",
|
47 |
-
"init_lr": 0,
|
48 |
-
}
|
49 |
-
print("No lr_scheduler_kwargs provided. The default learning rate is 0.")
|
50 |
-
|
51 |
-
else:
|
52 |
-
self.lr_scheduler_kwargs = lr_scheduler_kwargs
|
53 |
-
|
54 |
-
if optimizer_kwargs is None:
|
55 |
-
# Default optimizer
|
56 |
-
self.optimizer_kwargs = {
|
57 |
-
"class": "AdamW",
|
58 |
-
"betas": (0.9, 0.98),
|
59 |
-
"weight_decay": 0.01,
|
60 |
-
}
|
61 |
-
print("No optimizer_kwargs provided. The default optimizer is AdamW.")
|
62 |
-
else:
|
63 |
-
self.optimizer_kwargs = optimizer_kwargs
|
64 |
-
self.init_optimizers()
|
65 |
-
|
66 |
-
self.save_path = save_path
|
67 |
-
self.save_weights_only = save_weights_only
|
68 |
-
|
69 |
-
# temp_step is used for accumulating gradients
|
70 |
-
self.temp_step = 0
|
71 |
-
self.step = 0
|
72 |
-
self.epoch = 0
|
73 |
-
|
74 |
-
self.load_prev_scheduler = load_prev_scheduler
|
75 |
-
self.from_checkpoint = from_checkpoint
|
76 |
-
if from_checkpoint:
|
77 |
-
self.load_checkpoint(from_checkpoint)
|
78 |
-
|
79 |
-
@abc.abstractmethod
|
80 |
-
def initialize_model(self) -> None:
|
81 |
-
"""
|
82 |
-
All model initialization should be done here
|
83 |
-
Note that the whole model must be named as "self.model" for model saving and loading
|
84 |
-
"""
|
85 |
-
raise NotImplementedError
|
86 |
-
|
87 |
-
@abc.abstractmethod
|
88 |
-
def forward(self, *args, **kwargs):
|
89 |
-
"""
|
90 |
-
Forward propagation
|
91 |
-
"""
|
92 |
-
raise NotImplementedError
|
93 |
-
|
94 |
-
@abc.abstractmethod
|
95 |
-
def initialize_metrics(self, stage: str) -> dict:
|
96 |
-
"""
|
97 |
-
Initialize metrics for each stage
|
98 |
-
Args:
|
99 |
-
stage: "train", "valid" or "test"
|
100 |
-
|
101 |
-
Returns:
|
102 |
-
A dictionary of metrics for the stage. Keys are metric names and values are metric objects
|
103 |
-
"""
|
104 |
-
raise NotImplementedError
|
105 |
-
|
106 |
-
@abc.abstractmethod
|
107 |
-
def loss_func(self, stage: str, outputs, labels) -> torch.Tensor:
|
108 |
-
"""
|
109 |
-
|
110 |
-
Args:
|
111 |
-
stage: "train", "valid" or "test"
|
112 |
-
outputs: model outputs for calculating loss
|
113 |
-
labels: labels for calculating loss
|
114 |
-
|
115 |
-
Returns:
|
116 |
-
loss
|
117 |
-
|
118 |
-
"""
|
119 |
-
raise NotImplementedError
|
120 |
-
|
121 |
-
@staticmethod
|
122 |
-
def load_weights(model, weights):
|
123 |
-
model_dict = model.state_dict()
|
124 |
-
|
125 |
-
unused_params = []
|
126 |
-
missed_params = list(model_dict.keys())
|
127 |
-
|
128 |
-
for k, v in weights.items():
|
129 |
-
if k in model_dict.keys():
|
130 |
-
model_dict[k] = v
|
131 |
-
missed_params.remove(k)
|
132 |
-
|
133 |
-
else:
|
134 |
-
unused_params.append(k)
|
135 |
-
|
136 |
-
if len(missed_params) > 0:
|
137 |
-
print(f"\033[31mSome weights of {type(model).__name__} were not "
|
138 |
-
f"initialized from the model checkpoint: {missed_params}\033[0m")
|
139 |
-
|
140 |
-
if len(unused_params) > 0:
|
141 |
-
print(f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m")
|
142 |
-
|
143 |
-
model.load_state_dict(model_dict)
|
144 |
-
|
145 |
-
def optimizer_step(
|
146 |
-
self,
|
147 |
-
epoch: int,
|
148 |
-
batch_idx: int,
|
149 |
-
optimizer,
|
150 |
-
optimizer_closure=None,
|
151 |
-
) -> None:
|
152 |
-
super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)
|
153 |
-
|
154 |
-
self.temp_step += 1
|
155 |
-
if self.temp_step == self.trainer.accumulate_grad_batches:
|
156 |
-
self.step += 1
|
157 |
-
self.temp_step = 0
|
158 |
-
|
159 |
-
# For pytorch-lightning 1.9.5
|
160 |
-
# def optimizer_step(
|
161 |
-
# self,
|
162 |
-
# epoch: int,
|
163 |
-
# batch_idx: int,
|
164 |
-
# optimizer,
|
165 |
-
# optimizer_idx: int = 0,
|
166 |
-
# optimizer_closure=None,
|
167 |
-
# on_tpu: bool = False,
|
168 |
-
# using_native_amp: bool = False,
|
169 |
-
# using_lbfgs: bool = False,
|
170 |
-
# ) -> None:
|
171 |
-
# super().optimizer_step(
|
172 |
-
# epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs
|
173 |
-
# )
|
174 |
-
# self.temp_step += 1
|
175 |
-
# if self.temp_step == self.trainer.accumulate_grad_batches:
|
176 |
-
# self.step += 1
|
177 |
-
# self.temp_step = 0
|
178 |
-
|
179 |
-
def on_train_epoch_end(self):
|
180 |
-
self.epoch += 1
|
181 |
-
|
182 |
-
def training_step(self, batch, batch_idx):
|
183 |
-
inputs, labels = batch
|
184 |
-
|
185 |
-
# optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.98))
|
186 |
-
# for _ in range(1000):
|
187 |
-
# outputs = self(**inputs)
|
188 |
-
# loss = self.loss_func('train', outputs, labels)
|
189 |
-
# loss.backward()
|
190 |
-
# optimizer.step()
|
191 |
-
# optimizer.zero_grad()
|
192 |
-
#
|
193 |
-
# raise
|
194 |
-
|
195 |
-
outputs = self(**inputs)
|
196 |
-
loss = self.loss_func('train', outputs, labels)
|
197 |
-
|
198 |
-
self.log("loss", loss, prog_bar=True)
|
199 |
-
return loss
|
200 |
-
|
201 |
-
def validation_step(self, batch, batch_idx):
|
202 |
-
inputs, labels = batch
|
203 |
-
outputs = self(**inputs)
|
204 |
-
loss = self.loss_func('valid', outputs, labels)
|
205 |
-
self.valid_outputs.append(loss)
|
206 |
-
return loss
|
207 |
-
|
208 |
-
def test_step(self, batch, batch_idx):
|
209 |
-
inputs, labels = batch
|
210 |
-
outputs = self(**inputs)
|
211 |
-
|
212 |
-
loss = self.loss_func('test', outputs, labels)
|
213 |
-
self.test_outputs.append(loss)
|
214 |
-
return loss
|
215 |
-
|
216 |
-
def on_train_start(self) -> None:
|
217 |
-
# Load previous scheduler
|
218 |
-
if getattr(self, "prev_schechuler", None) is not None:
|
219 |
-
try:
|
220 |
-
self.step = self.prev_schechuler["global_step"]
|
221 |
-
self.epoch = self.prev_schechuler["epoch"]
|
222 |
-
self.best_value = self.prev_schechuler["best_value"]
|
223 |
-
self.lr_scheduler.load_state_dict(self.prev_schechuler["lr_scheduler"])
|
224 |
-
print(f"Previous training global step: {self.step}")
|
225 |
-
print(f"Previous training epoch: {self.epoch}")
|
226 |
-
print(f"Previous best value: {self.best_value}")
|
227 |
-
print(f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}")
|
228 |
-
|
229 |
-
# Load optimizer state
|
230 |
-
if hasattr(self.trainer.strategy, "deepspeed_engine"):
|
231 |
-
# For DeepSpeed strategy
|
232 |
-
try:
|
233 |
-
self.trainer.strategy.deepspeed_engine.load_checkpoint(self.from_checkpoint)
|
234 |
-
except Exception as e:
|
235 |
-
print(e)
|
236 |
-
|
237 |
-
else:
|
238 |
-
# For DDP strategy
|
239 |
-
self.optimizer.load_state_dict(self.prev_schechuler["optimizer"])
|
240 |
-
|
241 |
-
except Exception as e:
|
242 |
-
print(e)
|
243 |
-
raise Exception("Error in loading previous scheduler. Please set load_prev_scheduler=False")
|
244 |
-
|
245 |
-
def on_validation_epoch_start(self) -> None:
|
246 |
-
setattr(self, "valid_outputs", [])
|
247 |
-
|
248 |
-
def on_test_epoch_start(self) -> None:
|
249 |
-
setattr(self, "test_outputs", [])
|
250 |
-
|
251 |
-
def load_checkpoint(self, from_checkpoint: str) -> None:
|
252 |
-
"""
|
253 |
-
Args:
|
254 |
-
from_checkpoint: Path to checkpoint.
|
255 |
-
"""
|
256 |
-
|
257 |
-
# If ``from_checkpoint`` is a directory, load the checkpoint in it
|
258 |
-
if os.path.isdir(from_checkpoint):
|
259 |
-
basename = os.path.basename(from_checkpoint)
|
260 |
-
from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt")
|
261 |
-
|
262 |
-
state_dict = torch.load(from_checkpoint, map_location=self.device)
|
263 |
-
self.load_weights(self.model, state_dict["model"])
|
264 |
-
|
265 |
-
if self.load_prev_scheduler:
|
266 |
-
state_dict.pop("model")
|
267 |
-
self.prev_schechuler = state_dict
|
268 |
-
|
269 |
-
def save_checkpoint(self, save_path: str, save_info: dict = None, save_weights_only: bool = True) -> None:
|
270 |
-
"""
|
271 |
-
Save model to save_path
|
272 |
-
Args:
|
273 |
-
save_path: Path to save model
|
274 |
-
save_info: Other info to save
|
275 |
-
save_weights_only: Whether only save model weights
|
276 |
-
"""
|
277 |
-
dir = os.path.dirname(save_path)
|
278 |
-
os.makedirs(dir, exist_ok=True)
|
279 |
-
|
280 |
-
state_dict = {} if save_info is None else save_info
|
281 |
-
state_dict["model"] = self.model.state_dict()
|
282 |
-
|
283 |
-
# Convert model weights to fp32
|
284 |
-
for k, v in state_dict["model"].items():
|
285 |
-
state_dict["model"][k] = v.float()
|
286 |
-
|
287 |
-
if not save_weights_only:
|
288 |
-
state_dict["global_step"] = self.step
|
289 |
-
state_dict["epoch"] = self.epoch
|
290 |
-
state_dict["best_value"] = getattr(self, f"best_value", None)
|
291 |
-
state_dict["lr_scheduler"] = self.lr_schedulers().state_dict()
|
292 |
-
|
293 |
-
# If not using DeepSpeed, save optimizer state
|
294 |
-
if not hasattr(self.trainer.strategy, "deepspeed_engine"):
|
295 |
-
state_dict["optimizer"] = self.optimizers().optimizer.state_dict()
|
296 |
-
|
297 |
-
torch.save(state_dict, save_path)
|
298 |
-
|
299 |
-
def check_save_condition(self, now_value: float, mode: str, save_info: dict = None) -> None:
|
300 |
-
"""
|
301 |
-
Check whether to save model. If save_path is not None and now_value is the best, save model.
|
302 |
-
Args:
|
303 |
-
now_value: Current metric value
|
304 |
-
mode: "min" or "max", meaning whether the lower the better or the higher the better
|
305 |
-
save_info: Other info to save
|
306 |
-
"""
|
307 |
-
|
308 |
-
assert mode in ["min", "max"], "mode should be 'min' or 'max'"
|
309 |
-
|
310 |
-
if self.save_path is not None:
|
311 |
-
# In case there are variables to be included in the save path
|
312 |
-
save_path = eval(f"f'{self.save_path}'")
|
313 |
-
|
314 |
-
dir = os.path.dirname(save_path)
|
315 |
-
os.makedirs(dir, exist_ok=True)
|
316 |
-
|
317 |
-
# Check whether to save model
|
318 |
-
best_value = getattr(self, f"best_value", None)
|
319 |
-
if best_value is not None:
|
320 |
-
if mode == "min" and now_value >= best_value or mode == "max" and now_value <= best_value:
|
321 |
-
return
|
322 |
-
|
323 |
-
setattr(self, "best_value", now_value)
|
324 |
-
|
325 |
-
# For DeepSpeed strategy
|
326 |
-
if hasattr(self.trainer.strategy, "deepspeed_engine"):
|
327 |
-
if not self.save_weights_only:
|
328 |
-
self.trainer.strategy.deepspeed_engine.save_checkpoint(save_path, tag="deepspeed_ckpt")
|
329 |
-
|
330 |
-
# Save a complete checkpoint
|
331 |
-
if dist.get_rank() == 0:
|
332 |
-
basename = os.path.basename(save_path)
|
333 |
-
ckpt_path = os.path.join(save_path, f"{basename}.pt")
|
334 |
-
self.save_checkpoint(ckpt_path, save_info, self.save_weights_only)
|
335 |
-
|
336 |
-
# For normal situation
|
337 |
-
else:
|
338 |
-
if dist.get_rank() == 0:
|
339 |
-
self.save_checkpoint(save_path, save_info, self.save_weights_only)
|
340 |
-
|
341 |
-
def reset_metrics(self, stage) -> None:
|
342 |
-
"""
|
343 |
-
Reset metrics for given stage
|
344 |
-
Args:
|
345 |
-
stage: "train", "valid" or "test"
|
346 |
-
"""
|
347 |
-
for metric in self.metrics[stage].values():
|
348 |
-
metric.reset()
|
349 |
-
|
350 |
-
def get_log_dict(self, stage: str) -> dict:
|
351 |
-
"""
|
352 |
-
Get log dict for the stage
|
353 |
-
Args:
|
354 |
-
stage: "train", "valid" or "test"
|
355 |
-
|
356 |
-
Returns:
|
357 |
-
A dictionary of metrics for the stage. Keys are metric names and values are metric values
|
358 |
-
|
359 |
-
"""
|
360 |
-
return {name: metric.compute() for name, metric in self.metrics[stage].items()}
|
361 |
-
|
362 |
-
def log_info(self, info: dict) -> None:
|
363 |
-
"""
|
364 |
-
Record metrics during training and testing
|
365 |
-
Args:
|
366 |
-
info: dict of metrics
|
367 |
-
"""
|
368 |
-
if getattr(self, "logger", None) is not None and dist.get_rank() == 0:
|
369 |
-
info["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
|
370 |
-
info["epoch"] = self.epoch
|
371 |
-
self.logger.log_metrics(info, step=self.step)
|
372 |
-
|
373 |
-
def init_optimizers(self):
|
374 |
-
copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs)
|
375 |
-
|
376 |
-
# No decay for layer norm and bias
|
377 |
-
no_decay = ['LayerNorm.weight', 'bias']
|
378 |
-
weight_decay = copy_optimizer_kwargs.pop("weight_decay")
|
379 |
-
|
380 |
-
optimizer_grouped_parameters = [
|
381 |
-
{'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
382 |
-
'weight_decay': weight_decay},
|
383 |
-
{'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
384 |
-
'weight_decay': 0.0}
|
385 |
-
]
|
386 |
-
|
387 |
-
optimizer_cls = eval(f"torch.optim.{copy_optimizer_kwargs.pop('class')}")
|
388 |
-
self.optimizer = optimizer_cls(optimizer_grouped_parameters,
|
389 |
-
lr=self.lr_scheduler_kwargs['init_lr'],
|
390 |
-
**copy_optimizer_kwargs)
|
391 |
-
|
392 |
-
tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs)
|
393 |
-
lr_scheduler = tmp_kwargs.pop("class")
|
394 |
-
self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs)
|
395 |
-
|
396 |
-
def configure_optimizers(self):
|
397 |
-
return {"optimizer": self.optimizer,
|
398 |
-
"lr_scheduler": {"scheduler": self.lr_scheduler,
|
399 |
-
"interval": "step",
|
400 |
-
"frequency": 1}
|
401 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/model_interface.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import yaml
|
3 |
-
import glob
|
4 |
-
|
5 |
-
|
6 |
-
# register all available models through *_model.py files
|
7 |
-
# def construct_model():
|
8 |
-
# model_dir = os.path.dirname(__file__)
|
9 |
-
#
|
10 |
-
# # lists all model files
|
11 |
-
# model_list = []
|
12 |
-
# for root, _, names in os.walk(model_dir):
|
13 |
-
# for name in names:
|
14 |
-
# if name.endswith('_model.py'):
|
15 |
-
# sub_dirs = root.replace(model_dir, '').split(os.sep)
|
16 |
-
# model_list.append((sub_dirs, name[:-3]))
|
17 |
-
#
|
18 |
-
# # load model_config.yaml, controlling which models to be loaded
|
19 |
-
# model_config = yaml.safe_load(open(f"{model_dir}/model_config.yaml", "r"))
|
20 |
-
#
|
21 |
-
# if model_config["verbose"]:
|
22 |
-
# print("*" * 30 + f" Loading model " + "*" * 30)
|
23 |
-
#
|
24 |
-
# # register models
|
25 |
-
# for sub_dirs, name in model_list:
|
26 |
-
# if name in model_config["models"]:
|
27 |
-
# if len(sub_dirs) > 1:
|
28 |
-
# cmd = f"from {'.'.join(sub_dirs)} import {name}"
|
29 |
-
# else:
|
30 |
-
# cmd = f"from . import {name}"
|
31 |
-
#
|
32 |
-
# exec(cmd)
|
33 |
-
#
|
34 |
-
# if model_config["verbose"]:
|
35 |
-
# info = f"Loaded model: {name}"
|
36 |
-
# print(f"\033[32m{info}\033[0m")
|
37 |
-
# else:
|
38 |
-
# if model_config["verbose"]:
|
39 |
-
# info = f"Skipped model: {name}"
|
40 |
-
# print(f"\033[31m{info}\033[0m")
|
41 |
-
#
|
42 |
-
# if model_config["verbose"]:
|
43 |
-
# print("*" * 75)
|
44 |
-
#
|
45 |
-
#
|
46 |
-
# # register function as a wrapper for all models
|
47 |
-
# def register_model(cls):
|
48 |
-
# model_dict[cls.__name__] = cls
|
49 |
-
# return cls
|
50 |
-
#
|
51 |
-
#
|
52 |
-
# model_dict = {}
|
53 |
-
# construct_model()
|
54 |
-
#
|
55 |
-
#
|
56 |
-
# class ModelInterface:
|
57 |
-
# @classmethod
|
58 |
-
# def get_available_models(cls):
|
59 |
-
# return model_dict.keys()
|
60 |
-
#
|
61 |
-
# @classmethod
|
62 |
-
# def init_model(cls, model: str, **kwargs):
|
63 |
-
# """
|
64 |
-
#
|
65 |
-
# Args:
|
66 |
-
# model : Class name of model you want to use. Must be in model_dict.keys()
|
67 |
-
# **kwargs: Kwargs for model initialization
|
68 |
-
#
|
69 |
-
# Returns: Corresponding model
|
70 |
-
#
|
71 |
-
# """
|
72 |
-
# assert model in model_dict.keys(), f"class {model} doesn't exist!"
|
73 |
-
# return model_dict[model](**kwargs)
|
74 |
-
|
75 |
-
|
76 |
-
########################################################################
|
77 |
-
# Version 2 #
|
78 |
-
########################################################################
|
79 |
-
# register function as a wrapper for all models
|
80 |
-
def register_model(cls):
|
81 |
-
global now_cls
|
82 |
-
now_cls = cls
|
83 |
-
return cls
|
84 |
-
|
85 |
-
|
86 |
-
now_cls = None
|
87 |
-
|
88 |
-
|
89 |
-
class ModelInterface:
|
90 |
-
@classmethod
|
91 |
-
def init_model(cls, model_py_path: str, **kwargs):
|
92 |
-
"""
|
93 |
-
|
94 |
-
Args:
|
95 |
-
model_py_path: Py file Path of model you want to use.
|
96 |
-
**kwargs: Kwargs for model initialization
|
97 |
-
|
98 |
-
Returns: Corresponding model
|
99 |
-
"""
|
100 |
-
sub_dirs = model_py_path.split(os.sep)
|
101 |
-
cmd = f"from {'.' + '.'.join(sub_dirs[:-1])} import {sub_dirs[-1]}"
|
102 |
-
exec(cmd)
|
103 |
-
|
104 |
-
return now_cls(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|