Spaces:
Runtime error
Runtime error
File size: 6,590 Bytes
0c73150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import streamlit as st
from .sidebar import render_sidebar
from requests_toolkit import ArxivQuery,IEEEQuery,PaperWithCodeQuery
from trendflow.lrt.clustering.clusters import SingleCluster
from trendflow.lrt.clustering.config import Configuration
from trendflow.lrt import ArticleList, LiteratureResearchTool
from trendflow.lrt_instance import *
from .charts import build_bar_charts
def home():
# sidebar content
platforms, number_papers, start_year, end_year, hyperparams = render_sidebar()
# body head
with st.form("my_form", clear_on_submit=False):
st.markdown('''# 👋 Hi, enter your query here :)''')
query_input = st.text_input(
'Enter your query:',
placeholder='''e.g. "Machine learning"''',
# label_visibility='collapsed',
value=''
)
show_preview = st.checkbox('show paper preview')
# Every form must have a submit button.
submitted = st.form_submit_button("Search")
if submitted:
# body
render_body(platforms, number_papers, 5, query_input,
show_preview, start_year, end_year,
hyperparams,
hyperparams['standardization'])
def __preview__(platforms, num_papers, num_papers_preview, query_input, start_year, end_year):
with st.spinner('Searching...'):
paperInGeneral = st.empty() # paper的大概
paperInGeneral_md = '''# 0 Query Results Preview
We have found following papers for you! (displaying 5 papers for each literature platforms)
'''
if 'IEEE' in platforms:
paperInGeneral_md += '''## IEEE
| ID| Paper Title | Publication Year |
| -------- | -------- | -------- |
'''
IEEEQuery.__setup_api_key__('vpd9yy325enruv27zj2d353e')
ieee = IEEEQuery.query(query_input, start_year, end_year, num_papers)
num_papers_preview = min(len(ieee), num_papers_preview)
for i in range(num_papers_preview):
title = str(ieee[i]['title']).replace('\n', ' ')
publication_year = str(ieee[i]['publication_year']).replace('\n', ' ')
paperInGeneral_md += f'''|{i + 1}|{title}|{publication_year}|\n'''
if 'Arxiv' in platforms:
paperInGeneral_md += '''
## Arxiv
| ID| Paper Title | Publication Year |
| -------- | -------- | -------- |
'''
arxiv = ArxivQuery.query(query_input, max_results=num_papers)
num_papers_preview = min(len(arxiv), num_papers_preview)
for i in range(num_papers_preview):
title = str(arxiv[i]['title']).replace('\n', ' ')
publication_year = str(arxiv[i]['published']).replace('\n', ' ')
paperInGeneral_md += f'''|{i + 1}|{title}|{publication_year}|\n'''
if 'Paper with Code' in platforms:
paperInGeneral_md += '''
## Paper with Code
| ID| Paper Title | Publication Year |
| -------- | -------- | -------- |
'''
pwc = PaperWithCodeQuery.query(query_input, items_per_page=num_papers)
num_papers_preview = min(len(pwc), num_papers_preview)
for i in range(num_papers_preview):
title = str(pwc[i]['title']).replace('\n', ' ')
publication_year = str(pwc[i]['published']).replace('\n', ' ')
paperInGeneral_md += f'''|{i + 1}|{title}|{publication_year}|\n'''
paperInGeneral.markdown(paperInGeneral_md)
def render_body(platforms, num_papers, num_papers_preview, query_input, show_preview: bool, start_year, end_year,
hyperparams: dict, standardization=False):
tmp = st.empty()
if query_input != '':
tmp.markdown(f'You entered query: `{query_input}`')
# preview
if show_preview:
__preview__(platforms, num_papers, num_papers_preview, query_input, start_year, end_year)
with st.spinner("Clustering and generating..."):
# lrt results
## baseline
if hyperparams['dimension_reduction'] == 'none' \
and hyperparams['model_cpt'] == 'keyphrase-transformer' \
and hyperparams['cluster_model'] == 'kmeans-euclidean':
model = baseline_lrt
else:
config = Configuration(
plm='''all-mpnet-base-v2''',
dimension_reduction=hyperparams['dimension_reduction'],
clustering=hyperparams['cluster_model'],
keywords_extraction=hyperparams['model_cpt']
)
model = LiteratureResearchTool(config)
generator = model.yield_(query_input, num_papers, start_year, end_year, max_k=hyperparams['max_k'],
platforms=platforms, standardization=standardization)
for i, plat in enumerate(platforms):
clusters, articles = next(generator)
st.markdown(f'''# {i + 1} {plat} Results''')
clusters.sort()
st.markdown(f'''## {i + 1}.1 Clusters Overview''')
st.markdown(f'''In this section we show the overview of the clusters, more specifically,''')
st.markdown(f'''\n- the number of papers in each cluster\n- the number of keyphrases of each cluster''')
st.bokeh_chart(build_bar_charts(
x_range=[f'Cluster {i + 1}' for i in range(len(clusters))],
y_names=['Number of Papers', 'Number of Keyphrases'],
y_data=[[len(c) for c in clusters], [len(c.get_keyphrases()) for c in clusters]]
))
st.markdown(f'''## {i + 1}.2 Cluster Details''')
st.markdown(f'''In this section we show the details of each cluster, including''')
st.markdown(f'''\n- the article information in the cluster\n- the keyphrases of the cluster''')
for j, cluster in enumerate(clusters):
assert isinstance(cluster, SingleCluster) # TODO: remove this line
ids = cluster.get_elements()
articles_in_cluster = ArticleList([articles[id] for id in ids])
st.markdown(f'''**Cluster {j + 1}**''')
st.dataframe(articles_in_cluster.to_dataframe())
st.markdown(f'''The top 5 keyphrases of this cluster are:''')
md = ''
for keyphrase in cluster.top_5_keyphrases:
md += f'''- `{keyphrase}`\n'''
st.markdown(md)
|