Spaces:
Runtime error
Runtime error
Commit
·
f864d05
1
Parent(s):
207c9f5
initial commit
Browse files- app.py +185 -0
- requirements.txt +74 -0
app.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
from datasets import load_from_disk, load_dataset
|
6 |
+
from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
|
7 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
+
import numpy as np
|
9 |
+
import time
|
10 |
+
from annotated_text import annotated_text
|
11 |
+
|
12 |
+
ORG_ID = "cornell-authorship"
|
13 |
+
|
14 |
+
@st.cache
|
15 |
+
def preprocess_text(s):
|
16 |
+
return list(filter(lambda x: x!= '', (''.join(c if c.isalnum() or c == ' ' else ' ' for c in s)).split(' ')))
|
17 |
+
|
18 |
+
@st.cache
|
19 |
+
def get_pairwise_distances(model):
|
20 |
+
dataset = load_dataset(f"{ORG_ID}/{model}_distance")["train"]
|
21 |
+
df = pd.DataFrame(dataset).set_index('index')
|
22 |
+
return df
|
23 |
+
|
24 |
+
@st.cache
|
25 |
+
def get_pairwise_distances_chunked(model, chunk):
|
26 |
+
# for df in pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv", chunksize = 16):
|
27 |
+
# print(df.iloc[0]['queries'])
|
28 |
+
# if chunk == int(df.iloc[0]['queries']):
|
29 |
+
# return df
|
30 |
+
return get_pairwise_distances(model)
|
31 |
+
|
32 |
+
@st.cache
|
33 |
+
def get_query_strings():
|
34 |
+
# df = pd.read_json(hf_hub_download(repo_id=repo_id, filename="IUR_Reddit_test_queries_english.jsonl"), lines = True)
|
35 |
+
dataset = load_dataset(f"{ORG_ID}/IUR_Reddit_test_queries_english")["train"]
|
36 |
+
df = pd.DataFrame(dataset)
|
37 |
+
df['index'] = df.reset_index().index
|
38 |
+
return df
|
39 |
+
# df['partition'] = df['index']%100
|
40 |
+
# df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", index = 'index', partition_cols = 'partition')
|
41 |
+
|
42 |
+
# return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", columns=['fullText', 'index', 'authorIDs'])
|
43 |
+
|
44 |
+
@st.cache
|
45 |
+
def get_candidate_strings():
|
46 |
+
# df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.jsonl", lines = True)
|
47 |
+
dataset = load_dataset(f"{ORG_ID}/IUR_Reddit_test_candidates_english")["train"]
|
48 |
+
df = pd.DataFrame(dataset)
|
49 |
+
df['index'] = df.reset_index().index
|
50 |
+
return df
|
51 |
+
# df['partition'] = df['index']%100
|
52 |
+
# df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", index = 'index', partition_cols = 'partition')
|
53 |
+
# return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", columns=['fullText', 'index', 'authorIDs'])
|
54 |
+
|
55 |
+
@st.cache
|
56 |
+
def get_embedding_dataset(model):
|
57 |
+
# data = load_from_disk(f"{ASSETS_PATH}/{model}/embedding")
|
58 |
+
data = load_dataset(f"{ORG_ID}/{model}_embedding")
|
59 |
+
return data
|
60 |
+
|
61 |
+
@st.cache
|
62 |
+
def get_bad_queries(model):
|
63 |
+
df = get_query_strings().iloc[list(get_pairwise_distances(model)['queries'].unique())][['fullText', 'index', 'authorIDs']]
|
64 |
+
return df
|
65 |
+
|
66 |
+
@st.cache
|
67 |
+
def get_gt_candidates(model, author):
|
68 |
+
gt_candidates = get_candidate_strings()
|
69 |
+
df = gt_candidates[gt_candidates['authorIDs'].apply(lambda x: x[0]) == author]
|
70 |
+
return df
|
71 |
+
|
72 |
+
@st.cache
|
73 |
+
def get_candidate_text(l):
|
74 |
+
return get_candidate_strings().at[l,'fullText']
|
75 |
+
|
76 |
+
@st.cache
|
77 |
+
def get_annotated_text(text, word, pos):
|
78 |
+
# print("here", word, pos)
|
79 |
+
start= text.index(word, pos)
|
80 |
+
end = start+len(word)
|
81 |
+
return (text[:start], (text[start:end ], 'SELECTED'), text[end:]), end
|
82 |
+
|
83 |
+
class AgGridBuilder:
|
84 |
+
__static_key = 0
|
85 |
+
def build_ag_grid(table, display_columns):
|
86 |
+
AgGridBuilder.__static_key += 1
|
87 |
+
options_builder = GridOptionsBuilder.from_dataframe(table[display_columns])
|
88 |
+
options_builder.configure_pagination(paginationAutoPageSize=False, paginationPageSize=10)
|
89 |
+
options_builder.configure_selection(selection_mode= 'single', pre_selected_rows = [0])
|
90 |
+
options = options_builder.build()
|
91 |
+
return AgGrid(table, gridOptions = options, fit_columns_on_grid_load=True, key = AgGridBuilder.__static_key, reload_data = True, update_mode = GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED)
|
92 |
+
|
93 |
+
if __name__ == '__main__':
|
94 |
+
st.set_page_config(layout="wide")
|
95 |
+
|
96 |
+
# models = filter(lambda file_name: os.path.isdir(f"{ASSETS_PATH}/{file_name}") and not file_name.endswith(".parquet"), os.listdir(ASSETS_PATH))
|
97 |
+
models = ['luar_clone2_top_100']
|
98 |
+
|
99 |
+
with st.sidebar:
|
100 |
+
current_model = st.selectbox(
|
101 |
+
"Select Model to analyze",
|
102 |
+
models
|
103 |
+
)
|
104 |
+
|
105 |
+
pairwise_distances = get_pairwise_distances(current_model)
|
106 |
+
embedding_dataset = get_embedding_dataset(current_model)
|
107 |
+
|
108 |
+
candidate_string_grid = None
|
109 |
+
gt_candidate_string_grid = None
|
110 |
+
with st.container():
|
111 |
+
t1 = time.time()
|
112 |
+
st.title("Full Text")
|
113 |
+
col1, col2 = st.columns([14, 2])
|
114 |
+
t2 = time.time()
|
115 |
+
query_table = get_bad_queries(current_model)
|
116 |
+
t3 = time.time()
|
117 |
+
# print(query_table)
|
118 |
+
with col2:
|
119 |
+
index = st.number_input('Enter Query number to inspect', min_value = 0, max_value = query_table.shape[0], step = 1)
|
120 |
+
query_text = query_table.loc[index]['fullText']
|
121 |
+
preprocessed_query_text = preprocess_text(query_text)
|
122 |
+
text_highlight_index = st.number_input('Enter word #', min_value = 0, max_value = len(preprocessed_query_text), step = 1)
|
123 |
+
query_index = int(query_table.iloc[index]['index'])
|
124 |
+
|
125 |
+
with col1:
|
126 |
+
if 'pos_highlight' not in st.session_state or text_highlight_index == 0:
|
127 |
+
st.session_state['pos_highlight'] = text_highlight_index
|
128 |
+
st.session_state['pos_history'] = [0]
|
129 |
+
|
130 |
+
if st.session_state['pos_highlight'] > text_highlight_index:
|
131 |
+
st.session_state['pos_history'] = st.session_state['pos_history'][:-2]
|
132 |
+
if len(st.session_state['pos_history']) == 0:
|
133 |
+
st.session_state['pos_history'] = [0]
|
134 |
+
# print("pos", st.session_state['pos_history'], st.session_state['pos_highlight'], text_highlight_index)
|
135 |
+
anotated_text_, pos = get_annotated_text(query_text, preprocessed_query_text[text_highlight_index-1], st.session_state['pos_history'][-1]) if text_highlight_index >= 1 else ((query_text), 0)
|
136 |
+
if st.session_state['pos_highlight'] < text_highlight_index:
|
137 |
+
st.session_state['pos_history'].append(pos)
|
138 |
+
st.session_state['pos_highlight'] = text_highlight_index
|
139 |
+
annotated_text(*anotated_text_)
|
140 |
+
# annotated_text("Lol, this" , ('guy', 'SELECTED') , "is such a PR chameleon. \n\n In the Chan Zuckerberg Initiative announcement, he made it sound like he was giving away all his money to charity <PERSON> or <PERSON>. http://www.businessinsider.in/Mark-Zuckerberg-says-hes-giving-99-of-his-Facebook-shares-45-billion-to-charity/articleshow/50005321.cms Apparently, its just a VC fund. And there are still people out there who believe Facebook.org was an initiative to bring Internet to the poor.")
|
141 |
+
t4 = time.time()
|
142 |
+
|
143 |
+
# print(f"query time query text: {t3-t2}, total time: {t4-t1}")
|
144 |
+
with st.container():
|
145 |
+
st.title("Top 16 Recommended Candidates")
|
146 |
+
col1, col2, col3 = st.columns([10, 4, 2])
|
147 |
+
rec_candidates = pairwise_distances[pairwise_distances["queries"]==query_index]['candidates']
|
148 |
+
# print(rec_candidates)
|
149 |
+
l = list(rec_candidates)
|
150 |
+
with col3:
|
151 |
+
candidate_rec_index = st.number_input('Enter recommended candidate number to inspect', min_value = 0, max_value = len(l), step = 1)
|
152 |
+
print("l:",l, query_index)
|
153 |
+
pairwise_candidate_index = int(l[candidate_rec_index])
|
154 |
+
with col1:
|
155 |
+
st.header("Text")
|
156 |
+
t1 = time.time()
|
157 |
+
st.write(get_candidate_text(pairwise_candidate_index))
|
158 |
+
t2 = time.time()
|
159 |
+
with col2:
|
160 |
+
st.header("Cosine Distance")
|
161 |
+
st.write(float(pairwise_distances[\
|
162 |
+
( pairwise_distances['queries'] == query_index ) \
|
163 |
+
&
|
164 |
+
( pairwise_distances['candidates'] == pairwise_candidate_index)]['distances']))
|
165 |
+
print(f"candidate string retreival: {t2-t1}")
|
166 |
+
with st.container():
|
167 |
+
t1 = time.time()
|
168 |
+
st.title("Candidates With Same Authors As Query")
|
169 |
+
col1, col2, col3 = st.columns([10, 4, 2])
|
170 |
+
t2 = time.time()
|
171 |
+
gt_candidates = get_gt_candidates(current_model, query_table.iloc[query_index]['authorIDs'][0])
|
172 |
+
t3 = time.time()
|
173 |
+
|
174 |
+
with col3:
|
175 |
+
candidate_index = st.number_input('Enter ground truthnumber to inspect', min_value = 0, max_value = gt_candidates.shape[0], step = 1)
|
176 |
+
gt_candidate_index = int(gt_candidates.iloc[candidate_index]['index'])
|
177 |
+
with col1:
|
178 |
+
st.header("Text")
|
179 |
+
st.write(gt_candidates.iloc[candidate_index]['fullText'])
|
180 |
+
with col2:
|
181 |
+
t4 = time.time()
|
182 |
+
st.header("Cosine Distance")
|
183 |
+
st.write(1-cosine_similarity(np.array([embedding_dataset['queries'][query_index]['embedding']]), np.array([embedding_dataset['candidates'][gt_candidate_index]['embedding']]))[0,0])
|
184 |
+
t5 = time.time()
|
185 |
+
print(f"find gt candidates: {t3-t2}, find cosine: {t5-t4}, total: {t5-t1}")
|
requirements.txt
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.3.0
|
2 |
+
aiohttp==3.8.3
|
3 |
+
aiosignal==1.3.1
|
4 |
+
async-timeout==4.0.2
|
5 |
+
attrs==22.1.0
|
6 |
+
blis==0.7.9
|
7 |
+
catalogue==2.0.8
|
8 |
+
certifi==2022.12.7
|
9 |
+
charset-normalizer==2.1.1
|
10 |
+
click==8.1.3
|
11 |
+
confection==0.0.3
|
12 |
+
cymem==2.0.7
|
13 |
+
datasets==2.7.1
|
14 |
+
dill==0.3.6
|
15 |
+
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl
|
16 |
+
einops==0.6.0
|
17 |
+
filelock==3.8.2
|
18 |
+
frozenlist==1.3.3
|
19 |
+
fsspec==2022.11.0
|
20 |
+
huggingface-hub==0.13.1
|
21 |
+
idna==3.4
|
22 |
+
Jinja2==3.1.2
|
23 |
+
joblib==1.2.0
|
24 |
+
langcodes==3.3.0
|
25 |
+
MarkupSafe==2.1.1
|
26 |
+
multidict==6.0.3
|
27 |
+
multiprocess==0.70.14
|
28 |
+
murmurhash==1.0.9
|
29 |
+
numpy==1.23.5
|
30 |
+
nvidia-cublas-cu11==11.10.3.66
|
31 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
32 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
33 |
+
nvidia-cudnn-cu11==8.5.0.96
|
34 |
+
packaging==22.0
|
35 |
+
pandas==1.5.2
|
36 |
+
pathy==0.10.1
|
37 |
+
Pillow==9.3.0
|
38 |
+
preshed==3.0.8
|
39 |
+
pyarrow==10.0.1
|
40 |
+
pydantic==1.10.2
|
41 |
+
python-dateutil==2.8.2
|
42 |
+
pytorch-lightning==1.8.5.post0
|
43 |
+
pytorch-metric-learning==1.6.3
|
44 |
+
pytz==2022.6
|
45 |
+
PyYAML==6.0
|
46 |
+
regex==2022.10.31
|
47 |
+
requests==2.28.1
|
48 |
+
responses==0.18.0
|
49 |
+
sacremoses==0.0.53
|
50 |
+
scikit-learn==1.2.0
|
51 |
+
scipy==1.9.3
|
52 |
+
six==1.16.0
|
53 |
+
smart-open==6.3.0
|
54 |
+
spacy==3.4.3
|
55 |
+
spacy-legacy==3.0.10
|
56 |
+
spacy-loggers==1.0.4
|
57 |
+
srsly==2.4.5
|
58 |
+
thinc==8.1.5
|
59 |
+
threadpoolctl==3.1.0
|
60 |
+
tokenizers==0.10.3
|
61 |
+
torch==1.13.0
|
62 |
+
torchvision==0.14.0
|
63 |
+
tqdm==4.64.1
|
64 |
+
transformers==4.14.1
|
65 |
+
typer==0.7.0
|
66 |
+
typing_extensions==4.4.0
|
67 |
+
urllib3==1.26.13
|
68 |
+
wasabi==0.10.1
|
69 |
+
xxhash==3.1.0
|
70 |
+
yarl==1.8.2
|
71 |
+
streamlit==1.17.0
|
72 |
+
streamlit-aggrid==0.3.4.post3
|
73 |
+
--extra-index-url http://download.pytorch.org/whl/cu116
|
74 |
+
--trusted-host download.pytorch.org
|