File size: 4,005 Bytes
5b69b32 ed1afd7 5b69b32 52da96f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import faiss
import numpy as np
import pandas as pd
import os
import yaml
import glob
from easydict import EasyDict
from utils.constants import sequence_level
from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
from tqdm import tqdm
print(os.listdir("/data"))
def load_model():
model_config = {
"protein_config": glob.glob(f"{config.model_dir}/esm2_*")[0],
"text_config": f"{config.model_dir}/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"structure_config": glob.glob(f"{config.model_dir}/foldseek_*")[0],
"load_protein_pretrained": False,
"load_text_pretrained": False,
"from_checkpoint": glob.glob(f"{config.model_dir}/*.pt")[0]
}
model = ProTrekTrimodalModel(**model_config)
model.eval()
return model
def load_faiss_index(index_path: str):
if config.faiss_config.IO_FLAG_MMAP:
index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
else:
index = faiss.read_index(index_path)
index.metric_type = faiss.METRIC_INNER_PRODUCT
return index
def load_index():
all_index = {}
# Load protein sequence index
all_index["sequence"] = {}
for db in tqdm(config.sequence_index_dir, desc="Loading sequence index..."):
db_name = db["name"]
index_dir = db["index_dir"]
index_path = f"{index_dir}/sequence.index"
sequence_index = load_faiss_index(index_path)
id_path = f"{index_dir}/ids.tsv"
uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
all_index["sequence"][db_name] = {"index": sequence_index, "ids": uniprot_ids}
# Load protein structure index
print("Loading structure index...")
all_index["structure"] = {}
for db in tqdm(config.structure_index_dir, desc="Loading structure index..."):
db_name = db["name"]
index_dir = db["index_dir"]
index_path = f"{index_dir}/structure.index"
structure_index = load_faiss_index(index_path)
id_path = f"{index_dir}/ids.tsv"
uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
all_index["structure"][db_name] = {"index": structure_index, "ids": uniprot_ids}
# Load text index
all_index["text"] = {}
valid_subsections = {}
for db in tqdm(config.text_index_dir, desc="Loading text index..."):
db_name = db["name"]
index_dir = db["index_dir"]
all_index["text"][db_name] = {}
text_dir = f"{index_dir}/subsections"
# Remove "Taxonomic lineage" from sequence_level. This is a special case which we don't need to index.
valid_subsections[db_name] = set()
sequence_level.add("Global")
for subsection in tqdm(sequence_level):
index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index"
if not os.path.exists(index_path):
continue
text_index = load_faiss_index(index_path)
id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv"
text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
all_index["text"][db_name][subsection] = {"index": text_index, "ids": text_ids}
valid_subsections[db_name].add(subsection)
# Sort valid_subsections
for db_name in valid_subsections:
valid_subsections[db_name] = sorted(list(valid_subsections[db_name]))
return all_index, valid_subsections
# Load the config file
root_dir = __file__.rsplit("/", 3)[0]
config_path = f"{root_dir}/demo/config.yaml"
with open(config_path, 'r', encoding='utf-8') as r:
config = EasyDict(yaml.safe_load(r))
device = "cuda"
print("Loading model...")
model = load_model()
# model.to(device)
all_index, valid_subsections = load_index()
print("Done...")
# model = None
# all_index, valid_subsections = {"text": {}, "sequence": {"UniRef50": None}, "structure": {"UniRef50": None}}, {} |