File size: 4,005 Bytes
5b69b32
 
 
 
 
 
 
 
 
 
 
 
ed1afd7
5b69b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52da96f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import faiss
import numpy as np
import pandas as pd
import os
import yaml
import glob

from easydict import EasyDict
from utils.constants import sequence_level
from model.ProTrek.protrek_trimodal_model import ProTrekTrimodalModel
from tqdm import tqdm

print(os.listdir("/data"))
def load_model():
    model_config = {
        "protein_config": glob.glob(f"{config.model_dir}/esm2_*")[0],
        "text_config": f"{config.model_dir}/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
        "structure_config": glob.glob(f"{config.model_dir}/foldseek_*")[0],
        "load_protein_pretrained": False,
        "load_text_pretrained": False,
        "from_checkpoint": glob.glob(f"{config.model_dir}/*.pt")[0]
    }

    model = ProTrekTrimodalModel(**model_config)
    model.eval()
    return model


def load_faiss_index(index_path: str):
    if config.faiss_config.IO_FLAG_MMAP:
        index = faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
    else:
        index = faiss.read_index(index_path)
        
    index.metric_type = faiss.METRIC_INNER_PRODUCT
    return index


def load_index():
    all_index = {}
    
    # Load protein sequence index
    all_index["sequence"] = {}
    for db in tqdm(config.sequence_index_dir, desc="Loading sequence index..."):
        db_name = db["name"]
        index_dir = db["index_dir"]

        index_path = f"{index_dir}/sequence.index"
        sequence_index = load_faiss_index(index_path)

        id_path = f"{index_dir}/ids.tsv"
        uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()

        all_index["sequence"][db_name] = {"index": sequence_index, "ids": uniprot_ids}

    # Load protein structure index
    print("Loading structure index...")
    all_index["structure"] = {}
    for db in tqdm(config.structure_index_dir, desc="Loading structure index..."):
        db_name = db["name"]
        index_dir = db["index_dir"]

        index_path = f"{index_dir}/structure.index"
        structure_index = load_faiss_index(index_path)

        id_path = f"{index_dir}/ids.tsv"
        uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()

        all_index["structure"][db_name] = {"index": structure_index, "ids": uniprot_ids}
    
    # Load text index
    all_index["text"] = {}
    valid_subsections = {}
    for db in tqdm(config.text_index_dir, desc="Loading text index..."):
        db_name = db["name"]
        index_dir = db["index_dir"]
        all_index["text"][db_name] = {}
        text_dir = f"{index_dir}/subsections"
        
        # Remove "Taxonomic lineage" from sequence_level. This is a special case which we don't need to index.
        valid_subsections[db_name] = set()
        sequence_level.add("Global")
        for subsection in tqdm(sequence_level):
            index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index"
            if not os.path.exists(index_path):
                continue
                
            text_index = load_faiss_index(index_path)
            
            id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv"
            text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
            
            all_index["text"][db_name][subsection] = {"index": text_index, "ids": text_ids}
            valid_subsections[db_name].add(subsection)
    
    # Sort valid_subsections
    for db_name in valid_subsections:
        valid_subsections[db_name] = sorted(list(valid_subsections[db_name]))

    return all_index, valid_subsections


# Load the config file
root_dir = __file__.rsplit("/", 3)[0]
config_path = f"{root_dir}/demo/config.yaml"
with open(config_path, 'r', encoding='utf-8') as r:
    config = EasyDict(yaml.safe_load(r))

device = "cuda"

print("Loading model...")
model = load_model()
# model.to(device)

all_index, valid_subsections = load_index()
print("Done...")
# model = None
# all_index, valid_subsections = {"text": {}, "sequence": {"UniRef50": None}, "structure": {"UniRef50": None}}, {}