|
import faiss
|
|
import numpy as np
|
|
import pandas as pd
|
|
import os
|
|
import yaml
|
|
import glob
|
|
|
|
from easydict import EasyDict
|
|
from utils.constants import sequence_level
|
|
from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
|
|
from tqdm import tqdm
|
|
|
|
|
|
def load_model():
|
|
model_config = {
|
|
"protein_config": glob.glob(f"{config.model_dir}/esm2_*")[0],
|
|
"text_config": f"{config.model_dir}/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
|
|
"structure_config": glob.glob(f"{config.model_dir}/foldseek_*")[0],
|
|
"load_protein_pretrained": False,
|
|
"load_text_pretrained": False,
|
|
"from_checkpoint": glob.glob(f"{config.model_dir}/*.pt")[0]
|
|
}
|
|
|
|
model = ProTrekTrimodalModel(**model_config)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def load_faiss_index(index_path: str):
|
|
if config.faiss_config.IO_FLAG_MMAP:
|
|
index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
|
|
else:
|
|
index = faiss.read_index(index_path)
|
|
|
|
index.metric_type = faiss.METRIC_INNER_PRODUCT
|
|
return index
|
|
|
|
|
|
def load_index():
|
|
all_index = {}
|
|
|
|
|
|
all_index["sequence"] = {}
|
|
for db in tqdm(config.sequence_index_dir, desc="Loading sequence index..."):
|
|
db_name = db["name"]
|
|
index_dir = db["index_dir"]
|
|
|
|
index_path = f"{index_dir}/sequence.index"
|
|
sequence_index = load_faiss_index(index_path)
|
|
|
|
id_path = f"{index_dir}/ids.tsv"
|
|
uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
|
|
|
|
all_index["sequence"][db_name] = {"index": sequence_index, "ids": uniprot_ids}
|
|
|
|
|
|
print("Loading structure index...")
|
|
all_index["structure"] = {}
|
|
for db in tqdm(config.structure_index_dir, desc="Loading structure index..."):
|
|
db_name = db["name"]
|
|
index_dir = db["index_dir"]
|
|
|
|
index_path = f"{index_dir}/structure.index"
|
|
structure_index = load_faiss_index(index_path)
|
|
|
|
id_path = f"{index_dir}/ids.tsv"
|
|
uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
|
|
|
|
all_index["structure"][db_name] = {"index": structure_index, "ids": uniprot_ids}
|
|
|
|
|
|
all_index["text"] = {}
|
|
valid_subsections = {}
|
|
for db in tqdm(config.text_index_dir, desc="Loading text index..."):
|
|
db_name = db["name"]
|
|
index_dir = db["index_dir"]
|
|
all_index["text"][db_name] = {}
|
|
text_dir = f"{index_dir}/subsections"
|
|
|
|
|
|
valid_subsections[db_name] = set()
|
|
sequence_level.add("Global")
|
|
for subsection in tqdm(sequence_level):
|
|
index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index"
|
|
if not os.path.exists(index_path):
|
|
continue
|
|
|
|
text_index = load_faiss_index(index_path)
|
|
|
|
id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv"
|
|
text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
|
|
|
|
all_index["text"][db_name][subsection] = {"index": text_index, "ids": text_ids}
|
|
valid_subsections[db_name].add(subsection)
|
|
|
|
|
|
for db_name in valid_subsections:
|
|
valid_subsections[db_name] = sorted(list(valid_subsections[db_name]))
|
|
|
|
return all_index, valid_subsections
|
|
|
|
|
|
|
|
root_dir = __file__.rsplit("/", 3)[0]
|
|
config_path = f"{root_dir}/demo/config.yaml"
|
|
with open(config_path, 'r', encoding='utf-8') as r:
|
|
config = EasyDict(yaml.safe_load(r))
|
|
|
|
device = "cuda"
|
|
|
|
print("Loading model...")
|
|
model = load_model()
|
|
model.to(device)
|
|
|
|
all_index, valid_subsections = load_index()
|
|
print("Done...")
|
|
|
|
|