|
import os |
|
import time |
|
import json |
|
import numpy as np |
|
import re |
|
import sys |
|
sys.path.append(".") |
|
|
|
|
|
|
|
def get_struc_seq(foldseek, |
|
path, |
|
chains: list = None, |
|
process_id: int = 0, |
|
plddt_mask: bool = False, |
|
plddt_threshold: float = 70., |
|
foldseek_verbose: bool = False) -> dict: |
|
""" |
|
|
|
Args: |
|
foldseek: Binary executable file of foldseek |
|
|
|
path: Path to pdb file |
|
|
|
chains: Chains to be extracted from pdb file. If None, all chains will be extracted. |
|
|
|
process_id: Process ID for temporary files. This is used for parallel processing. |
|
|
|
plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file. |
|
|
|
plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked. |
|
|
|
foldseek_verbose: If True, foldseek will print verbose messages. |
|
|
|
Returns: |
|
seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of |
|
(seq, struc_seq, combined_seq). |
|
""" |
|
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}" |
|
assert os.path.exists(path), f"PDB file not found: {path}" |
|
|
|
tmp_save_path = f"/tmp/get_struc_seq_{process_id}_{time.time()}.tsv" |
|
if foldseek_verbose: |
|
cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}" |
|
else: |
|
cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}" |
|
os.system(cmd) |
|
|
|
seq_dict = {} |
|
name = os.path.basename(path) |
|
with open(tmp_save_path, "r") as r: |
|
for i, line in enumerate(r): |
|
desc, seq, struc_seq = line.split("\t")[:3] |
|
|
|
|
|
if plddt_mask: |
|
plddts = extract_plddt(path) |
|
assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}" |
|
|
|
|
|
indices = np.where(plddts < plddt_threshold)[0] |
|
np_seq = np.array(list(struc_seq)) |
|
np_seq[indices] = "#" |
|
struc_seq = "".join(np_seq) |
|
|
|
name_chain = desc.split(" ")[0] |
|
chain = name_chain.replace(name, "").split("_")[-1] |
|
|
|
if chains is None or chain in chains: |
|
if chain not in seq_dict: |
|
combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)]) |
|
seq_dict[chain] = (seq, struc_seq, combined_seq) |
|
|
|
os.remove(tmp_save_path) |
|
os.remove(tmp_save_path + ".dbtype") |
|
return seq_dict |
|
|
|
|
|
def extract_plddt(pdb_path: str) -> np.ndarray: |
|
""" |
|
Extract plddt scores from pdb file. |
|
Args: |
|
pdb_path: Path to pdb file. |
|
|
|
Returns: |
|
plddts: plddt scores. |
|
""" |
|
with open(pdb_path, "r") as r: |
|
plddt_dict = {} |
|
for line in r: |
|
line = re.sub(' +', ' ', line).strip() |
|
splits = line.split(" ") |
|
|
|
if splits[0] == "ATOM": |
|
|
|
if len(splits[4]) == 1: |
|
pos = int(splits[5]) |
|
|
|
|
|
|
|
else: |
|
pos = int(splits[4][1:]) |
|
|
|
plddt = float(splits[-2]) |
|
|
|
if pos not in plddt_dict: |
|
plddt_dict[pos] = [plddt] |
|
else: |
|
plddt_dict[pos].append(plddt) |
|
|
|
plddts = np.array([np.mean(v) for v in plddt_dict.values()]) |
|
return plddts |
|
|
|
|
|
if __name__ == '__main__': |
|
foldseek = "/sujin/bin/foldseek" |
|
|
|
test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb" |
|
plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json" |
|
res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.) |
|
print(res["A"][1].lower()) |