Spaces:
Runtime error
Runtime error
updates
Browse files- app.py +4 -1
- lrt/clustering/clustering_pipeline.py +14 -3
- lrt/lrt.py +8 -6
- widgets/body.py +2 -2
- widgets/sidebar.py +6 -4
app.py
CHANGED
@@ -25,7 +25,10 @@ with st.form("my_form",clear_on_submit=False):
|
|
25 |
|
26 |
if submitted:
|
27 |
# body
|
28 |
-
render_body(platforms, number_papers, 5, query_input,
|
|
|
|
|
|
|
29 |
# '''
|
30 |
# bar = (
|
31 |
# Bar()
|
|
|
25 |
|
26 |
if submitted:
|
27 |
# body
|
28 |
+
render_body(platforms, number_papers, 5, query_input,
|
29 |
+
show_preview, start_year, end_year,
|
30 |
+
hyperparams,
|
31 |
+
hyperparams['standardization'])
|
32 |
# '''
|
33 |
# bar = (
|
34 |
# Bar()
|
lrt/clustering/clustering_pipeline.py
CHANGED
@@ -3,6 +3,7 @@ from .config import BaselineConfig, Configuration
|
|
3 |
from ..utils import __create_model__
|
4 |
import numpy as np
|
5 |
from sklearn.cluster import KMeans
|
|
|
6 |
from yellowbrick.cluster import KElbowVisualizer
|
7 |
from .clusters import ClusterList
|
8 |
|
@@ -42,7 +43,7 @@ class ClusterPipeline:
|
|
42 |
print(f'>>> finished dimension reduction...')
|
43 |
return embeddings
|
44 |
|
45 |
-
def __3_clustering__(self, embeddings, return_cluster_centers = False, max_k: int =10):
|
46 |
'''
|
47 |
|
48 |
:param embeddings: Nxd
|
@@ -52,6 +53,16 @@ class ClusterPipeline:
|
|
52 |
return embeddings
|
53 |
else:
|
54 |
print(f'>>> start clustering...')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
model = KMeans()
|
56 |
visualizer = KElbowVisualizer(
|
57 |
model, k=(2, max_k+1), metric='silhouette', timings=False, locate_elbow=False
|
@@ -93,11 +104,11 @@ class ClusterPipeline:
|
|
93 |
return clusters
|
94 |
|
95 |
|
96 |
-
def __call__(self, documents: List[str], max_k:int):
|
97 |
print(f'>>> pipeline starts...')
|
98 |
x = self.__1_generate_word_embeddings__(documents)
|
99 |
x = self.__2_dimenstion_reduction__(x)
|
100 |
-
clusters = self.__3_clustering__(x,max_k=max_k)
|
101 |
outputs = self.__4_keywords_extraction__(clusters, documents)
|
102 |
print(f'>>> pipeline finished!\n')
|
103 |
return outputs
|
|
|
3 |
from ..utils import __create_model__
|
4 |
import numpy as np
|
5 |
from sklearn.cluster import KMeans
|
6 |
+
from sklearn.preprocessing import StandardScaler
|
7 |
from yellowbrick.cluster import KElbowVisualizer
|
8 |
from .clusters import ClusterList
|
9 |
|
|
|
43 |
print(f'>>> finished dimension reduction...')
|
44 |
return embeddings
|
45 |
|
46 |
+
def __3_clustering__(self, embeddings, return_cluster_centers = False, max_k: int =10, standarization = False):
|
47 |
'''
|
48 |
|
49 |
:param embeddings: Nxd
|
|
|
53 |
return embeddings
|
54 |
else:
|
55 |
print(f'>>> start clustering...')
|
56 |
+
|
57 |
+
######## new: standarization ########
|
58 |
+
if standarization:
|
59 |
+
print(f'>>> start standardization...')
|
60 |
+
scaler = StandardScaler()
|
61 |
+
embeddings = scaler.fit_transform(embeddings)
|
62 |
+
print(f'>>> finished standardization...')
|
63 |
+
######## new: standarization ########
|
64 |
+
|
65 |
+
|
66 |
model = KMeans()
|
67 |
visualizer = KElbowVisualizer(
|
68 |
model, k=(2, max_k+1), metric='silhouette', timings=False, locate_elbow=False
|
|
|
104 |
return clusters
|
105 |
|
106 |
|
107 |
+
def __call__(self, documents: List[str], max_k:int, standarization = False):
|
108 |
print(f'>>> pipeline starts...')
|
109 |
x = self.__1_generate_word_embeddings__(documents)
|
110 |
x = self.__2_dimenstion_reduction__(x)
|
111 |
+
clusters = self.__3_clustering__(x,max_k=max_k,standarization=standarization)
|
112 |
outputs = self.__4_keywords_extraction__(clusters, documents)
|
113 |
print(f'>>> pipeline finished!\n')
|
114 |
return outputs
|
lrt/lrt.py
CHANGED
@@ -49,15 +49,16 @@ class LiteratureResearchTool:
|
|
49 |
max_k: int,
|
50 |
platforms: List[str] = ['IEEE', 'Arxiv', 'Paper with Code'],
|
51 |
loading_ctx_manager = None,
|
|
|
52 |
):
|
53 |
|
54 |
|
55 |
for platform in platforms:
|
56 |
if loading_ctx_manager:
|
57 |
with loading_ctx_manager():
|
58 |
-
clusters, articles = self.__platformPipeline__(platform,query,num_papers,start_year,end_year,max_k)
|
59 |
else:
|
60 |
-
clusters, articles = self.__platformPipeline__(platform, query, num_papers, start_year, end_year,max_k)
|
61 |
|
62 |
clusters.sort()
|
63 |
yield clusters,articles
|
@@ -69,7 +70,8 @@ class LiteratureResearchTool:
|
|
69 |
num_papers: int,
|
70 |
start_year: int,
|
71 |
end_year: int,
|
72 |
-
max_k: int
|
|
|
73 |
) -> (ClusterList,ArticleList):
|
74 |
|
75 |
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
|
@@ -82,7 +84,7 @@ class LiteratureResearchTool:
|
|
82 |
articles = ArticleList.parse_ieee_articles(
|
83 |
self.literature_search.ieee(query, start_year, end_year, num_papers)) # ArticleList
|
84 |
abstracts = articles.getAbstracts() # List[str]
|
85 |
-
clusters = self.cluster_pipeline(abstracts,max_k)
|
86 |
clusters = self.__postprocess_clusters__(clusters)
|
87 |
return clusters, articles
|
88 |
|
@@ -94,7 +96,7 @@ class LiteratureResearchTool:
|
|
94 |
articles = ArticleList.parse_arxiv_articles(
|
95 |
self.literature_search.arxiv(query, num_papers)) # ArticleList
|
96 |
abstracts = articles.getAbstracts() # List[str]
|
97 |
-
clusters = self.cluster_pipeline(abstracts,max_k)
|
98 |
clusters = self.__postprocess_clusters__(clusters)
|
99 |
return clusters, articles
|
100 |
|
@@ -106,7 +108,7 @@ class LiteratureResearchTool:
|
|
106 |
articles = ArticleList.parse_pwc_articles(
|
107 |
self.literature_search.paper_with_code(query, num_papers)) # ArticleList
|
108 |
abstracts = articles.getAbstracts() # List[str]
|
109 |
-
clusters = self.cluster_pipeline(abstracts,max_k)
|
110 |
clusters = self.__postprocess_clusters__(clusters)
|
111 |
return clusters, articles
|
112 |
|
|
|
49 |
max_k: int,
|
50 |
platforms: List[str] = ['IEEE', 'Arxiv', 'Paper with Code'],
|
51 |
loading_ctx_manager = None,
|
52 |
+
standardization = False
|
53 |
):
|
54 |
|
55 |
|
56 |
for platform in platforms:
|
57 |
if loading_ctx_manager:
|
58 |
with loading_ctx_manager():
|
59 |
+
clusters, articles = self.__platformPipeline__(platform,query,num_papers,start_year,end_year,max_k,standardization)
|
60 |
else:
|
61 |
+
clusters, articles = self.__platformPipeline__(platform, query, num_papers, start_year, end_year,max_k,standardization)
|
62 |
|
63 |
clusters.sort()
|
64 |
yield clusters,articles
|
|
|
70 |
num_papers: int,
|
71 |
start_year: int,
|
72 |
end_year: int,
|
73 |
+
max_k: int,
|
74 |
+
standardization
|
75 |
) -> (ClusterList,ArticleList):
|
76 |
|
77 |
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
|
|
|
84 |
articles = ArticleList.parse_ieee_articles(
|
85 |
self.literature_search.ieee(query, start_year, end_year, num_papers)) # ArticleList
|
86 |
abstracts = articles.getAbstracts() # List[str]
|
87 |
+
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
88 |
clusters = self.__postprocess_clusters__(clusters)
|
89 |
return clusters, articles
|
90 |
|
|
|
96 |
articles = ArticleList.parse_arxiv_articles(
|
97 |
self.literature_search.arxiv(query, num_papers)) # ArticleList
|
98 |
abstracts = articles.getAbstracts() # List[str]
|
99 |
+
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
100 |
clusters = self.__postprocess_clusters__(clusters)
|
101 |
return clusters, articles
|
102 |
|
|
|
108 |
articles = ArticleList.parse_pwc_articles(
|
109 |
self.literature_search.paper_with_code(query, num_papers)) # ArticleList
|
110 |
abstracts = articles.getAbstracts() # List[str]
|
111 |
+
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
112 |
clusters = self.__postprocess_clusters__(clusters)
|
113 |
return clusters, articles
|
114 |
|
widgets/body.py
CHANGED
@@ -55,7 +55,7 @@ We have found following papers for you! (displaying 5 papers for each literature
|
|
55 |
|
56 |
paperInGeneral.markdown(paperInGeneral_md)
|
57 |
|
58 |
-
def render_body(platforms, num_papers, num_papers_preview, query_input, show_preview:bool, start_year, end_year, hyperparams: dict):
|
59 |
|
60 |
tmp = st.empty()
|
61 |
if query_input != '':
|
@@ -79,7 +79,7 @@ def render_body(platforms, num_papers, num_papers_preview, query_input, show_pre
|
|
79 |
)
|
80 |
model = LiteratureResearchTool(config)
|
81 |
|
82 |
-
generator = model(query_input, num_papers, start_year, end_year, max_k=hyperparams['max_k'], platforms=platforms)
|
83 |
for i,plat in enumerate(platforms):
|
84 |
clusters, articles = next(generator)
|
85 |
st.markdown(f'''# {i+1} {plat} Results''')
|
|
|
55 |
|
56 |
paperInGeneral.markdown(paperInGeneral_md)
|
57 |
|
58 |
+
def render_body(platforms, num_papers, num_papers_preview, query_input, show_preview:bool, start_year, end_year, hyperparams: dict, standardization = False):
|
59 |
|
60 |
tmp = st.empty()
|
61 |
if query_input != '':
|
|
|
79 |
)
|
80 |
model = LiteratureResearchTool(config)
|
81 |
|
82 |
+
generator = model(query_input, num_papers, start_year, end_year, max_k=hyperparams['max_k'], platforms=platforms, standardization=standardization)
|
83 |
for i,plat in enumerate(platforms):
|
84 |
clusters, articles = next(generator)
|
85 |
st.markdown(f'''# {i+1} {plat} Results''')
|
widgets/sidebar.py
CHANGED
@@ -3,7 +3,7 @@ import datetime
|
|
3 |
# from .utils import PACKAGE_ROOT
|
4 |
from lrt.utils.functions import template
|
5 |
|
6 |
-
APP_VERSION = 'v1.
|
7 |
|
8 |
def render_sidebar():
|
9 |
icons = f'''
|
@@ -70,9 +70,10 @@ def render_sidebar():
|
|
70 |
with st.sidebar:
|
71 |
st.markdown('## Adjust hyperparameters')
|
72 |
with st.expander('Clustering Options'):
|
73 |
-
|
|
|
74 |
tmp = min(number_papers,15)
|
75 |
-
max_k = st.slider('
|
76 |
|
77 |
with st.expander('Keyphrases Generation Options'):
|
78 |
model_cpt = st.selectbox(label='Model checkpoint', options=template.keywords_extraction.keys(),index=0)
|
@@ -88,5 +89,6 @@ def render_sidebar():
|
|
88 |
return platforms, number_papers, start_year, end_year, dict(
|
89 |
dimension_reduction= dr,
|
90 |
max_k = max_k,
|
91 |
-
model_cpt = model_cpt
|
|
|
92 |
)
|
|
|
3 |
# from .utils import PACKAGE_ROOT
|
4 |
from lrt.utils.functions import template
|
5 |
|
6 |
+
APP_VERSION = 'v1.3.0'
|
7 |
|
8 |
def render_sidebar():
|
9 |
icons = f'''
|
|
|
70 |
with st.sidebar:
|
71 |
st.markdown('## Adjust hyperparameters')
|
72 |
with st.expander('Clustering Options'):
|
73 |
+
standardization = st.selectbox('1) Standardization before clustering', options=['no', 'yes'], index=0 )
|
74 |
+
dr = st.selectbox('2) Dimension reduction', options=['none', 'pca'], index=0)
|
75 |
tmp = min(number_papers,15)
|
76 |
+
max_k = st.slider('3) Max number of clusters', 2,tmp , tmp//2)
|
77 |
|
78 |
with st.expander('Keyphrases Generation Options'):
|
79 |
model_cpt = st.selectbox(label='Model checkpoint', options=template.keywords_extraction.keys(),index=0)
|
|
|
89 |
return platforms, number_papers, start_year, end_year, dict(
|
90 |
dimension_reduction= dr,
|
91 |
max_k = max_k,
|
92 |
+
model_cpt = model_cpt,
|
93 |
+
standardization = True if standardization == 'yes' else False
|
94 |
)
|