Spaces:
Running
Running
Delete utils
Browse files- utils/constants.py +0 -54
- utils/downloader.py +0 -180
- utils/foldseek_util.py +0 -121
- utils/lr_scheduler.py +0 -187
- utils/mpr.py +0 -397
utils/constants.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
import itertools
|
2 |
-
|
3 |
-
|
4 |
-
aa_set = {"A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"}
|
5 |
-
aa_list = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
|
6 |
-
|
7 |
-
foldseek_seq_vocab = "ACDEFGHIKLMNPQRSTVWY#"
|
8 |
-
foldseek_struc_vocab = "pynwrqhgdlvtmfsaeikc#"
|
9 |
-
|
10 |
-
struc_unit = "abcdefghijklmnopqrstuvwxyz"
|
11 |
-
|
12 |
-
|
13 |
-
def create_vocab(size: int) -> dict:
|
14 |
-
"""
|
15 |
-
|
16 |
-
Args:
|
17 |
-
size: Size of the vocabulary
|
18 |
-
|
19 |
-
Returns:
|
20 |
-
vocab: Vocabulary
|
21 |
-
"""
|
22 |
-
|
23 |
-
token_len = 1
|
24 |
-
while size > len(struc_unit) ** token_len:
|
25 |
-
token_len += 1
|
26 |
-
|
27 |
-
vocab = {}
|
28 |
-
for i, token in enumerate(itertools.product(struc_unit, repeat=token_len)):
|
29 |
-
vocab[i] = "".join(token)
|
30 |
-
if len(vocab) == size:
|
31 |
-
vocab[i+1] = "#"
|
32 |
-
return vocab
|
33 |
-
|
34 |
-
# ProTrek
|
35 |
-
residue_level = {"Active site", "Binding site", "Site", "DNA binding", "Natural variant", "Mutagenesis",
|
36 |
-
"Transmembrane", "Topological domain", "Intramembrane", "Signal peptide", "Propeptide",
|
37 |
-
"Transit peptide",
|
38 |
-
"Chain", "Peptide", "Modified residue", "Lipidation", "Glycosylation", "Disulfide bond",
|
39 |
-
"Cross-link",
|
40 |
-
"Domain", "Repeat", "Compositional bias", "Region", "Coiled coil", "Motif"}
|
41 |
-
|
42 |
-
sequence_level = {"Function", "Miscellaneous", "Caution", "Catalytic activity", "Cofactor", "Activity regulation",
|
43 |
-
"Biophysicochemical properties", "Pathway", "Involvement in disease", "Allergenic properties",
|
44 |
-
"Toxic dose", "Pharmaceutical use", "Disruption phenotype", "Subcellular location",
|
45 |
-
"Post-translational modification", "Subunit", "Domain (non-positional annotation)",
|
46 |
-
"Sequence similarities", "RNA Editing", "Tissue specificity", "Developmental stage", "Induction",
|
47 |
-
"Biotechnology", "Polymorphism", "GO annotation", "Proteomes", "Protein names", "Gene names",
|
48 |
-
"Organism", "Taxonomic lineage", "Virus host"}
|
49 |
-
|
50 |
-
raw_text_level = {"Function", "Subunit", "Tissue specificity", "Disruption phenotype", "Post-translational modification",
|
51 |
-
"Induction", "Miscellaneous", "Sequence similarities", "Developmental stage",
|
52 |
-
"Domain (non-positional annotation)", "Activity regulation", "Caution", "Polymorphism", "Toxic dose",
|
53 |
-
"Allergenic properties", "Pharmaceutical use", "Cofactor", "Biophysicochemical properties",
|
54 |
-
"Subcellular location", "RNA Editing"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/downloader.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
|
4 |
-
from utils.mpr import MultipleProcessRunner
|
5 |
-
from tqdm import tqdm
|
6 |
-
|
7 |
-
|
8 |
-
class Downloader(MultipleProcessRunner):
|
9 |
-
"""
|
10 |
-
Download files that has unified resource locator
|
11 |
-
"""
|
12 |
-
|
13 |
-
def __init__(self, base_url, save_path, overwrite=False, skip_error_info=False, **kwargs):
|
14 |
-
"""
|
15 |
-
|
16 |
-
Args:
|
17 |
-
base_url: Unified Resource Locator of pdb file
|
18 |
-
save_path: Unified Resource Locator of saving path
|
19 |
-
overwrite: whether overwrite existing files
|
20 |
-
"""
|
21 |
-
super().__init__(**kwargs)
|
22 |
-
|
23 |
-
self.base_url = base_url
|
24 |
-
self.save_path = save_path
|
25 |
-
self.overwrite = overwrite
|
26 |
-
self.skip_error_info = skip_error_info
|
27 |
-
|
28 |
-
if not overwrite:
|
29 |
-
# remove existing files in data
|
30 |
-
self.data = [uniprot for uniprot in tqdm(self.data, desc="Filtering out existing files...")
|
31 |
-
if not os.path.exists(self.save_path.format(uniprot))]
|
32 |
-
|
33 |
-
def _aggregate(self, final_path: str, sub_paths):
|
34 |
-
pass
|
35 |
-
|
36 |
-
def _target_static(self, process_id, data, sub_path, *args):
|
37 |
-
for i, uniprot in enumerate(data):
|
38 |
-
url = self.base_url.format(uniprot)
|
39 |
-
save_path = self.save_path.format(uniprot)
|
40 |
-
|
41 |
-
# shell cmd to download files
|
42 |
-
wget = f"wget -q -o /dev/null {url} -O {save_path}"
|
43 |
-
|
44 |
-
rm = f"rm {save_path}"
|
45 |
-
err = f"echo 'Error: {url} cannot be downloaded!'"
|
46 |
-
if self.skip_error_info:
|
47 |
-
err += ">/dev/null"
|
48 |
-
|
49 |
-
os.system(f"{wget} || ({rm} && {err})")
|
50 |
-
|
51 |
-
self.terminal_progress_bar(process_id, i + 1, len(data), f"Process{process_id} Downloading files...")
|
52 |
-
|
53 |
-
def run(self):
|
54 |
-
"""
|
55 |
-
Run this function to download files
|
56 |
-
"""
|
57 |
-
super().run()
|
58 |
-
|
59 |
-
def __len__(self):
|
60 |
-
return len(self.data)
|
61 |
-
|
62 |
-
@staticmethod
|
63 |
-
# Clear empty files in specific directory
|
64 |
-
def clear_empty_files(path):
|
65 |
-
cnt = 0
|
66 |
-
for file in tqdm(os.listdir(path), desc="Clearing empty files..."):
|
67 |
-
if os.path.getsize(os.path.join(path, file)) == 0:
|
68 |
-
os.remove(os.path.join(path, file))
|
69 |
-
cnt += 1
|
70 |
-
print(f"Removed {cnt} empty files")
|
71 |
-
return cnt
|
72 |
-
|
73 |
-
|
74 |
-
class AlphaDBDownloader(Downloader):
|
75 |
-
"""
|
76 |
-
Download files from AlphaFold2 database
|
77 |
-
"""
|
78 |
-
def __init__(self, uniprot_ids, type: str, save_dir: str, **kwargs):
|
79 |
-
"""
|
80 |
-
|
81 |
-
Args:
|
82 |
-
uniprots: Uniprot ids
|
83 |
-
type: Which type of files to download. Must be one of ['pdb', 'mmcif', 'plddt', "pae"]
|
84 |
-
save_dir: Saving directory
|
85 |
-
**kwargs:
|
86 |
-
"""
|
87 |
-
|
88 |
-
url_dict = {
|
89 |
-
"pdb": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-model_v4.pdb",
|
90 |
-
"mmcif": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-model_v4.cif",
|
91 |
-
"plddt": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-confidence_v4.json",
|
92 |
-
"pae": "https://alphafold.ebi.ac.uk/files/AF-{}-F1-predicted_aligned_error_v4.json"
|
93 |
-
}
|
94 |
-
|
95 |
-
save_dict = {
|
96 |
-
"pdb": "{}.pdb",
|
97 |
-
"mmcif": "{}.cif",
|
98 |
-
"plddt": "{}.json",
|
99 |
-
"pae": "{}.json"
|
100 |
-
}
|
101 |
-
base_url = url_dict[type]
|
102 |
-
save_path = os.path.join(save_dir, save_dict[type])
|
103 |
-
|
104 |
-
super().__init__(data=uniprot_ids, base_url=base_url, save_path=save_path, **kwargs)
|
105 |
-
|
106 |
-
|
107 |
-
class PDBDownloader(Downloader):
|
108 |
-
"""
|
109 |
-
Download files from PDB
|
110 |
-
"""
|
111 |
-
def __init__(self, pdb_ids, type: str, save_dir: str, **kwargs):
|
112 |
-
"""
|
113 |
-
|
114 |
-
Args:
|
115 |
-
pdb_ids: PDB ids
|
116 |
-
type: Which type of files to download. Must be one of ['pdb', 'mmcif']
|
117 |
-
save_dir: Saving directory
|
118 |
-
"""
|
119 |
-
|
120 |
-
url_dict = {
|
121 |
-
"pdb": "https://files.rcsb.org/download/{}.pdb",
|
122 |
-
"mmcif": "https://files.rcsb.org/download/{}.cif"
|
123 |
-
}
|
124 |
-
|
125 |
-
save_dict = {
|
126 |
-
"pdb": "{}.pdb",
|
127 |
-
"mmcif": "{}.cif"
|
128 |
-
}
|
129 |
-
|
130 |
-
base_url = url_dict[type]
|
131 |
-
save_path = os.path.join(save_dir, save_dict[type])
|
132 |
-
|
133 |
-
super().__init__(data=pdb_ids, base_url=base_url, save_path=save_path, **kwargs)
|
134 |
-
|
135 |
-
|
136 |
-
class CATHDownloader(Downloader):
|
137 |
-
def __init__(self, cath_ids, save_dir, **kwargs):
|
138 |
-
"""
|
139 |
-
Download files from CATH
|
140 |
-
Args:
|
141 |
-
cath_ids: CATH ids
|
142 |
-
save_dir: Saving directory
|
143 |
-
"""
|
144 |
-
|
145 |
-
url = "http://www.cathdb.info/version/v4_3_0/api/rest/id/{}.pdb"
|
146 |
-
save_path = os.path.join(save_dir, "{}.pdb")
|
147 |
-
|
148 |
-
super().__init__(data=cath_ids, base_url=url, save_path=save_path, **kwargs)
|
149 |
-
|
150 |
-
|
151 |
-
def download_pdb(pdb_id: str, format: str, save_path: str):
|
152 |
-
"""
|
153 |
-
Download pdb file from PDB
|
154 |
-
Args:
|
155 |
-
pdb_id: PDB id
|
156 |
-
format: File , must be one of ['pdb', 'cif']
|
157 |
-
save_path: Saving path
|
158 |
-
"""
|
159 |
-
|
160 |
-
url = f"https://files.rcsb.org/download/{pdb_id}.{format}"
|
161 |
-
wget = f"wget -q -o /dev/null {url} -O {save_path}"
|
162 |
-
rm = f"rm {save_path}"
|
163 |
-
err = f"echo 'Error: {url} cannot be downloaded!'"
|
164 |
-
os.system(f"{wget} || ({rm} && {err})")
|
165 |
-
|
166 |
-
|
167 |
-
def download_af2(uniprot_id: str, format: str, save_path: str):
|
168 |
-
"""
|
169 |
-
Download files from AlphaFold2 database
|
170 |
-
Args:
|
171 |
-
uniprot_id: Uniprot id
|
172 |
-
format: File format, must be one of ['pdb', 'cif', 'plddt', 'pae']
|
173 |
-
save_path: Saving path
|
174 |
-
"""
|
175 |
-
|
176 |
-
url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.{format}"
|
177 |
-
wget = f"wget -q -o /dev/null {url} -O {save_path}"
|
178 |
-
rm = f"rm {save_path}"
|
179 |
-
err = f"echo 'Error: {url} cannot be downloaded!'"
|
180 |
-
os.system(f"{wget} || ({rm} && {err})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/foldseek_util.py
DELETED
@@ -1,121 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import time
|
3 |
-
import json
|
4 |
-
import numpy as np
|
5 |
-
import re
|
6 |
-
import sys
|
7 |
-
sys.path.append(".")
|
8 |
-
|
9 |
-
|
10 |
-
# Get structural seqs from pdb file
|
11 |
-
def get_struc_seq(foldseek,
|
12 |
-
path,
|
13 |
-
chains: list = None,
|
14 |
-
process_id: int = 0,
|
15 |
-
plddt_mask: bool = False,
|
16 |
-
plddt_threshold: float = 70.,
|
17 |
-
foldseek_verbose: bool = False) -> dict:
|
18 |
-
"""
|
19 |
-
|
20 |
-
Args:
|
21 |
-
foldseek: Binary executable file of foldseek
|
22 |
-
|
23 |
-
path: Path to pdb file
|
24 |
-
|
25 |
-
chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
|
26 |
-
|
27 |
-
process_id: Process ID for temporary files. This is used for parallel processing.
|
28 |
-
|
29 |
-
plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
|
30 |
-
|
31 |
-
plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
|
32 |
-
|
33 |
-
foldseek_verbose: If True, foldseek will print verbose messages.
|
34 |
-
|
35 |
-
Returns:
|
36 |
-
seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
|
37 |
-
(seq, struc_seq, combined_seq).
|
38 |
-
"""
|
39 |
-
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
|
40 |
-
assert os.path.exists(path), f"PDB file not found: {path}"
|
41 |
-
|
42 |
-
tmp_save_path = f"/tmp/get_struc_seq_{process_id}_{time.time()}.tsv"
|
43 |
-
if foldseek_verbose:
|
44 |
-
cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
45 |
-
else:
|
46 |
-
cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
47 |
-
os.system(cmd)
|
48 |
-
|
49 |
-
seq_dict = {}
|
50 |
-
name = os.path.basename(path)
|
51 |
-
with open(tmp_save_path, "r") as r:
|
52 |
-
for i, line in enumerate(r):
|
53 |
-
desc, seq, struc_seq = line.split("\t")[:3]
|
54 |
-
|
55 |
-
# Mask low plddt
|
56 |
-
if plddt_mask:
|
57 |
-
plddts = extract_plddt(path)
|
58 |
-
assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
|
59 |
-
|
60 |
-
# Mask regions with plddt < threshold
|
61 |
-
indices = np.where(plddts < plddt_threshold)[0]
|
62 |
-
np_seq = np.array(list(struc_seq))
|
63 |
-
np_seq[indices] = "#"
|
64 |
-
struc_seq = "".join(np_seq)
|
65 |
-
|
66 |
-
name_chain = desc.split(" ")[0]
|
67 |
-
chain = name_chain.replace(name, "").split("_")[-1]
|
68 |
-
|
69 |
-
if chains is None or chain in chains:
|
70 |
-
if chain not in seq_dict:
|
71 |
-
combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
|
72 |
-
seq_dict[chain] = (seq, struc_seq, combined_seq)
|
73 |
-
|
74 |
-
os.remove(tmp_save_path)
|
75 |
-
os.remove(tmp_save_path + ".dbtype")
|
76 |
-
return seq_dict
|
77 |
-
|
78 |
-
|
79 |
-
def extract_plddt(pdb_path: str) -> np.ndarray:
|
80 |
-
"""
|
81 |
-
Extract plddt scores from pdb file.
|
82 |
-
Args:
|
83 |
-
pdb_path: Path to pdb file.
|
84 |
-
|
85 |
-
Returns:
|
86 |
-
plddts: plddt scores.
|
87 |
-
"""
|
88 |
-
with open(pdb_path, "r") as r:
|
89 |
-
plddt_dict = {}
|
90 |
-
for line in r:
|
91 |
-
line = re.sub(' +', ' ', line).strip()
|
92 |
-
splits = line.split(" ")
|
93 |
-
|
94 |
-
if splits[0] == "ATOM":
|
95 |
-
# If position < 1000
|
96 |
-
if len(splits[4]) == 1:
|
97 |
-
pos = int(splits[5])
|
98 |
-
|
99 |
-
# If position >= 1000, the blank will be removed, e.g. "A 999" -> "A1000"
|
100 |
-
# So the length of splits[4] is not 1
|
101 |
-
else:
|
102 |
-
pos = int(splits[4][1:])
|
103 |
-
|
104 |
-
plddt = float(splits[-2])
|
105 |
-
|
106 |
-
if pos not in plddt_dict:
|
107 |
-
plddt_dict[pos] = [plddt]
|
108 |
-
else:
|
109 |
-
plddt_dict[pos].append(plddt)
|
110 |
-
|
111 |
-
plddts = np.array([np.mean(v) for v in plddt_dict.values()])
|
112 |
-
return plddts
|
113 |
-
|
114 |
-
|
115 |
-
if __name__ == '__main__':
|
116 |
-
foldseek = "/sujin/bin/foldseek"
|
117 |
-
# test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
|
118 |
-
test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
|
119 |
-
plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
|
120 |
-
res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
|
121 |
-
print(res["A"][1].lower())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/lr_scheduler.py
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
|
4 |
-
|
5 |
-
|
6 |
-
class ConstantLRScheduler(_LRScheduler):
|
7 |
-
def __init__(self,
|
8 |
-
optimizer,
|
9 |
-
last_epoch: int = -1,
|
10 |
-
verbose: bool = False,
|
11 |
-
init_lr: float = 0.,
|
12 |
-
):
|
13 |
-
"""
|
14 |
-
This is an implementation of constant learning rate scheduler.
|
15 |
-
Args:
|
16 |
-
optimizer: Optimizer
|
17 |
-
|
18 |
-
last_epoch: The index of last epoch. Default: -1
|
19 |
-
|
20 |
-
verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
|
21 |
-
|
22 |
-
init_lr: Initial learning rate
|
23 |
-
"""
|
24 |
-
|
25 |
-
self.init_lr = init_lr
|
26 |
-
super().__init__(optimizer, last_epoch, verbose)
|
27 |
-
|
28 |
-
def state_dict(self):
|
29 |
-
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
|
30 |
-
return state_dict
|
31 |
-
|
32 |
-
def load_state_dict(self, state_dict):
|
33 |
-
self.__dict__.update(state_dict)
|
34 |
-
|
35 |
-
def get_lr(self):
|
36 |
-
if not self._get_lr_called_within_step:
|
37 |
-
raise RuntimeError(
|
38 |
-
"To get the last learning rate computed by the scheduler, use "
|
39 |
-
"get_last_lr()"
|
40 |
-
)
|
41 |
-
|
42 |
-
return [self.init_lr for group in self.optimizer.param_groups]
|
43 |
-
|
44 |
-
|
45 |
-
class CosineAnnealingLRScheduler(_LRScheduler):
|
46 |
-
def __init__(self,
|
47 |
-
optimizer,
|
48 |
-
last_epoch: int = -1,
|
49 |
-
verbose: bool = False,
|
50 |
-
init_lr: float = 0.,
|
51 |
-
max_lr: float = 4e-4,
|
52 |
-
final_lr: float = 4e-5,
|
53 |
-
warmup_steps: int = 2000,
|
54 |
-
cosine_steps: int = 10000,
|
55 |
-
):
|
56 |
-
"""
|
57 |
-
This is an implementation of cosine annealing learning rate scheduler.
|
58 |
-
Args:
|
59 |
-
optimizer: Optimizer
|
60 |
-
|
61 |
-
last_epoch: The index of last epoch. Default: -1
|
62 |
-
|
63 |
-
verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
|
64 |
-
|
65 |
-
init_lr: Initial learning rate
|
66 |
-
|
67 |
-
max_lr: Maximum learning rate after warmup
|
68 |
-
|
69 |
-
final_lr: Final learning rate after decay
|
70 |
-
|
71 |
-
warmup_steps: Number of steps for warmup
|
72 |
-
|
73 |
-
cosine_steps: Number of steps for cosine annealing
|
74 |
-
"""
|
75 |
-
|
76 |
-
self.init_lr = init_lr
|
77 |
-
self.max_lr = max_lr
|
78 |
-
self.final_lr = final_lr
|
79 |
-
self.warmup_steps = warmup_steps
|
80 |
-
self.cosine_steps = cosine_steps
|
81 |
-
super(CosineAnnealingLRScheduler, self).__init__(optimizer, last_epoch, verbose)
|
82 |
-
|
83 |
-
def state_dict(self):
|
84 |
-
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
|
85 |
-
return state_dict
|
86 |
-
|
87 |
-
def load_state_dict(self, state_dict):
|
88 |
-
self.__dict__.update(state_dict)
|
89 |
-
|
90 |
-
def get_lr(self):
|
91 |
-
if not self._get_lr_called_within_step:
|
92 |
-
raise RuntimeError(
|
93 |
-
"To get the last learning rate computed by the scheduler, use "
|
94 |
-
"get_last_lr()"
|
95 |
-
)
|
96 |
-
|
97 |
-
step_no = self.last_epoch
|
98 |
-
|
99 |
-
if step_no <= self.warmup_steps:
|
100 |
-
lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr)
|
101 |
-
|
102 |
-
else:
|
103 |
-
lr = self.final_lr + 0.5 * (self.max_lr - self.final_lr) \
|
104 |
-
* (1 + math.cos(math.pi * (step_no - self.warmup_steps) / self.cosine_steps))
|
105 |
-
|
106 |
-
return [lr for group in self.optimizer.param_groups]
|
107 |
-
|
108 |
-
|
109 |
-
class Esm2LRScheduler(_LRScheduler):
|
110 |
-
def __init__(self,
|
111 |
-
optimizer,
|
112 |
-
last_epoch: int = -1,
|
113 |
-
verbose: bool = False,
|
114 |
-
init_lr: float = 0.,
|
115 |
-
max_lr: float = 4e-4,
|
116 |
-
final_lr: float = 4e-5,
|
117 |
-
warmup_steps: int = 2000,
|
118 |
-
start_decay_after_n_steps: int = 500000,
|
119 |
-
end_decay_after_n_steps: int = 5000000,
|
120 |
-
on_use: bool = True,
|
121 |
-
):
|
122 |
-
"""
|
123 |
-
This is an implementation of ESM2's learning rate scheduler.
|
124 |
-
Args:
|
125 |
-
optimizer: Optimizer
|
126 |
-
|
127 |
-
last_epoch: The index of last epoch. Default: -1
|
128 |
-
|
129 |
-
verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
|
130 |
-
|
131 |
-
init_lr: Initial learning rate
|
132 |
-
|
133 |
-
max_lr: Maximum learning rate after warmup
|
134 |
-
|
135 |
-
final_lr: Final learning rate after decay
|
136 |
-
|
137 |
-
warmup_steps: Number of steps for warmup
|
138 |
-
|
139 |
-
start_decay_after_n_steps: Start decay after this number of steps
|
140 |
-
|
141 |
-
end_decay_after_n_steps: End decay after this number of steps
|
142 |
-
|
143 |
-
on_use: Whether to use this scheduler. If ``False``, the scheduler will not change the learning rate
|
144 |
-
and will only use the ``init_lr``. Default: ``True``
|
145 |
-
"""
|
146 |
-
|
147 |
-
self.init_lr = init_lr
|
148 |
-
self.max_lr = max_lr
|
149 |
-
self.final_lr = final_lr
|
150 |
-
self.warmup_steps = warmup_steps
|
151 |
-
self.start_decay_after_n_steps = start_decay_after_n_steps
|
152 |
-
self.end_decay_after_n_steps = end_decay_after_n_steps
|
153 |
-
self.on_use = on_use
|
154 |
-
super(Esm2LRScheduler, self).__init__(optimizer, last_epoch, verbose)
|
155 |
-
|
156 |
-
def state_dict(self):
|
157 |
-
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
|
158 |
-
return state_dict
|
159 |
-
|
160 |
-
def load_state_dict(self, state_dict):
|
161 |
-
self.__dict__.update(state_dict)
|
162 |
-
|
163 |
-
def get_lr(self):
|
164 |
-
if not self._get_lr_called_within_step:
|
165 |
-
raise RuntimeError(
|
166 |
-
"To get the last learning rate computed by the scheduler, use "
|
167 |
-
"get_last_lr()"
|
168 |
-
)
|
169 |
-
|
170 |
-
step_no = self.last_epoch
|
171 |
-
if not self.on_use:
|
172 |
-
return [base_lr for base_lr in self.base_lrs]
|
173 |
-
|
174 |
-
if step_no <= self.warmup_steps:
|
175 |
-
lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr)
|
176 |
-
|
177 |
-
elif step_no <= self.start_decay_after_n_steps:
|
178 |
-
lr = self.max_lr
|
179 |
-
|
180 |
-
elif step_no <= self.end_decay_after_n_steps:
|
181 |
-
portion = (step_no - self.start_decay_after_n_steps) / (self.end_decay_after_n_steps - self.start_decay_after_n_steps)
|
182 |
-
lr = self.max_lr - portion * (self.max_lr - self.final_lr)
|
183 |
-
|
184 |
-
else:
|
185 |
-
lr = self.final_lr
|
186 |
-
|
187 |
-
return [lr for group in self.optimizer.param_groups]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/mpr.py
DELETED
@@ -1,397 +0,0 @@
|
|
1 |
-
import abc
|
2 |
-
import os
|
3 |
-
import time
|
4 |
-
import sys
|
5 |
-
|
6 |
-
|
7 |
-
from tqdm import tqdm
|
8 |
-
from math import ceil
|
9 |
-
|
10 |
-
|
11 |
-
class MultipleProcessRunner:
|
12 |
-
"""
|
13 |
-
Abstarct class for running tasks with multiple process
|
14 |
-
There are three abstract methods that should be implemented:
|
15 |
-
1. __len__() : return the length of data
|
16 |
-
2. _target() : target function for each process
|
17 |
-
3. _aggregate() : aggregate results from each process
|
18 |
-
"""
|
19 |
-
|
20 |
-
def __init__(self,
|
21 |
-
data,
|
22 |
-
save_path=None,
|
23 |
-
n_process=1,
|
24 |
-
verbose=True,
|
25 |
-
total_only=True,
|
26 |
-
log_step=1,
|
27 |
-
start_method='fork'):
|
28 |
-
"""
|
29 |
-
Args:
|
30 |
-
data : data to be processed that can be sliced
|
31 |
-
|
32 |
-
path : final output path
|
33 |
-
|
34 |
-
n_process: number of process
|
35 |
-
|
36 |
-
verbose : if True, display progress bar
|
37 |
-
|
38 |
-
total_only: If True, only total progress bar is displayed
|
39 |
-
|
40 |
-
log_step : For total progress bar, Next log will be printed when
|
41 |
-
``current iteration`` - ``last log iteration`` >= log_step
|
42 |
-
|
43 |
-
start_method: start method for multiprocessing
|
44 |
-
"""
|
45 |
-
self.data = data
|
46 |
-
self.save_path = save_path
|
47 |
-
self.n_process = n_process
|
48 |
-
self.verbose = verbose
|
49 |
-
self.total_only = total_only
|
50 |
-
self.log_step = log_step
|
51 |
-
self.start_method = start_method
|
52 |
-
|
53 |
-
# get terminal width to format output
|
54 |
-
try:
|
55 |
-
self.terminal_y = os.get_terminal_size()[0]
|
56 |
-
|
57 |
-
except Exception as e:
|
58 |
-
print(e)
|
59 |
-
print("Can't get terminal size, set terminal_y = None")
|
60 |
-
self.terminal_y = None
|
61 |
-
|
62 |
-
def _s2hms(self, seconds: float):
|
63 |
-
"""
|
64 |
-
convert second format of time into hour:minute:second format
|
65 |
-
|
66 |
-
"""
|
67 |
-
m, s = divmod(seconds, 60)
|
68 |
-
h, m = divmod(m, 60)
|
69 |
-
|
70 |
-
return "%02d:%02d:%02d" % (h, m, s)
|
71 |
-
|
72 |
-
def _display_time(self, st_time, now, total):
|
73 |
-
ed_time = time.time()
|
74 |
-
running_time = ed_time - st_time
|
75 |
-
rest_time = running_time * (total - now) / now
|
76 |
-
iter_sec = f"{now / running_time:.2f}it/s" if now > running_time else f"{running_time / now:.2f}s/it"
|
77 |
-
|
78 |
-
return f' [{self._s2hms(running_time)} < {self._s2hms(rest_time)}, {iter_sec}]'
|
79 |
-
|
80 |
-
def _display_bar(self, now, total, length):
|
81 |
-
now = now if now <= total else total
|
82 |
-
num = now * length // total
|
83 |
-
progress_bar = '[' + '#' * num + '_' * (length - num) + ']'
|
84 |
-
return progress_bar
|
85 |
-
|
86 |
-
def _display_all(self, now, total, desc, st_time):
|
87 |
-
# make a progress bar
|
88 |
-
length = 50
|
89 |
-
progress_bar = self._display_bar(now, total, length)
|
90 |
-
time_display = self._display_time(st_time, now, total)
|
91 |
-
|
92 |
-
display = f'{desc}{progress_bar} {int(now / total * 100):02d}% {now}/{total}{time_display}'
|
93 |
-
|
94 |
-
# Clean a line
|
95 |
-
width = self.terminal_y if self.terminal_y is not None else 100
|
96 |
-
num_space = width - len(display)
|
97 |
-
if num_space > 0:
|
98 |
-
display += ' ' * num_space
|
99 |
-
else:
|
100 |
-
length += num_space
|
101 |
-
progress_bar = self._display_bar(now, total, length)
|
102 |
-
display = f'{desc}{progress_bar} {int(now / total * 100):02d}% {now}/{total}{time_display}'
|
103 |
-
|
104 |
-
# Set color
|
105 |
-
display = f"\033[31m{display}\033[0m"
|
106 |
-
|
107 |
-
return display
|
108 |
-
|
109 |
-
# Print progress bar at specific position in terminal
|
110 |
-
def terminal_progress_bar(self,
|
111 |
-
process_id: int,
|
112 |
-
now: int,
|
113 |
-
total: int,
|
114 |
-
desc: str = ''):
|
115 |
-
"""
|
116 |
-
|
117 |
-
Args:
|
118 |
-
process_id: process id
|
119 |
-
now: now iteration number
|
120 |
-
total: total iteration number
|
121 |
-
desc: description
|
122 |
-
|
123 |
-
"""
|
124 |
-
st_time = self.process_st_time[process_id]
|
125 |
-
|
126 |
-
# Aggregate total information
|
127 |
-
self.counts[process_id] = now
|
128 |
-
self._total_display(self.process_st_time["total"])
|
129 |
-
|
130 |
-
if not self.total_only:
|
131 |
-
process_display = self._display_all(now, total, desc, st_time)
|
132 |
-
if self.terminal_y is not None:
|
133 |
-
sys.stdout.write(f"\x1b7\x1b[{process_id + 1};{0}f{process_display}\x1b8")
|
134 |
-
sys.stdout.flush()
|
135 |
-
else:
|
136 |
-
print(f"\x1b7\x1b[{process_id + 1};{0}f{process_display}\x1b8", flush=True)
|
137 |
-
|
138 |
-
# Print global information
|
139 |
-
def _total_display(self, st_time):
|
140 |
-
if self.total_display_callable.value == 1:
|
141 |
-
self.total_display_callable.value = 0
|
142 |
-
|
143 |
-
cnt = sum([self.counts[i] for i in range(self.n_process)])
|
144 |
-
if cnt - self.last_cnt.value >= self.log_step:
|
145 |
-
total_display = self._display_all(cnt, self.__len__(), f"Total: ", st_time)
|
146 |
-
self.last_cnt.value = cnt
|
147 |
-
|
148 |
-
x = self.n_process + 1 if not self.total_only else 0
|
149 |
-
# if self.terminal_y is not None:
|
150 |
-
# sys.stdout.write(f"\x1b7\x1b[{x};{0}f{total_display}\x1b8")
|
151 |
-
# sys.stdout.flush()
|
152 |
-
# else:
|
153 |
-
# print(f"\x1b7\x1b[{x};{0}f{total_display}\x1b8", flush=True)
|
154 |
-
print(f"\r\x1b7\x1b[{x};{0}f{total_display}\x1b8", flush=True, end="")
|
155 |
-
|
156 |
-
self.total_display_callable.value = 1
|
157 |
-
|
158 |
-
def run(self):
|
159 |
-
"""
|
160 |
-
The function is used to run a multi-process task
|
161 |
-
Returns: return the result of function '_aggregate()'
|
162 |
-
"""
|
163 |
-
|
164 |
-
import multiprocess as mp
|
165 |
-
mp.set_start_method(self.start_method, force=True)
|
166 |
-
|
167 |
-
# total number of data that is already processed
|
168 |
-
self.counts = mp.Manager().dict({i: 0 for i in range(self.n_process)})
|
169 |
-
|
170 |
-
# record start time for each process
|
171 |
-
self.process_st_time = {"total": time.time()}
|
172 |
-
|
173 |
-
# set a lock to call total number display
|
174 |
-
self.total_display_callable = mp.Value('d', 1)
|
175 |
-
|
176 |
-
# Save last log iteration number
|
177 |
-
self.last_cnt = mp.Value('d', 0)
|
178 |
-
|
179 |
-
num_per_process = ceil(self.__len__() / self.n_process)
|
180 |
-
|
181 |
-
if self.save_path is not None:
|
182 |
-
file_name, suffix = os.path.splitext(self.save_path)
|
183 |
-
|
184 |
-
process_list = []
|
185 |
-
sub_paths = []
|
186 |
-
for i in range(self.n_process):
|
187 |
-
st = i * num_per_process
|
188 |
-
ed = st + num_per_process
|
189 |
-
|
190 |
-
# construct slice and sub path for sub process
|
191 |
-
data_slice = self.data[st: ed]
|
192 |
-
|
193 |
-
sub_path = None
|
194 |
-
# Create a directory to save sub-results
|
195 |
-
if self.save_path is not None:
|
196 |
-
save_dir = f"{file_name}{suffix}_temp"
|
197 |
-
os.makedirs(save_dir, exist_ok=True)
|
198 |
-
sub_path = f"{save_dir}/temp_{i}{suffix}"
|
199 |
-
|
200 |
-
# construct sub process
|
201 |
-
input_args = (i, data_slice, sub_path)
|
202 |
-
self.process_st_time[i] = time.time()
|
203 |
-
p = mp.Process(target=self._target, args=input_args)
|
204 |
-
p.start()
|
205 |
-
|
206 |
-
process_list.append(p)
|
207 |
-
sub_paths.append(sub_path)
|
208 |
-
|
209 |
-
for p in process_list:
|
210 |
-
p.join()
|
211 |
-
|
212 |
-
# aggregate results and remove temporary directory
|
213 |
-
results = self._aggregate(self.save_path, sub_paths)
|
214 |
-
if self.save_path is not None:
|
215 |
-
save_dir = f"{file_name}{suffix}_temp"
|
216 |
-
os.rmdir(save_dir)
|
217 |
-
|
218 |
-
return results
|
219 |
-
|
220 |
-
def parallel_run(self):
|
221 |
-
import multiprocess as mp
|
222 |
-
from joblib import Parallel, delayed
|
223 |
-
|
224 |
-
# total number of data that is already processed
|
225 |
-
self.counts = mp.Manager().dict({i: 0 for i in range(self.n_process)})
|
226 |
-
|
227 |
-
# record start time for each process
|
228 |
-
self.process_st_time = {"total": time.time()}
|
229 |
-
|
230 |
-
# set a lock to call total number display
|
231 |
-
self.total_display_callable = mp.Value('d', 1)
|
232 |
-
|
233 |
-
# Save last log iteration number
|
234 |
-
self.last_cnt = mp.Value('d', 0)
|
235 |
-
|
236 |
-
num_per_process = ceil(self.__len__() / self.n_process)
|
237 |
-
|
238 |
-
if self.save_path is not None:
|
239 |
-
file_name, suffix = os.path.splitext(self.save_path)
|
240 |
-
|
241 |
-
sub_paths = []
|
242 |
-
input_arg_list = []
|
243 |
-
for i in range(self.n_process):
|
244 |
-
st = i * num_per_process
|
245 |
-
ed = st + num_per_process
|
246 |
-
|
247 |
-
# construct slice and sub path for sub process
|
248 |
-
data_slice = self.data[st: ed]
|
249 |
-
|
250 |
-
sub_path = None
|
251 |
-
# Create a directory to save sub-results
|
252 |
-
if self.save_path is not None:
|
253 |
-
save_dir = f"{file_name}{suffix}_temp"
|
254 |
-
os.makedirs(save_dir, exist_ok=True)
|
255 |
-
sub_path = f"{save_dir}/temp_{i}{suffix}"
|
256 |
-
|
257 |
-
# construct sub process
|
258 |
-
input_args = (i, data_slice, sub_path)
|
259 |
-
self.process_st_time[i] = time.time()
|
260 |
-
|
261 |
-
sub_paths.append(sub_path)
|
262 |
-
input_arg_list.append(input_args)
|
263 |
-
|
264 |
-
# Start parallel processing
|
265 |
-
Parallel(n_jobs=self.n_process)(delayed(self._target)(input_args) for input_args in input_arg_list)
|
266 |
-
|
267 |
-
# aggregate results and remove temporary directory
|
268 |
-
results = self._aggregate(self.save_path, sub_paths)
|
269 |
-
if self.save_path is not None:
|
270 |
-
save_dir = f"{file_name}{suffix}_temp"
|
271 |
-
os.rmdir(save_dir)
|
272 |
-
|
273 |
-
return results
|
274 |
-
|
275 |
-
|
276 |
-
@abc.abstractmethod
|
277 |
-
def _aggregate(self, final_path: str, sub_paths):
|
278 |
-
"""
|
279 |
-
This function is used to aggregate results from sub processes into a file
|
280 |
-
|
281 |
-
Args:
|
282 |
-
final_path: path to save final results
|
283 |
-
sub_paths : list of sub paths
|
284 |
-
|
285 |
-
Returns: None or desirable results specified by user
|
286 |
-
|
287 |
-
"""
|
288 |
-
raise NotImplementedError
|
289 |
-
|
290 |
-
@abc.abstractmethod
|
291 |
-
def _target(self, process_id, data, sub_path):
|
292 |
-
"""
|
293 |
-
The main body to operate data in one process
|
294 |
-
|
295 |
-
Args:
|
296 |
-
i : process id
|
297 |
-
data : data slice
|
298 |
-
sub_path: sub path to save results
|
299 |
-
"""
|
300 |
-
raise NotImplementedError
|
301 |
-
|
302 |
-
@abc.abstractmethod
|
303 |
-
def __len__(self):
|
304 |
-
raise NotImplementedError
|
305 |
-
|
306 |
-
|
307 |
-
class MultipleProcessRunnerSimplifier(MultipleProcessRunner):
|
308 |
-
"""
|
309 |
-
A simplified version of MultipleProcessRunner.
|
310 |
-
User only need to implement the function 'do', then it will be automatically executed
|
311 |
-
in every iteration after call the function 'run'.
|
312 |
-
If 'save_path' is specified, it will open a file in the 'sub_path' into which
|
313 |
-
user can write results, and results will be aggregated into 'save_path'.
|
314 |
-
|
315 |
-
The procedure would be like:
|
316 |
-
...
|
317 |
-
with open(sub_path, 'w') as w:
|
318 |
-
for i, d in enumerate(data):
|
319 |
-
self.do(process_id, i, d, w) # You can write results into the file.
|
320 |
-
...
|
321 |
-
|
322 |
-
The 'do' function should be like:
|
323 |
-
def do(process_id, idx, data, writer):
|
324 |
-
...
|
325 |
-
|
326 |
-
If 'save_path' is None, the argument 'writer' will be set to None.
|
327 |
-
|
328 |
-
"""
|
329 |
-
|
330 |
-
def __init__(self,
|
331 |
-
data,
|
332 |
-
do,
|
333 |
-
save_path=None,
|
334 |
-
n_process=1,
|
335 |
-
verbose=True,
|
336 |
-
total_only=True,
|
337 |
-
log_step=1,
|
338 |
-
return_results=False,
|
339 |
-
start_method='fork'):
|
340 |
-
|
341 |
-
super().__init__(data=data,
|
342 |
-
save_path=save_path,
|
343 |
-
n_process=n_process,
|
344 |
-
verbose=verbose,
|
345 |
-
total_only=total_only,
|
346 |
-
log_step=log_step,
|
347 |
-
start_method=start_method)
|
348 |
-
self.do = do
|
349 |
-
self.return_results = return_results
|
350 |
-
|
351 |
-
def run(self):
|
352 |
-
self.start_time = time.time()
|
353 |
-
return super().run()
|
354 |
-
|
355 |
-
def _aggregate(self, final_path: str, sub_paths):
|
356 |
-
results = []
|
357 |
-
|
358 |
-
w = open(final_path, 'w') if final_path is not None else None
|
359 |
-
|
360 |
-
if self.verbose:
|
361 |
-
iterator = tqdm(enumerate(sub_paths), "Aggregating results...")
|
362 |
-
else:
|
363 |
-
iterator = enumerate(sub_paths)
|
364 |
-
|
365 |
-
for i, sub_path in iterator:
|
366 |
-
if sub_path is None and self.return_results:
|
367 |
-
sub_path = f"MultipleProcessRunnerSimplifier_{self.start_time}_{i}.tmp"
|
368 |
-
|
369 |
-
if sub_path is not None:
|
370 |
-
with open(sub_path, 'r') as r:
|
371 |
-
for line in r:
|
372 |
-
if w is not None:
|
373 |
-
w.write(line)
|
374 |
-
|
375 |
-
if self.return_results:
|
376 |
-
results.append(line[:-1])
|
377 |
-
|
378 |
-
os.remove(sub_path)
|
379 |
-
|
380 |
-
return results
|
381 |
-
|
382 |
-
def _target(self, process_id, data, sub_path):
|
383 |
-
if sub_path is None and self.return_results:
|
384 |
-
sub_path = f"MultipleProcessRunnerSimplifier_{self.start_time}_{process_id}.tmp"
|
385 |
-
|
386 |
-
w = open(sub_path, 'w') if sub_path is not None else None
|
387 |
-
for i, d in enumerate(data):
|
388 |
-
self.do(process_id, i, d, w)
|
389 |
-
if self.verbose:
|
390 |
-
self.terminal_progress_bar(process_id, i + 1, len(data), f"Process{process_id} running...")
|
391 |
-
|
392 |
-
if w is not None:
|
393 |
-
w.close()
|
394 |
-
|
395 |
-
def __len__(self):
|
396 |
-
return len(self.data)
|
397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|