LTEnjoy commited on
Commit
ef8c7ff
·
verified ·
1 Parent(s): 2bd60c8

Delete utils

Browse files
utils/constants.py DELETED
@@ -1,54 +0,0 @@
1
- import itertools
2
-
3
-
4
- aa_set = {"A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"}
5
- aa_list = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
6
-
7
- foldseek_seq_vocab = "ACDEFGHIKLMNPQRSTVWY#"
8
- foldseek_struc_vocab = "pynwrqhgdlvtmfsaeikc#"
9
-
10
- struc_unit = "abcdefghijklmnopqrstuvwxyz"
11
-
12
-
13
- def create_vocab(size: int) -> dict:
14
- """
15
-
16
- Args:
17
- size: Size of the vocabulary
18
-
19
- Returns:
20
- vocab: Vocabulary
21
- """
22
-
23
- token_len = 1
24
- while size > len(struc_unit) ** token_len:
25
- token_len += 1
26
-
27
- vocab = {}
28
- for i, token in enumerate(itertools.product(struc_unit, repeat=token_len)):
29
- vocab[i] = "".join(token)
30
- if len(vocab) == size:
31
- vocab[i+1] = "#"
32
- return vocab
33
-
34
- # ProTrek
35
- residue_level = {"Active site", "Binding site", "Site", "DNA binding", "Natural variant", "Mutagenesis",
36
- "Transmembrane", "Topological domain", "Intramembrane", "Signal peptide", "Propeptide",
37
- "Transit peptide",
38
- "Chain", "Peptide", "Modified residue", "Lipidation", "Glycosylation", "Disulfide bond",
39
- "Cross-link",
40
- "Domain", "Repeat", "Compositional bias", "Region", "Coiled coil", "Motif"}
41
-
42
- sequence_level = {"Function", "Miscellaneous", "Caution", "Catalytic activity", "Cofactor", "Activity regulation",
43
- "Biophysicochemical properties", "Pathway", "Involvement in disease", "Allergenic properties",
44
- "Toxic dose", "Pharmaceutical use", "Disruption phenotype", "Subcellular location",
45
- "Post-translational modification", "Subunit", "Domain (non-positional annotation)",
46
- "Sequence similarities", "RNA Editing", "Tissue specificity", "Developmental stage", "Induction",
47
- "Biotechnology", "Polymorphism", "GO annotation", "Proteomes", "Protein names", "Gene names",
48
- "Organism", "Taxonomic lineage", "Virus host"}
49
-
50
- raw_text_level = {"Function", "Subunit", "Tissue specificity", "Disruption phenotype", "Post-translational modification",
51
- "Induction", "Miscellaneous", "Sequence similarities", "Developmental stage",
52
- "Domain (non-positional annotation)", "Activity regulation", "Caution", "Polymorphism", "Toxic dose",
53
- "Allergenic properties", "Pharmaceutical use", "Cofactor", "Biophysicochemical properties",
54
- "Subcellular location", "RNA Editing"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/downloader.py DELETED
@@ -1,180 +0,0 @@
1
- import os
2
-
3
-
4
- from utils.mpr import MultipleProcessRunner
5
- from tqdm import tqdm
6
-
7
-
8
- class Downloader(MultipleProcessRunner):
9
- """
10
- Download files that has unified resource locator
11
- """
12
-
13
- def __init__(self, base_url, save_path, overwrite=False, skip_error_info=False, **kwargs):
14
- """
15
-
16
- Args:
17
- base_url: Unified Resource Locator of pdb file
18
- save_path: Unified Resource Locator of saving path
19
- overwrite: whether overwrite existing files
20
- """
21
- super().__init__(**kwargs)
22
-
23
- self.base_url = base_url
24
- self.save_path = save_path
25
- self.overwrite = overwrite
26
- self.skip_error_info = skip_error_info
27
-
28
- if not overwrite:
29
- # remove existing files in data
30
- self.data = [uniprot for uniprot in tqdm(self.data, desc="Filtering out existing files...")
31
- if not os.path.exists(self.save_path.format(uniprot))]
32
-
33
- def _aggregate(self, final_path: str, sub_paths):
34
- pass
35
-
36
- def _target_static(self, process_id, data, sub_path, *args):
37
- for i, uniprot in enumerate(data):
38
- url = self.base_url.format(uniprot)
39
- save_path = self.save_path.format(uniprot)
40
-
41
- # shell cmd to download files
42
- wget = f"wget -q -o /dev/null {url} -O {save_path}"
43
-
44
- rm = f"rm {save_path}"
45
- err = f"echo 'Error: {url} cannot be downloaded!'"
46
- if self.skip_error_info:
47
- err += ">/dev/null"
48
-
49
- os.system(f"{wget} || ({rm} && {err})")
50
-
51
- self.terminal_progress_bar(process_id, i + 1, len(data), f"Process{process_id} Downloading files...")
52
-
53
- def run(self):
54
- """
55
- Run this function to download files
56
- """
57
- super().run()
58
-
59
- def __len__(self):
60
- return len(self.data)
61
-
62
- @staticmethod
63
- # Clear empty files in specific directory
64
- def clear_empty_files(path):
65
- cnt = 0
66
- for file in tqdm(os.listdir(path), desc="Clearing empty files..."):
67
- if os.path.getsize(os.path.join(path, file)) == 0:
68
- os.remove(os.path.join(path, file))
69
- cnt += 1
70
- print(f"Removed {cnt} empty files")
71
- return cnt
72
-
73
-
74
- class AlphaDBDownloader(Downloader):
75
- """
76
- Download files from AlphaFold2 database
77
- """
78
- def __init__(self, uniprot_ids, type: str, save_dir: str, **kwargs):
79
- """
80
-
81
- Args:
82
- uniprots: Uniprot ids
83
- type: Which type of files to download. Must be one of ['pdb', 'mmcif', 'plddt', "pae"]
84
- save_dir: Saving directory
85
- **kwargs:
86
- """
87
-
88
- url_dict = {
89
- "pdb": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-model_v4.pdb",
90
- "mmcif": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-model_v4.cif",
91
- "plddt": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-confidence_v4.json",
92
- "pae": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-predicted_aligned_error_v4.json"
93
- }
94
-
95
- save_dict = {
96
- "pdb": "{}.pdb",
97
- "mmcif": "{}.cif",
98
- "plddt": "{}.json",
99
- "pae": "{}.json"
100
- }
101
- base_url = url_dict[type]
102
- save_path = os.path.join(save_dir, save_dict[type])
103
-
104
- super().__init__(data=uniprot_ids, base_url=base_url, save_path=save_path, **kwargs)
105
-
106
-
107
- class PDBDownloader(Downloader):
108
- """
109
- Download files from PDB
110
- """
111
- def __init__(self, pdb_ids, type: str, save_dir: str, **kwargs):
112
- """
113
-
114
- Args:
115
- pdb_ids: PDB ids
116
- type: Which type of files to download. Must be one of ['pdb', 'mmcif']
117
- save_dir: Saving directory
118
- """
119
-
120
- url_dict = {
121
- "pdb": "https://files.rcsb.org/download/{}.pdb",
122
- "mmcif": "https://files.rcsb.org/download/{}.cif"
123
- }
124
-
125
- save_dict = {
126
- "pdb": "{}.pdb",
127
- "mmcif": "{}.cif"
128
- }
129
-
130
- base_url = url_dict[type]
131
- save_path = os.path.join(save_dir, save_dict[type])
132
-
133
- super().__init__(data=pdb_ids, base_url=base_url, save_path=save_path, **kwargs)
134
-
135
-
136
- class CATHDownloader(Downloader):
137
- def __init__(self, cath_ids, save_dir, **kwargs):
138
- """
139
- Download files from CATH
140
- Args:
141
- cath_ids: CATH ids
142
- save_dir: Saving directory
143
- """
144
-
145
- url = "http://www.cathdb.info/version/v4_3_0/api/rest/id/{}.pdb"
146
- save_path = os.path.join(save_dir, "{}.pdb")
147
-
148
- super().__init__(data=cath_ids, base_url=url, save_path=save_path, **kwargs)
149
-
150
-
151
- def download_pdb(pdb_id: str, format: str, save_path: str):
152
- """
153
- Download pdb file from PDB
154
- Args:
155
- pdb_id: PDB id
156
- format: File , must be one of ['pdb', 'cif']
157
- save_path: Saving path
158
- """
159
-
160
- url = f"https://files.rcsb.org/download/{pdb_id}.{format}"
161
- wget = f"wget -q -o /dev/null {url} -O {save_path}"
162
- rm = f"rm {save_path}"
163
- err = f"echo 'Error: {url} cannot be downloaded!'"
164
- os.system(f"{wget} || ({rm} && {err})")
165
-
166
-
167
- def download_af2(uniprot_id: str, format: str, save_path: str):
168
- """
169
- Download files from AlphaFold2 database
170
- Args:
171
- uniprot_id: Uniprot id
172
- format: File format, must be one of ['pdb', 'cif', 'plddt', 'pae']
173
- save_path: Saving path
174
- """
175
-
176
- url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.{format}"
177
- wget = f"wget -q -o /dev/null {url} -O {save_path}"
178
- rm = f"rm {save_path}"
179
- err = f"echo 'Error: {url} cannot be downloaded!'"
180
- os.system(f"{wget} || ({rm} && {err})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/foldseek_util.py DELETED
@@ -1,121 +0,0 @@
1
- import os
2
- import time
3
- import json
4
- import numpy as np
5
- import re
6
- import sys
7
- sys.path.append(".")
8
-
9
-
10
- # Get structural seqs from pdb file
11
- def get_struc_seq(foldseek,
12
- path,
13
- chains: list = None,
14
- process_id: int = 0,
15
- plddt_mask: bool = False,
16
- plddt_threshold: float = 70.,
17
- foldseek_verbose: bool = False) -> dict:
18
- """
19
-
20
- Args:
21
- foldseek: Binary executable file of foldseek
22
-
23
- path: Path to pdb file
24
-
25
- chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
26
-
27
- process_id: Process ID for temporary files. This is used for parallel processing.
28
-
29
- plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
30
-
31
- plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
32
-
33
- foldseek_verbose: If True, foldseek will print verbose messages.
34
-
35
- Returns:
36
- seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
37
- (seq, struc_seq, combined_seq).
38
- """
39
- assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
40
- assert os.path.exists(path), f"PDB file not found: {path}"
41
-
42
- tmp_save_path = f"/tmp/get_struc_seq_{process_id}_{time.time()}.tsv"
43
- if foldseek_verbose:
44
- cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
45
- else:
46
- cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
47
- os.system(cmd)
48
-
49
- seq_dict = {}
50
- name = os.path.basename(path)
51
- with open(tmp_save_path, "r") as r:
52
- for i, line in enumerate(r):
53
- desc, seq, struc_seq = line.split("\t")[:3]
54
-
55
- # Mask low plddt
56
- if plddt_mask:
57
- plddts = extract_plddt(path)
58
- assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
59
-
60
- # Mask regions with plddt < threshold
61
- indices = np.where(plddts < plddt_threshold)[0]
62
- np_seq = np.array(list(struc_seq))
63
- np_seq[indices] = "#"
64
- struc_seq = "".join(np_seq)
65
-
66
- name_chain = desc.split(" ")[0]
67
- chain = name_chain.replace(name, "").split("_")[-1]
68
-
69
- if chains is None or chain in chains:
70
- if chain not in seq_dict:
71
- combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
72
- seq_dict[chain] = (seq, struc_seq, combined_seq)
73
-
74
- os.remove(tmp_save_path)
75
- os.remove(tmp_save_path + ".dbtype")
76
- return seq_dict
77
-
78
-
79
- def extract_plddt(pdb_path: str) -> np.ndarray:
80
- """
81
- Extract plddt scores from pdb file.
82
- Args:
83
- pdb_path: Path to pdb file.
84
-
85
- Returns:
86
- plddts: plddt scores.
87
- """
88
- with open(pdb_path, "r") as r:
89
- plddt_dict = {}
90
- for line in r:
91
- line = re.sub(' +', ' ', line).strip()
92
- splits = line.split(" ")
93
-
94
- if splits[0] == "ATOM":
95
- # If position < 1000
96
- if len(splits[4]) == 1:
97
- pos = int(splits[5])
98
-
99
- # If position >= 1000, the blank will be removed, e.g. "A 999" -> "A1000"
100
- # So the length of splits[4] is not 1
101
- else:
102
- pos = int(splits[4][1:])
103
-
104
- plddt = float(splits[-2])
105
-
106
- if pos not in plddt_dict:
107
- plddt_dict[pos] = [plddt]
108
- else:
109
- plddt_dict[pos].append(plddt)
110
-
111
- plddts = np.array([np.mean(v) for v in plddt_dict.values()])
112
- return plddts
113
-
114
-
115
- if __name__ == '__main__':
116
- foldseek = "/sujin/bin/foldseek"
117
- # test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
118
- test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
119
- plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
120
- res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
121
- print(res["A"][1].lower())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/lr_scheduler.py DELETED
@@ -1,187 +0,0 @@
1
- import math
2
-
3
- from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
4
-
5
-
6
- class ConstantLRScheduler(_LRScheduler):
7
- def __init__(self,
8
- optimizer,
9
- last_epoch: int = -1,
10
- verbose: bool = False,
11
- init_lr: float = 0.,
12
- ):
13
- """
14
- This is an implementation of constant learning rate scheduler.
15
- Args:
16
- optimizer: Optimizer
17
-
18
- last_epoch: The index of last epoch. Default: -1
19
-
20
- verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
21
-
22
- init_lr: Initial learning rate
23
- """
24
-
25
- self.init_lr = init_lr
26
- super().__init__(optimizer, last_epoch, verbose)
27
-
28
- def state_dict(self):
29
- state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
30
- return state_dict
31
-
32
- def load_state_dict(self, state_dict):
33
- self.__dict__.update(state_dict)
34
-
35
- def get_lr(self):
36
- if not self._get_lr_called_within_step:
37
- raise RuntimeError(
38
- "To get the last learning rate computed by the scheduler, use "
39
- "get_last_lr()"
40
- )
41
-
42
- return [self.init_lr for group in self.optimizer.param_groups]
43
-
44
-
45
- class CosineAnnealingLRScheduler(_LRScheduler):
46
- def __init__(self,
47
- optimizer,
48
- last_epoch: int = -1,
49
- verbose: bool = False,
50
- init_lr: float = 0.,
51
- max_lr: float = 4e-4,
52
- final_lr: float = 4e-5,
53
- warmup_steps: int = 2000,
54
- cosine_steps: int = 10000,
55
- ):
56
- """
57
- This is an implementation of cosine annealing learning rate scheduler.
58
- Args:
59
- optimizer: Optimizer
60
-
61
- last_epoch: The index of last epoch. Default: -1
62
-
63
- verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
64
-
65
- init_lr: Initial learning rate
66
-
67
- max_lr: Maximum learning rate after warmup
68
-
69
- final_lr: Final learning rate after decay
70
-
71
- warmup_steps: Number of steps for warmup
72
-
73
- cosine_steps: Number of steps for cosine annealing
74
- """
75
-
76
- self.init_lr = init_lr
77
- self.max_lr = max_lr
78
- self.final_lr = final_lr
79
- self.warmup_steps = warmup_steps
80
- self.cosine_steps = cosine_steps
81
- super(CosineAnnealingLRScheduler, self).__init__(optimizer, last_epoch, verbose)
82
-
83
- def state_dict(self):
84
- state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
85
- return state_dict
86
-
87
- def load_state_dict(self, state_dict):
88
- self.__dict__.update(state_dict)
89
-
90
- def get_lr(self):
91
- if not self._get_lr_called_within_step:
92
- raise RuntimeError(
93
- "To get the last learning rate computed by the scheduler, use "
94
- "get_last_lr()"
95
- )
96
-
97
- step_no = self.last_epoch
98
-
99
- if step_no <= self.warmup_steps:
100
- lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr)
101
-
102
- else:
103
- lr = self.final_lr + 0.5 * (self.max_lr - self.final_lr) \
104
- * (1 + math.cos(math.pi * (step_no - self.warmup_steps) / self.cosine_steps))
105
-
106
- return [lr for group in self.optimizer.param_groups]
107
-
108
-
109
- class Esm2LRScheduler(_LRScheduler):
110
- def __init__(self,
111
- optimizer,
112
- last_epoch: int = -1,
113
- verbose: bool = False,
114
- init_lr: float = 0.,
115
- max_lr: float = 4e-4,
116
- final_lr: float = 4e-5,
117
- warmup_steps: int = 2000,
118
- start_decay_after_n_steps: int = 500000,
119
- end_decay_after_n_steps: int = 5000000,
120
- on_use: bool = True,
121
- ):
122
- """
123
- This is an implementation of ESM2's learning rate scheduler.
124
- Args:
125
- optimizer: Optimizer
126
-
127
- last_epoch: The index of last epoch. Default: -1
128
-
129
- verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
130
-
131
- init_lr: Initial learning rate
132
-
133
- max_lr: Maximum learning rate after warmup
134
-
135
- final_lr: Final learning rate after decay
136
-
137
- warmup_steps: Number of steps for warmup
138
-
139
- start_decay_after_n_steps: Start decay after this number of steps
140
-
141
- end_decay_after_n_steps: End decay after this number of steps
142
-
143
- on_use: Whether to use this scheduler. If ``False``, the scheduler will not change the learning rate
144
- and will only use the ``init_lr``. Default: ``True``
145
- """
146
-
147
- self.init_lr = init_lr
148
- self.max_lr = max_lr
149
- self.final_lr = final_lr
150
- self.warmup_steps = warmup_steps
151
- self.start_decay_after_n_steps = start_decay_after_n_steps
152
- self.end_decay_after_n_steps = end_decay_after_n_steps
153
- self.on_use = on_use
154
- super(Esm2LRScheduler, self).__init__(optimizer, last_epoch, verbose)
155
-
156
- def state_dict(self):
157
- state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
158
- return state_dict
159
-
160
- def load_state_dict(self, state_dict):
161
- self.__dict__.update(state_dict)
162
-
163
- def get_lr(self):
164
- if not self._get_lr_called_within_step:
165
- raise RuntimeError(
166
- "To get the last learning rate computed by the scheduler, use "
167
- "get_last_lr()"
168
- )
169
-
170
- step_no = self.last_epoch
171
- if not self.on_use:
172
- return [base_lr for base_lr in self.base_lrs]
173
-
174
- if step_no <= self.warmup_steps:
175
- lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr)
176
-
177
- elif step_no <= self.start_decay_after_n_steps:
178
- lr = self.max_lr
179
-
180
- elif step_no <= self.end_decay_after_n_steps:
181
- portion = (step_no - self.start_decay_after_n_steps) / (self.end_decay_after_n_steps - self.start_decay_after_n_steps)
182
- lr = self.max_lr - portion * (self.max_lr - self.final_lr)
183
-
184
- else:
185
- lr = self.final_lr
186
-
187
- return [lr for group in self.optimizer.param_groups]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/mpr.py DELETED
@@ -1,397 +0,0 @@
1
- import abc
2
- import os
3
- import time
4
- import sys
5
-
6
-
7
- from tqdm import tqdm
8
- from math import ceil
9
-
10
-
11
- class MultipleProcessRunner:
12
- """
13
- Abstarct class for running tasks with multiple process
14
- There are three abstract methods that should be implemented:
15
- 1. __len__() : return the length of data
16
- 2. _target() : target function for each process
17
- 3. _aggregate() : aggregate results from each process
18
- """
19
-
20
- def __init__(self,
21
- data,
22
- save_path=None,
23
- n_process=1,
24
- verbose=True,
25
- total_only=True,
26
- log_step=1,
27
- start_method='fork'):
28
- """
29
- Args:
30
- data : data to be processed that can be sliced
31
-
32
- path : final output path
33
-
34
- n_process: number of process
35
-
36
- verbose : if True, display progress bar
37
-
38
- total_only: If True, only total progress bar is displayed
39
-
40
- log_step : For total progress bar, Next log will be printed when
41
- ``current iteration`` - ``last log iteration`` >= log_step
42
-
43
- start_method: start method for multiprocessing
44
- """
45
- self.data = data
46
- self.save_path = save_path
47
- self.n_process = n_process
48
- self.verbose = verbose
49
- self.total_only = total_only
50
- self.log_step = log_step
51
- self.start_method = start_method
52
-
53
- # get terminal width to format output
54
- try:
55
- self.terminal_y = os.get_terminal_size()[0]
56
-
57
- except Exception as e:
58
- print(e)
59
- print("Can't get terminal size, set terminal_y = None")
60
- self.terminal_y = None
61
-
62
- def _s2hms(self, seconds: float):
63
- """
64
- convert second format of time into hour:minute:second format
65
-
66
- """
67
- m, s = divmod(seconds, 60)
68
- h, m = divmod(m, 60)
69
-
70
- return "%02d:%02d:%02d" % (h, m, s)
71
-
72
- def _display_time(self, st_time, now, total):
73
- ed_time = time.time()
74
- running_time = ed_time - st_time
75
- rest_time = running_time * (total - now) / now
76
- iter_sec = f"{now / running_time:.2f}it/s" if now > running_time else f"{running_time / now:.2f}s/it"
77
-
78
- return f' [{self._s2hms(running_time)} < {self._s2hms(rest_time)}, {iter_sec}]'
79
-
80
- def _display_bar(self, now, total, length):
81
- now = now if now <= total else total
82
- num = now * length // total
83
- progress_bar = '[' + '#' * num + '_' * (length - num) + ']'
84
- return progress_bar
85
-
86
- def _display_all(self, now, total, desc, st_time):
87
- # make a progress bar
88
- length = 50
89
- progress_bar = self._display_bar(now, total, length)
90
- time_display = self._display_time(st_time, now, total)
91
-
92
- display = f'{desc}{progress_bar} {int(now / total * 100):02d}% {now}/{total}{time_display}'
93
-
94
- # Clean a line
95
- width = self.terminal_y if self.terminal_y is not None else 100
96
- num_space = width - len(display)
97
- if num_space > 0:
98
- display += ' ' * num_space
99
- else:
100
- length += num_space
101
- progress_bar = self._display_bar(now, total, length)
102
- display = f'{desc}{progress_bar} {int(now / total * 100):02d}% {now}/{total}{time_display}'
103
-
104
- # Set color
105
- display = f"\033[31m{display}\033[0m"
106
-
107
- return display
108
-
109
- # Print progress bar at specific position in terminal
110
- def terminal_progress_bar(self,
111
- process_id: int,
112
- now: int,
113
- total: int,
114
- desc: str = ''):
115
- """
116
-
117
- Args:
118
- process_id: process id
119
- now: now iteration number
120
- total: total iteration number
121
- desc: description
122
-
123
- """
124
- st_time = self.process_st_time[process_id]
125
-
126
- # Aggregate total information
127
- self.counts[process_id] = now
128
- self._total_display(self.process_st_time["total"])
129
-
130
- if not self.total_only:
131
- process_display = self._display_all(now, total, desc, st_time)
132
- if self.terminal_y is not None:
133
- sys.stdout.write(f"\x1b7\x1b[{process_id + 1};{0}f{process_display}\x1b8")
134
- sys.stdout.flush()
135
- else:
136
- print(f"\x1b7\x1b[{process_id + 1};{0}f{process_display}\x1b8", flush=True)
137
-
138
- # Print global information
139
- def _total_display(self, st_time):
140
- if self.total_display_callable.value == 1:
141
- self.total_display_callable.value = 0
142
-
143
- cnt = sum([self.counts[i] for i in range(self.n_process)])
144
- if cnt - self.last_cnt.value >= self.log_step:
145
- total_display = self._display_all(cnt, self.__len__(), f"Total: ", st_time)
146
- self.last_cnt.value = cnt
147
-
148
- x = self.n_process + 1 if not self.total_only else 0
149
- # if self.terminal_y is not None:
150
- # sys.stdout.write(f"\x1b7\x1b[{x};{0}f{total_display}\x1b8")
151
- # sys.stdout.flush()
152
- # else:
153
- # print(f"\x1b7\x1b[{x};{0}f{total_display}\x1b8", flush=True)
154
- print(f"\r\x1b7\x1b[{x};{0}f{total_display}\x1b8", flush=True, end="")
155
-
156
- self.total_display_callable.value = 1
157
-
158
- def run(self):
159
- """
160
- The function is used to run a multi-process task
161
- Returns: return the result of function '_aggregate()'
162
- """
163
-
164
- import multiprocess as mp
165
- mp.set_start_method(self.start_method, force=True)
166
-
167
- # total number of data that is already processed
168
- self.counts = mp.Manager().dict({i: 0 for i in range(self.n_process)})
169
-
170
- # record start time for each process
171
- self.process_st_time = {"total": time.time()}
172
-
173
- # set a lock to call total number display
174
- self.total_display_callable = mp.Value('d', 1)
175
-
176
- # Save last log iteration number
177
- self.last_cnt = mp.Value('d', 0)
178
-
179
- num_per_process = ceil(self.__len__() / self.n_process)
180
-
181
- if self.save_path is not None:
182
- file_name, suffix = os.path.splitext(self.save_path)
183
-
184
- process_list = []
185
- sub_paths = []
186
- for i in range(self.n_process):
187
- st = i * num_per_process
188
- ed = st + num_per_process
189
-
190
- # construct slice and sub path for sub process
191
- data_slice = self.data[st: ed]
192
-
193
- sub_path = None
194
- # Create a directory to save sub-results
195
- if self.save_path is not None:
196
- save_dir = f"{file_name}{suffix}_temp"
197
- os.makedirs(save_dir, exist_ok=True)
198
- sub_path = f"{save_dir}/temp_{i}{suffix}"
199
-
200
- # construct sub process
201
- input_args = (i, data_slice, sub_path)
202
- self.process_st_time[i] = time.time()
203
- p = mp.Process(target=self._target, args=input_args)
204
- p.start()
205
-
206
- process_list.append(p)
207
- sub_paths.append(sub_path)
208
-
209
- for p in process_list:
210
- p.join()
211
-
212
- # aggregate results and remove temporary directory
213
- results = self._aggregate(self.save_path, sub_paths)
214
- if self.save_path is not None:
215
- save_dir = f"{file_name}{suffix}_temp"
216
- os.rmdir(save_dir)
217
-
218
- return results
219
-
220
- def parallel_run(self):
221
- import multiprocess as mp
222
- from joblib import Parallel, delayed
223
-
224
- # total number of data that is already processed
225
- self.counts = mp.Manager().dict({i: 0 for i in range(self.n_process)})
226
-
227
- # record start time for each process
228
- self.process_st_time = {"total": time.time()}
229
-
230
- # set a lock to call total number display
231
- self.total_display_callable = mp.Value('d', 1)
232
-
233
- # Save last log iteration number
234
- self.last_cnt = mp.Value('d', 0)
235
-
236
- num_per_process = ceil(self.__len__() / self.n_process)
237
-
238
- if self.save_path is not None:
239
- file_name, suffix = os.path.splitext(self.save_path)
240
-
241
- sub_paths = []
242
- input_arg_list = []
243
- for i in range(self.n_process):
244
- st = i * num_per_process
245
- ed = st + num_per_process
246
-
247
- # construct slice and sub path for sub process
248
- data_slice = self.data[st: ed]
249
-
250
- sub_path = None
251
- # Create a directory to save sub-results
252
- if self.save_path is not None:
253
- save_dir = f"{file_name}{suffix}_temp"
254
- os.makedirs(save_dir, exist_ok=True)
255
- sub_path = f"{save_dir}/temp_{i}{suffix}"
256
-
257
- # construct sub process
258
- input_args = (i, data_slice, sub_path)
259
- self.process_st_time[i] = time.time()
260
-
261
- sub_paths.append(sub_path)
262
- input_arg_list.append(input_args)
263
-
264
- # Start parallel processing
265
- Parallel(n_jobs=self.n_process)(delayed(self._target)(input_args) for input_args in input_arg_list)
266
-
267
- # aggregate results and remove temporary directory
268
- results = self._aggregate(self.save_path, sub_paths)
269
- if self.save_path is not None:
270
- save_dir = f"{file_name}{suffix}_temp"
271
- os.rmdir(save_dir)
272
-
273
- return results
274
-
275
-
276
- @abc.abstractmethod
277
- def _aggregate(self, final_path: str, sub_paths):
278
- """
279
- This function is used to aggregate results from sub processes into a file
280
-
281
- Args:
282
- final_path: path to save final results
283
- sub_paths : list of sub paths
284
-
285
- Returns: None or desirable results specified by user
286
-
287
- """
288
- raise NotImplementedError
289
-
290
- @abc.abstractmethod
291
- def _target(self, process_id, data, sub_path):
292
- """
293
- The main body to operate data in one process
294
-
295
- Args:
296
- i : process id
297
- data : data slice
298
- sub_path: sub path to save results
299
- """
300
- raise NotImplementedError
301
-
302
- @abc.abstractmethod
303
- def __len__(self):
304
- raise NotImplementedError
305
-
306
-
307
- class MultipleProcessRunnerSimplifier(MultipleProcessRunner):
308
- """
309
- A simplified version of MultipleProcessRunner.
310
- User only need to implement the function 'do', then it will be automatically executed
311
- in every iteration after call the function 'run'.
312
- If 'save_path' is specified, it will open a file in the 'sub_path' into which
313
- user can write results, and results will be aggregated into 'save_path'.
314
-
315
- The procedure would be like:
316
- ...
317
- with open(sub_path, 'w') as w:
318
- for i, d in enumerate(data):
319
- self.do(process_id, i, d, w) # You can write results into the file.
320
- ...
321
-
322
- The 'do' function should be like:
323
- def do(process_id, idx, data, writer):
324
- ...
325
-
326
- If 'save_path' is None, the argument 'writer' will be set to None.
327
-
328
- """
329
-
330
- def __init__(self,
331
- data,
332
- do,
333
- save_path=None,
334
- n_process=1,
335
- verbose=True,
336
- total_only=True,
337
- log_step=1,
338
- return_results=False,
339
- start_method='fork'):
340
-
341
- super().__init__(data=data,
342
- save_path=save_path,
343
- n_process=n_process,
344
- verbose=verbose,
345
- total_only=total_only,
346
- log_step=log_step,
347
- start_method=start_method)
348
- self.do = do
349
- self.return_results = return_results
350
-
351
- def run(self):
352
- self.start_time = time.time()
353
- return super().run()
354
-
355
- def _aggregate(self, final_path: str, sub_paths):
356
- results = []
357
-
358
- w = open(final_path, 'w') if final_path is not None else None
359
-
360
- if self.verbose:
361
- iterator = tqdm(enumerate(sub_paths), "Aggregating results...")
362
- else:
363
- iterator = enumerate(sub_paths)
364
-
365
- for i, sub_path in iterator:
366
- if sub_path is None and self.return_results:
367
- sub_path = f"MultipleProcessRunnerSimplifier_{self.start_time}_{i}.tmp"
368
-
369
- if sub_path is not None:
370
- with open(sub_path, 'r') as r:
371
- for line in r:
372
- if w is not None:
373
- w.write(line)
374
-
375
- if self.return_results:
376
- results.append(line[:-1])
377
-
378
- os.remove(sub_path)
379
-
380
- return results
381
-
382
- def _target(self, process_id, data, sub_path):
383
- if sub_path is None and self.return_results:
384
- sub_path = f"MultipleProcessRunnerSimplifier_{self.start_time}_{process_id}.tmp"
385
-
386
- w = open(sub_path, 'w') if sub_path is not None else None
387
- for i, d in enumerate(data):
388
- self.do(process_id, i, d, w)
389
- if self.verbose:
390
- self.terminal_progress_bar(process_id, i + 1, len(data), f"Process{process_id} running...")
391
-
392
- if w is not None:
393
- w.close()
394
-
395
- def __len__(self):
396
- return len(self.data)
397
-