Spaces:
Running
Running
updates to codebase for embeddings and RAG QA.
Browse files
.DS_Store
CHANGED
|
Binary files a/.DS_Store and b/.DS_Store differ
|
|
|
absts/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
pages/{2_arxiv_embedding.py → 1_arxiv_embedding_explorer.py}
RENAMED
|
@@ -74,9 +74,9 @@ def density_estimation(m1, m2, xmin=0, ymin=0, xmax=15, ymax=15):
|
|
| 74 |
st.sidebar.markdown('This is a widget that allows you to look for papers containing specific phrases in the dataset and show it as a heatmap. Enter the phrase of interest, then change the size and opacity of the heatmap as desired to find the high-density regions. Hover over blue points to see the details of individual papers.')
|
| 75 |
st.sidebar.markdown('`Note`: (i) if you enter a query that is not in the corpus of abstracts, it will return an error. just enter a different query in that case. (ii) there are some empty tooltips when you hover, these correspond to the underlying hexbins, and can be ignored.')
|
| 76 |
|
| 77 |
-
st.sidebar.text_input("Search query", key="phrase", value="")
|
| 78 |
-
alpha_value = st.sidebar.slider("Pick the hexbin opacity",0.0,1.0,0.
|
| 79 |
-
size_value = st.sidebar.slider("Pick the hexbin
|
| 80 |
|
| 81 |
phrase=st.session_state.phrase
|
| 82 |
|
|
@@ -103,10 +103,19 @@ ID: $index
|
|
| 103 |
"""
|
| 104 |
|
| 105 |
p = figure(width=700, height=583, tooltips=TOOLTIPS, x_range=(0, 15), y_range=(2.5,15),
|
| 106 |
-
title="UMAP projection of
|
| 107 |
|
| 108 |
-
p.hexbin(embedding[phrase_flags==1,0],embedding[phrase_flags==1,1], size=size_value,
|
| 109 |
-
|
| 110 |
p.circle('x', 'y', size=3, source=source, alpha=0.3)
|
| 111 |
-
|
| 112 |
st.bokeh_chart(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
st.sidebar.markdown('This is a widget that allows you to look for papers containing specific phrases in the dataset and show it as a heatmap. Enter the phrase of interest, then change the size and opacity of the heatmap as desired to find the high-density regions. Hover over blue points to see the details of individual papers.')
|
| 75 |
st.sidebar.markdown('`Note`: (i) if you enter a query that is not in the corpus of abstracts, it will return an error. just enter a different query in that case. (ii) there are some empty tooltips when you hover, these correspond to the underlying hexbins, and can be ignored.')
|
| 76 |
|
| 77 |
+
st.sidebar.text_input("Search query", key="phrase", value="Quenching")
|
| 78 |
+
alpha_value = st.sidebar.slider("Pick the hexbin opacity",0.0,1.0,0.81)
|
| 79 |
+
size_value = st.sidebar.slider("Pick the hexbin gridsize",10,50,20)
|
| 80 |
|
| 81 |
phrase=st.session_state.phrase
|
| 82 |
|
|
|
|
| 103 |
"""
|
| 104 |
|
| 105 |
p = figure(width=700, height=583, tooltips=TOOLTIPS, x_range=(0, 15), y_range=(2.5,15),
|
| 106 |
+
title="UMAP projection of embeddings for the astro-ph.GA corpus"+phrase)
|
| 107 |
|
| 108 |
+
# p.hexbin(embedding[phrase_flags==1,0],embedding[phrase_flags==1,1], size=size_value,
|
| 109 |
+
# palette = np.flip(OrRd[8]), alpha=alpha_value)
|
| 110 |
p.circle('x', 'y', size=3, source=source, alpha=0.3)
|
|
|
|
| 111 |
st.bokeh_chart(p)
|
| 112 |
+
|
| 113 |
+
fig = plt.figure(figsize=(10.5,9*0.8328))
|
| 114 |
+
plt.scatter(embedding[0:,0], embedding[0:,1],s=2,alpha=0.1)
|
| 115 |
+
plt.hexbin(embedding[phrase_flags==1,0],embedding[phrase_flags==1,1],
|
| 116 |
+
gridsize=size_value, cmap = 'viridis', alpha=alpha_value,extent=(-1,16,1.5,16),mincnt=10)
|
| 117 |
+
plt.title("UMAP localization of heatmap keyword: "+phrase)
|
| 118 |
+
plt.axis([0,15,2.5,15]);
|
| 119 |
+
clbr = plt.colorbar(); clbr.set_label('# papers')
|
| 120 |
+
plt.axis('off')
|
| 121 |
+
st.pyplot(fig)
|
pages/{1_paper_search.py → 2_paper_search.py}
RENAMED
|
File without changes
|
pages/{3_qa_sources_v2.py → 3_answering_questions.py}
RENAMED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# set the environment variables needed for openai package to know to reach out to azure
|
| 2 |
import os
|
| 3 |
import datetime
|
| 4 |
import faiss
|
|
@@ -181,7 +180,7 @@ def list_similar_papers_v2(model_data,
|
|
| 181 |
for i in range(start_range,start_range+return_n):
|
| 182 |
|
| 183 |
abstracts_relevant.append(all_text[sims[i]])
|
| 184 |
-
fhdr = all_authors[sims[i]][0]['name'].split()[-1] + all_arxivid[sims[i]][0:2] +'_'+ all_arxivid[sims[i]]
|
| 185 |
fhdrs.append(fhdr)
|
| 186 |
textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n'
|
| 187 |
textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n'
|
|
@@ -325,7 +324,7 @@ def run_rag(query, return_n = 10, show_authors = True, show_summary = True):
|
|
| 325 |
temp = temp[0:-2] + ' et al. 19' + temp[-2:]
|
| 326 |
temp = '['+temp+']('+all_links[int(srcnames[i].split('_')[0].split('/')[1])]+')'
|
| 327 |
st.markdown(temp)
|
| 328 |
-
|
| 329 |
|
| 330 |
fig = plt.figure(figsize=(9,9))
|
| 331 |
plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
|
|
@@ -338,100 +337,15 @@ def run_rag(query, return_n = 10, show_authors = True, show_summary = True):
|
|
| 338 |
|
| 339 |
return rag_answer
|
| 340 |
|
| 341 |
-
def run_query(query, return_n = 3, show_pure_answer = False, show_all_sources = True):
|
| 342 |
-
|
| 343 |
-
show_authors = True
|
| 344 |
-
show_summary = True
|
| 345 |
-
sims, absts, fhdrs, simids = list_similar_papers_v2(model_data,
|
| 346 |
-
doc_id = query,
|
| 347 |
-
input_type='keywords',
|
| 348 |
-
show_authors = show_authors, show_summary = show_summary,
|
| 349 |
-
return_n = return_n)
|
| 350 |
-
|
| 351 |
-
temp_abst = ''
|
| 352 |
-
loaders = []
|
| 353 |
-
for i in range(len(absts)):
|
| 354 |
-
temp_abst = absts[i]
|
| 355 |
-
|
| 356 |
-
try:
|
| 357 |
-
text_file = open("absts/"+fhdrs[i]+".txt", "w")
|
| 358 |
-
except:
|
| 359 |
-
os.mkdir('absts')
|
| 360 |
-
text_file = open("absts/"+fhdrs[i]+".txt", "w")
|
| 361 |
-
n = text_file.write(temp_abst)
|
| 362 |
-
text_file.close()
|
| 363 |
-
loader = TextLoader("absts/"+fhdrs[i]+".txt")
|
| 364 |
-
loaders.append(loader)
|
| 365 |
-
|
| 366 |
-
lc_index = VectorstoreIndexCreator().from_loaders(loaders)
|
| 367 |
-
|
| 368 |
-
st.markdown('### User query: '+query)
|
| 369 |
-
if show_pure_answer == True:
|
| 370 |
-
st.markdown('pure answer:')
|
| 371 |
-
st.markdown(lc_index.query(query))
|
| 372 |
-
st.markdown(' ')
|
| 373 |
-
st.markdown('#### context-based answer from sources:')
|
| 374 |
-
output = lc_index.query_with_sources(query + ' Let\'s work this out in a step by step way to be sure we have the right answer.' ) #zero-shot in-context prompting from Zhou+22, Kojima+22
|
| 375 |
-
st.markdown(output['answer'])
|
| 376 |
-
opstr = '#### Primary sources: \n'
|
| 377 |
-
st.markdown(opstr)
|
| 378 |
-
|
| 379 |
-
# opstr = ''
|
| 380 |
-
# for i in range(len(output['sources'])):
|
| 381 |
-
# opstr = opstr +'\n'+ output['sources'][i]
|
| 382 |
-
|
| 383 |
-
textstr = ''
|
| 384 |
-
ng = len(output['sources'].split())
|
| 385 |
-
abs_indices = []
|
| 386 |
-
|
| 387 |
-
for i in range(ng):
|
| 388 |
-
if i == (ng-1):
|
| 389 |
-
tempid = output['sources'].split()[i].split('_')[1][0:-4]
|
| 390 |
-
else:
|
| 391 |
-
tempid = output['sources'].split()[i].split('_')[1][0:-5]
|
| 392 |
-
try:
|
| 393 |
-
abs_index = all_arxivid.index(tempid)
|
| 394 |
-
abs_indices.append(abs_index)
|
| 395 |
-
textstr = textstr + str(i+1)+'. **'+ all_titles[abs_index] +' \n'
|
| 396 |
-
textstr = textstr + '**ArXiv:** ['+all_arxivid[abs_index]+'](https://arxiv.org/abs/'+all_arxivid[abs_index]+') \n'
|
| 397 |
-
textstr = textstr + '**Authors:** '
|
| 398 |
-
temp = all_authors[abs_index]
|
| 399 |
-
for ak in range(4):
|
| 400 |
-
if ak < len(temp)-1:
|
| 401 |
-
textstr = textstr + temp[ak].name + ', '
|
| 402 |
-
else:
|
| 403 |
-
textstr = textstr + temp[ak].name + ' \n'
|
| 404 |
-
if len(temp) > 3:
|
| 405 |
-
textstr = textstr + ' et al. \n'
|
| 406 |
-
textstr = textstr + '**Summary:** '
|
| 407 |
-
text = all_text[abs_index]
|
| 408 |
-
text = text.replace('\n', ' ')
|
| 409 |
-
textstr = textstr + summarizer.summarize(text) + ' \n'
|
| 410 |
-
except:
|
| 411 |
-
textstr = textstr + output['sources'].split()[i]
|
| 412 |
-
# opstr = opstr + ' \n ' + output['sources'].split()[i][6:-5].split('_')[0]
|
| 413 |
-
# opstr = opstr + ' \n Arxiv id: ' + output['sources'].split()[i][6:-5].split('_')[1]
|
| 414 |
-
|
| 415 |
-
textstr = textstr + ' '
|
| 416 |
-
textstr = textstr + ' \n'
|
| 417 |
-
st.markdown(textstr)
|
| 418 |
-
|
| 419 |
-
fig = plt.figure(figsize=(9,9))
|
| 420 |
-
plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
|
| 421 |
-
plt.scatter(e2d[simids,0], e2d[simids,1],s=30)
|
| 422 |
-
plt.scatter(e2d[abs_indices,0], e2d[abs_indices,1],s=100,color='k',marker='d')
|
| 423 |
-
st.pyplot(fig)
|
| 424 |
-
|
| 425 |
-
if show_all_sources == True:
|
| 426 |
-
st.markdown('\n #### Other interesting papers:')
|
| 427 |
-
st.markdown(sims)
|
| 428 |
-
return output
|
| 429 |
|
| 430 |
st.title('ArXiv-based question answering')
|
| 431 |
st.markdown('[Includes papers up to: `'+dateval+'`]')
|
| 432 |
-
st.markdown('Concise answers for questions using arxiv abstracts + GPT-4.
|
|
|
|
|
|
|
| 433 |
|
| 434 |
-
query = st.text_input('Your question here:',
|
| 435 |
-
|
|
|
|
| 436 |
|
| 437 |
-
sims =
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import datetime
|
| 3 |
import faiss
|
|
|
|
| 180 |
for i in range(start_range,start_range+return_n):
|
| 181 |
|
| 182 |
abstracts_relevant.append(all_text[sims[i]])
|
| 183 |
+
fhdr = str(sims[i])+'_'+all_authors[sims[i]][0]['name'].split()[-1] + all_arxivid[sims[i]][0:2] +'_'+ all_arxivid[sims[i]]
|
| 184 |
fhdrs.append(fhdr)
|
| 185 |
textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n'
|
| 186 |
textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n'
|
|
|
|
| 324 |
temp = temp[0:-2] + ' et al. 19' + temp[-2:]
|
| 325 |
temp = '['+temp+']('+all_links[int(srcnames[i].split('_')[0].split('/')[1])]+')'
|
| 326 |
st.markdown(temp)
|
| 327 |
+
abs_indices = np.array(srcindices)
|
| 328 |
|
| 329 |
fig = plt.figure(figsize=(9,9))
|
| 330 |
plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
|
|
|
|
| 337 |
|
| 338 |
return rag_answer
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
st.title('ArXiv-based question answering')
|
| 342 |
st.markdown('[Includes papers up to: `'+dateval+'`]')
|
| 343 |
+
st.markdown('Concise answers for questions using arxiv abstracts + GPT-4. You might need to wait for a few seconds for the GPT-4 query to return an answer (check top right corner to see if it is still running).')
|
| 344 |
+
st.markdown('The answers are followed by relevant source(s) used in the answer, a graph showing which part of the astro-ph.GA manifold it drew the answer from (tightly clustered points generally indicate high quality/consensus answers) followed by a bunch of relevant papers used by the RAG to compose the answer.')
|
| 345 |
+
st.markdown('If this does not satisfactorily answer your question or rambles too much, you can also try the older `qa_sources_v1` page.')
|
| 346 |
|
| 347 |
+
query = st.text_input('Your question here:',
|
| 348 |
+
value="What causes galaxy quenching at high redshifts?")
|
| 349 |
+
return_n = st.slider('How many papers should I show?', 1, 30, 10)
|
| 350 |
|
| 351 |
+
sims = run_rag(query, return_n = return_n)
|
pages/{3_qa_sources_v1.py → 4_qa_sources_v1.py}
RENAMED
|
@@ -118,7 +118,7 @@ def find_papers_by_author(auth_name):
|
|
| 118 |
|
| 119 |
return doc_ids
|
| 120 |
|
| 121 |
-
def faiss_based_indices(input_vector, nindex=10):
|
| 122 |
xq = input_vector.reshape(-1,1).T.astype('float32')
|
| 123 |
D, I = index.search(xq, nindex)
|
| 124 |
return I[0], D[0]
|
|
@@ -126,7 +126,7 @@ def faiss_based_indices(input_vector, nindex=10):
|
|
| 126 |
def list_similar_papers_v2(model_data,
|
| 127 |
doc_id = [], input_type = 'doc_id',
|
| 128 |
show_authors = False, show_summary = False,
|
| 129 |
-
return_n = 10):
|
| 130 |
|
| 131 |
arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data
|
| 132 |
|
|
@@ -152,7 +152,7 @@ def list_similar_papers_v2(model_data,
|
|
| 152 |
print('unrecognized input type.')
|
| 153 |
return
|
| 154 |
|
| 155 |
-
sims, dists = faiss_based_indices(inferred_vector, return_n+2)
|
| 156 |
textstr = ''
|
| 157 |
abstracts_relevant = []
|
| 158 |
fhdrs = []
|
|
@@ -182,30 +182,9 @@ def list_similar_papers_v2(model_data,
|
|
| 182 |
textstr = textstr + ' \n'
|
| 183 |
return textstr, abstracts_relevant, fhdrs, sims
|
| 184 |
|
| 185 |
-
|
| 186 |
-
def generate_chat_completion(messages, model="gpt-4", temperature=1, max_tokens=None):
|
| 187 |
-
headers = {
|
| 188 |
-
"Content-Type": "application/json",
|
| 189 |
-
"Authorization": f"Bearer {openai.api_key}",
|
| 190 |
-
}
|
| 191 |
-
|
| 192 |
-
data = {
|
| 193 |
-
"model": model,
|
| 194 |
-
"messages": messages,
|
| 195 |
-
"temperature": temperature,
|
| 196 |
-
}
|
| 197 |
-
|
| 198 |
-
if max_tokens is not None:
|
| 199 |
-
data["max_tokens"] = max_tokens
|
| 200 |
-
response = requests.post(API_ENDPOINT, headers=headers, data=json.dumps(data))
|
| 201 |
-
if response.status_code == 200:
|
| 202 |
-
return response.json()["choices"][0]["message"]["content"]
|
| 203 |
-
else:
|
| 204 |
-
raise Exception(f"Error {response.status_code}: {response.text}")
|
| 205 |
-
|
| 206 |
model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors]
|
| 207 |
|
| 208 |
-
def run_query(query, return_n = 3, show_pure_answer = False, show_all_sources = True):
|
| 209 |
|
| 210 |
show_authors = True
|
| 211 |
show_summary = True
|
|
@@ -213,7 +192,7 @@ def run_query(query, return_n = 3, show_pure_answer = False, show_all_sources =
|
|
| 213 |
doc_id = query,
|
| 214 |
input_type='keywords',
|
| 215 |
show_authors = show_authors, show_summary = show_summary,
|
| 216 |
-
return_n = return_n)
|
| 217 |
|
| 218 |
temp_abst = ''
|
| 219 |
loaders = []
|
|
@@ -300,5 +279,8 @@ st.markdown('Concise answers for questions using arxiv abstracts + GPT-4. Please
|
|
| 300 |
|
| 301 |
query = st.text_input('Your question here:', value="What sersic index does a disk galaxy have?")
|
| 302 |
return_n = st.slider('How many papers should I show?', 1, 20, 10)
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
-
sims = run_query(query, return_n = return_n)
|
|
|
|
| 118 |
|
| 119 |
return doc_ids
|
| 120 |
|
| 121 |
+
def faiss_based_indices(input_vector, nindex=10, yrmin = 1990, yrmax = 2024):
|
| 122 |
xq = input_vector.reshape(-1,1).T.astype('float32')
|
| 123 |
D, I = index.search(xq, nindex)
|
| 124 |
return I[0], D[0]
|
|
|
|
| 126 |
def list_similar_papers_v2(model_data,
|
| 127 |
doc_id = [], input_type = 'doc_id',
|
| 128 |
show_authors = False, show_summary = False,
|
| 129 |
+
return_n = 10, yrmin = 1990, yrmax = 2024):
|
| 130 |
|
| 131 |
arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data
|
| 132 |
|
|
|
|
| 152 |
print('unrecognized input type.')
|
| 153 |
return
|
| 154 |
|
| 155 |
+
sims, dists = faiss_based_indices(inferred_vector, return_n+2, yrmin = 1990, yrmax = 2024)
|
| 156 |
textstr = ''
|
| 157 |
abstracts_relevant = []
|
| 158 |
fhdrs = []
|
|
|
|
| 182 |
textstr = textstr + ' \n'
|
| 183 |
return textstr, abstracts_relevant, fhdrs, sims
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors]
|
| 186 |
|
| 187 |
+
def run_query(query, return_n = 3, yrmin = 1990, yrmax = 2024, show_pure_answer = False, show_all_sources = True):
|
| 188 |
|
| 189 |
show_authors = True
|
| 190 |
show_summary = True
|
|
|
|
| 192 |
doc_id = query,
|
| 193 |
input_type='keywords',
|
| 194 |
show_authors = show_authors, show_summary = show_summary,
|
| 195 |
+
return_n = return_n, yrmin = 1990, yrmax = 2024)
|
| 196 |
|
| 197 |
temp_abst = ''
|
| 198 |
loaders = []
|
|
|
|
| 279 |
|
| 280 |
query = st.text_input('Your question here:', value="What sersic index does a disk galaxy have?")
|
| 281 |
return_n = st.slider('How many papers should I show?', 1, 20, 10)
|
| 282 |
+
yrmin = st.slider('Min year', 1990,2023, 1990)
|
| 283 |
+
yrmax = st.slider('Max year', 1990, 2024, 2024)
|
| 284 |
+
|
| 285 |
|
| 286 |
+
sims = run_query(query, return_n = return_n, yrmin = yrmin, yrmax = yrmax)
|