friendshipkim commited on
Commit
f864d05
·
1 Parent(s): 207c9f5

initial commit

Browse files
Files changed (2) hide show
  1. app.py +185 -0
  2. requirements.txt +74 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import sys
4
+ import os
5
+ from datasets import load_from_disk, load_dataset
6
+ from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ import numpy as np
9
+ import time
10
+ from annotated_text import annotated_text
11
+
12
+ ORG_ID = "cornell-authorship"
13
+
14
+ @st.cache
15
+ def preprocess_text(s):
16
+ return list(filter(lambda x: x!= '', (''.join(c if c.isalnum() or c == ' ' else ' ' for c in s)).split(' ')))
17
+
18
+ @st.cache
19
+ def get_pairwise_distances(model):
20
+ dataset = load_dataset(f"{ORG_ID}/{model}_distance")["train"]
21
+ df = pd.DataFrame(dataset).set_index('index')
22
+ return df
23
+
24
+ @st.cache
25
+ def get_pairwise_distances_chunked(model, chunk):
26
+ # for df in pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv", chunksize = 16):
27
+ # print(df.iloc[0]['queries'])
28
+ # if chunk == int(df.iloc[0]['queries']):
29
+ # return df
30
+ return get_pairwise_distances(model)
31
+
32
+ @st.cache
33
+ def get_query_strings():
34
+ # df = pd.read_json(hf_hub_download(repo_id=repo_id, filename="IUR_Reddit_test_queries_english.jsonl"), lines = True)
35
+ dataset = load_dataset(f"{ORG_ID}/IUR_Reddit_test_queries_english")["train"]
36
+ df = pd.DataFrame(dataset)
37
+ df['index'] = df.reset_index().index
38
+ return df
39
+ # df['partition'] = df['index']%100
40
+ # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", index = 'index', partition_cols = 'partition')
41
+
42
+ # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", columns=['fullText', 'index', 'authorIDs'])
43
+
44
+ @st.cache
45
+ def get_candidate_strings():
46
+ # df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.jsonl", lines = True)
47
+ dataset = load_dataset(f"{ORG_ID}/IUR_Reddit_test_candidates_english")["train"]
48
+ df = pd.DataFrame(dataset)
49
+ df['index'] = df.reset_index().index
50
+ return df
51
+ # df['partition'] = df['index']%100
52
+ # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", index = 'index', partition_cols = 'partition')
53
+ # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", columns=['fullText', 'index', 'authorIDs'])
54
+
55
+ @st.cache
56
+ def get_embedding_dataset(model):
57
+ # data = load_from_disk(f"{ASSETS_PATH}/{model}/embedding")
58
+ data = load_dataset(f"{ORG_ID}/{model}_embedding")
59
+ return data
60
+
61
+ @st.cache
62
+ def get_bad_queries(model):
63
+ df = get_query_strings().iloc[list(get_pairwise_distances(model)['queries'].unique())][['fullText', 'index', 'authorIDs']]
64
+ return df
65
+
66
+ @st.cache
67
+ def get_gt_candidates(model, author):
68
+ gt_candidates = get_candidate_strings()
69
+ df = gt_candidates[gt_candidates['authorIDs'].apply(lambda x: x[0]) == author]
70
+ return df
71
+
72
+ @st.cache
73
+ def get_candidate_text(l):
74
+ return get_candidate_strings().at[l,'fullText']
75
+
76
+ @st.cache
77
+ def get_annotated_text(text, word, pos):
78
+ # print("here", word, pos)
79
+ start= text.index(word, pos)
80
+ end = start+len(word)
81
+ return (text[:start], (text[start:end ], 'SELECTED'), text[end:]), end
82
+
83
+ class AgGridBuilder:
84
+ __static_key = 0
85
+ def build_ag_grid(table, display_columns):
86
+ AgGridBuilder.__static_key += 1
87
+ options_builder = GridOptionsBuilder.from_dataframe(table[display_columns])
88
+ options_builder.configure_pagination(paginationAutoPageSize=False, paginationPageSize=10)
89
+ options_builder.configure_selection(selection_mode= 'single', pre_selected_rows = [0])
90
+ options = options_builder.build()
91
+ return AgGrid(table, gridOptions = options, fit_columns_on_grid_load=True, key = AgGridBuilder.__static_key, reload_data = True, update_mode = GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED)
92
+
93
+ if __name__ == '__main__':
94
+ st.set_page_config(layout="wide")
95
+
96
+ # models = filter(lambda file_name: os.path.isdir(f"{ASSETS_PATH}/{file_name}") and not file_name.endswith(".parquet"), os.listdir(ASSETS_PATH))
97
+ models = ['luar_clone2_top_100']
98
+
99
+ with st.sidebar:
100
+ current_model = st.selectbox(
101
+ "Select Model to analyze",
102
+ models
103
+ )
104
+
105
+ pairwise_distances = get_pairwise_distances(current_model)
106
+ embedding_dataset = get_embedding_dataset(current_model)
107
+
108
+ candidate_string_grid = None
109
+ gt_candidate_string_grid = None
110
+ with st.container():
111
+ t1 = time.time()
112
+ st.title("Full Text")
113
+ col1, col2 = st.columns([14, 2])
114
+ t2 = time.time()
115
+ query_table = get_bad_queries(current_model)
116
+ t3 = time.time()
117
+ # print(query_table)
118
+ with col2:
119
+ index = st.number_input('Enter Query number to inspect', min_value = 0, max_value = query_table.shape[0], step = 1)
120
+ query_text = query_table.loc[index]['fullText']
121
+ preprocessed_query_text = preprocess_text(query_text)
122
+ text_highlight_index = st.number_input('Enter word #', min_value = 0, max_value = len(preprocessed_query_text), step = 1)
123
+ query_index = int(query_table.iloc[index]['index'])
124
+
125
+ with col1:
126
+ if 'pos_highlight' not in st.session_state or text_highlight_index == 0:
127
+ st.session_state['pos_highlight'] = text_highlight_index
128
+ st.session_state['pos_history'] = [0]
129
+
130
+ if st.session_state['pos_highlight'] > text_highlight_index:
131
+ st.session_state['pos_history'] = st.session_state['pos_history'][:-2]
132
+ if len(st.session_state['pos_history']) == 0:
133
+ st.session_state['pos_history'] = [0]
134
+ # print("pos", st.session_state['pos_history'], st.session_state['pos_highlight'], text_highlight_index)
135
+ anotated_text_, pos = get_annotated_text(query_text, preprocessed_query_text[text_highlight_index-1], st.session_state['pos_history'][-1]) if text_highlight_index >= 1 else ((query_text), 0)
136
+ if st.session_state['pos_highlight'] < text_highlight_index:
137
+ st.session_state['pos_history'].append(pos)
138
+ st.session_state['pos_highlight'] = text_highlight_index
139
+ annotated_text(*anotated_text_)
140
+ # annotated_text("Lol, this" , ('guy', 'SELECTED') , "is such a PR chameleon. \n\n In the Chan Zuckerberg Initiative announcement, he made it sound like he was giving away all his money to charity <PERSON> or <PERSON>. http://www.businessinsider.in/Mark-Zuckerberg-says-hes-giving-99-of-his-Facebook-shares-45-billion-to-charity/articleshow/50005321.cms Apparently, its just a VC fund. And there are still people out there who believe Facebook.org was an initiative to bring Internet to the poor.")
141
+ t4 = time.time()
142
+
143
+ # print(f"query time query text: {t3-t2}, total time: {t4-t1}")
144
+ with st.container():
145
+ st.title("Top 16 Recommended Candidates")
146
+ col1, col2, col3 = st.columns([10, 4, 2])
147
+ rec_candidates = pairwise_distances[pairwise_distances["queries"]==query_index]['candidates']
148
+ # print(rec_candidates)
149
+ l = list(rec_candidates)
150
+ with col3:
151
+ candidate_rec_index = st.number_input('Enter recommended candidate number to inspect', min_value = 0, max_value = len(l), step = 1)
152
+ print("l:",l, query_index)
153
+ pairwise_candidate_index = int(l[candidate_rec_index])
154
+ with col1:
155
+ st.header("Text")
156
+ t1 = time.time()
157
+ st.write(get_candidate_text(pairwise_candidate_index))
158
+ t2 = time.time()
159
+ with col2:
160
+ st.header("Cosine Distance")
161
+ st.write(float(pairwise_distances[\
162
+ ( pairwise_distances['queries'] == query_index ) \
163
+ &
164
+ ( pairwise_distances['candidates'] == pairwise_candidate_index)]['distances']))
165
+ print(f"candidate string retreival: {t2-t1}")
166
+ with st.container():
167
+ t1 = time.time()
168
+ st.title("Candidates With Same Authors As Query")
169
+ col1, col2, col3 = st.columns([10, 4, 2])
170
+ t2 = time.time()
171
+ gt_candidates = get_gt_candidates(current_model, query_table.iloc[query_index]['authorIDs'][0])
172
+ t3 = time.time()
173
+
174
+ with col3:
175
+ candidate_index = st.number_input('Enter ground truthnumber to inspect', min_value = 0, max_value = gt_candidates.shape[0], step = 1)
176
+ gt_candidate_index = int(gt_candidates.iloc[candidate_index]['index'])
177
+ with col1:
178
+ st.header("Text")
179
+ st.write(gt_candidates.iloc[candidate_index]['fullText'])
180
+ with col2:
181
+ t4 = time.time()
182
+ st.header("Cosine Distance")
183
+ st.write(1-cosine_similarity(np.array([embedding_dataset['queries'][query_index]['embedding']]), np.array([embedding_dataset['candidates'][gt_candidate_index]['embedding']]))[0,0])
184
+ t5 = time.time()
185
+ print(f"find gt candidates: {t3-t2}, find cosine: {t5-t4}, total: {t5-t1}")
requirements.txt ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.3.0
2
+ aiohttp==3.8.3
3
+ aiosignal==1.3.1
4
+ async-timeout==4.0.2
5
+ attrs==22.1.0
6
+ blis==0.7.9
7
+ catalogue==2.0.8
8
+ certifi==2022.12.7
9
+ charset-normalizer==2.1.1
10
+ click==8.1.3
11
+ confection==0.0.3
12
+ cymem==2.0.7
13
+ datasets==2.7.1
14
+ dill==0.3.6
15
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl
16
+ einops==0.6.0
17
+ filelock==3.8.2
18
+ frozenlist==1.3.3
19
+ fsspec==2022.11.0
20
+ huggingface-hub==0.13.1
21
+ idna==3.4
22
+ Jinja2==3.1.2
23
+ joblib==1.2.0
24
+ langcodes==3.3.0
25
+ MarkupSafe==2.1.1
26
+ multidict==6.0.3
27
+ multiprocess==0.70.14
28
+ murmurhash==1.0.9
29
+ numpy==1.23.5
30
+ nvidia-cublas-cu11==11.10.3.66
31
+ nvidia-cuda-nvrtc-cu11==11.7.99
32
+ nvidia-cuda-runtime-cu11==11.7.99
33
+ nvidia-cudnn-cu11==8.5.0.96
34
+ packaging==22.0
35
+ pandas==1.5.2
36
+ pathy==0.10.1
37
+ Pillow==9.3.0
38
+ preshed==3.0.8
39
+ pyarrow==10.0.1
40
+ pydantic==1.10.2
41
+ python-dateutil==2.8.2
42
+ pytorch-lightning==1.8.5.post0
43
+ pytorch-metric-learning==1.6.3
44
+ pytz==2022.6
45
+ PyYAML==6.0
46
+ regex==2022.10.31
47
+ requests==2.28.1
48
+ responses==0.18.0
49
+ sacremoses==0.0.53
50
+ scikit-learn==1.2.0
51
+ scipy==1.9.3
52
+ six==1.16.0
53
+ smart-open==6.3.0
54
+ spacy==3.4.3
55
+ spacy-legacy==3.0.10
56
+ spacy-loggers==1.0.4
57
+ srsly==2.4.5
58
+ thinc==8.1.5
59
+ threadpoolctl==3.1.0
60
+ tokenizers==0.10.3
61
+ torch==1.13.0
62
+ torchvision==0.14.0
63
+ tqdm==4.64.1
64
+ transformers==4.14.1
65
+ typer==0.7.0
66
+ typing_extensions==4.4.0
67
+ urllib3==1.26.13
68
+ wasabi==0.10.1
69
+ xxhash==3.1.0
70
+ yarl==1.8.2
71
+ streamlit==1.17.0
72
+ streamlit-aggrid==0.3.4.post3
73
+ --extra-index-url http://download.pytorch.org/whl/cu116
74
+ --trusted-host download.pytorch.org