File size: 5,093 Bytes
d551fc8
 
 
 
 
 
 
 
 
 
 
 
 
f123b98
d551fc8
 
f123b98
d551fc8
 
55a6bd8
d551fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55a6bd8
d551fc8
 
 
 
 
 
 
 
55a6bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d551fc8
 
 
 
 
 
 
55a6bd8
 
 
 
 
 
 
d551fc8
f123b98
55a6bd8
d551fc8
f123b98
 
d551fc8
f123b98
 
55a6bd8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
Page for similarities
"""

################
# DEPENDENCIES #
################
import streamlit as st
import pandas as pd
from scipy.sparse import load_npz
import pickle
import faiss
from sentence_transformers import SentenceTransformer
from modules.result_table import show_table
import modules.semantic_search as semantic_search
from functions.filter_projects import filter_projects
from functions.calc_matches import calc_matches
import psutil
import os
import gc

def get_process_memory():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / (1024 * 1024) 

# Catch DATA
# Load Similarity matrix
@st.cache_data
def load_sim_matrix():
    loaded_matrix = load_npz("src/similarities.npz")
    dense_matrix = loaded_matrix.toarray()

    return dense_matrix

# Load Projects DFs
@st.cache_data
def load_projects():
    orgas_df = pd.read_csv("src/projects/project_orgas.csv")
    region_df = pd.read_csv("src/projects/project_region.csv")
    sector_df = pd.read_csv("src/projects/project_sector.csv")
    status_df = pd.read_csv("src/projects/project_status.csv")
    texts_df = pd.read_csv("src/projects/project_texts.csv")

    projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner')
    projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner')
    projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner')
    projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner')

    return projects_df

# Load CRS 3 data
@st.cache_data
def getCRS3():
    # Read in CRS3 CODELISTS
    crs3_df = pd.read_csv('src/codelists/crs3_codes.csv')
    CRS3_CODES = crs3_df['code'].tolist()
    CRS3_NAME = crs3_df['name'].tolist()
    CRS3_MERGED = {f"{name} - {code}": code for name, code in zip(CRS3_NAME, CRS3_CODES)}

    return CRS3_MERGED

# Load CRS 5 data
@st.cache_data
def getCRS5():
    # Read in CRS3 CODELISTS
    crs5_df = pd.read_csv('src/codelists/crs5_codes.csv')
    CRS5_CODES = crs5_df['code'].tolist()
    CRS5_NAME = crs5_df['name'].tolist()
    CRS5_MERGED = {code: [f"{name} - {code}"] for name, code in zip(CRS5_NAME, CRS5_CODES)}

    return CRS5_MERGED

# Load SDG data
@st.cache_data
def getSDG():
    # Read in SDG CODELISTS
    sdg_df = pd.read_csv('src/codelists/sdg_goals.csv')
    SDG_NAMES = sdg_df['name'].tolist()

    return SDG_NAMES

# Load Sentence Transformer Model
@st.cache_resource
def load_model():
    model = SentenceTransformer('all-MiniLM-L6-v2')
    return model


# Load Embeddings
@st.cache_data 
def load_embeddings_and_index():
    # Load embeddings
    with open("src/embeddings.pkl", "rb") as fIn:
        stored_data = pickle.load(fIn)
    sentences = stored_data["sentences"]
    embeddings = stored_data["embeddings"]

    # Load or create FAISS index
    dimension = embeddings.shape[1]  
    faiss_index = faiss.IndexFlatL2(dimension)
    faiss_index.add(embeddings)

    return sentences, embeddings, faiss_index

# USE CACHE FUNCTIONS 
sim_matrix = load_sim_matrix()
projects_df = load_projects()

CRS3_MERGED = getCRS3()
CRS5_MERGED = getCRS5()
SDG_NAMES = getSDG()

model = load_model()
sentences, embeddings, faiss_index = load_embeddings_and_index()

def show_page():
    st.write(f"Current RAM usage of this app: {get_process_memory():.2f} MB")
    st.write("Similarities")

    st.session_state.crs5_option_disabled = True
    col1, col2 = st.columns([1, 1])
    with col1:
        # CRS 3 SELECTION
        crs3_option = st.multiselect(
                        'CRS 3',
                        CRS3_MERGED,
                        placeholder="Select"
                        )

        # CRS 5 SELECTION
        ## Only enable crs5 select field when crs3 code is selected
        if crs3_option != []:
            st.session_state.crs5_option_disabled = False

        ## define list of crs5 codes dependend on crs3 codes
        crs5_list = [txt[0].replace('"', "") for crs3_item in crs3_option for code, txt in CRS5_MERGED.items() if str(code)[:3] == str(crs3_item)[-3:]]

        ## crs5 select field
        crs5_option = st.multiselect(
            'CRS 5',
            crs5_list,
            placeholder="Select",
            disabled=st.session_state.crs5_option_disabled
            )
        
        # SDG SELECTION
        sdg_option = st.selectbox(
                label = 'SDG',
                index = None,
                placeholder = "Select SDG",
                options = SDG_NAMES[:-1],
                )

    
    with col2:
        st.write("x")


    # CRS CODE LIST
    crs3_list = [i[-3:] for i in crs3_option]
    crs5_list = [i[-5:] for i in crs5_option]

    # SDG CODE LIST
    if sdg_option != None:
        sdg_str = sdg_option[0]
    else:
        sdg_str = ""

    # FILTER DF WITH SELECTED FILTER OPTIONS
    filtered_df = filter_projects(projects_df, crs3_list, crs5_list, sdg_str)

    # FIND MATCHES
    p1_df, p2_df = calc_matches(filtered_df, projects_df, sim_matrix)

    # SHOW THE RESULT
    show_table(p1_df, p2_df)

    del p1_df, p2_df, crs3_list, crs5_list, sdg_str, filtered_df
    gc.collect()