Adapting commited on
Commit
b17c1e6
·
1 Parent(s): 675214a
app.py CHANGED
@@ -25,7 +25,10 @@ with st.form("my_form",clear_on_submit=False):
25
 
26
  if submitted:
27
  # body
28
- render_body(platforms, number_papers, 5, query_input, show_preview, start_year, end_year, hyperparams)
 
 
 
29
  # '''
30
  # bar = (
31
  # Bar()
 
25
 
26
  if submitted:
27
  # body
28
+ render_body(platforms, number_papers, 5, query_input,
29
+ show_preview, start_year, end_year,
30
+ hyperparams,
31
+ hyperparams['standardization'])
32
  # '''
33
  # bar = (
34
  # Bar()
lrt/clustering/clustering_pipeline.py CHANGED
@@ -3,6 +3,7 @@ from .config import BaselineConfig, Configuration
3
  from ..utils import __create_model__
4
  import numpy as np
5
  from sklearn.cluster import KMeans
 
6
  from yellowbrick.cluster import KElbowVisualizer
7
  from .clusters import ClusterList
8
 
@@ -42,7 +43,7 @@ class ClusterPipeline:
42
  print(f'>>> finished dimension reduction...')
43
  return embeddings
44
 
45
- def __3_clustering__(self, embeddings, return_cluster_centers = False, max_k: int =10):
46
  '''
47
 
48
  :param embeddings: Nxd
@@ -52,6 +53,16 @@ class ClusterPipeline:
52
  return embeddings
53
  else:
54
  print(f'>>> start clustering...')
 
 
 
 
 
 
 
 
 
 
55
  model = KMeans()
56
  visualizer = KElbowVisualizer(
57
  model, k=(2, max_k+1), metric='silhouette', timings=False, locate_elbow=False
@@ -93,11 +104,11 @@ class ClusterPipeline:
93
  return clusters
94
 
95
 
96
- def __call__(self, documents: List[str], max_k:int):
97
  print(f'>>> pipeline starts...')
98
  x = self.__1_generate_word_embeddings__(documents)
99
  x = self.__2_dimenstion_reduction__(x)
100
- clusters = self.__3_clustering__(x,max_k=max_k)
101
  outputs = self.__4_keywords_extraction__(clusters, documents)
102
  print(f'>>> pipeline finished!\n')
103
  return outputs
 
3
  from ..utils import __create_model__
4
  import numpy as np
5
  from sklearn.cluster import KMeans
6
+ from sklearn.preprocessing import StandardScaler
7
  from yellowbrick.cluster import KElbowVisualizer
8
  from .clusters import ClusterList
9
 
 
43
  print(f'>>> finished dimension reduction...')
44
  return embeddings
45
 
46
+ def __3_clustering__(self, embeddings, return_cluster_centers = False, max_k: int =10, standarization = False):
47
  '''
48
 
49
  :param embeddings: Nxd
 
53
  return embeddings
54
  else:
55
  print(f'>>> start clustering...')
56
+
57
+ ######## new: standarization ########
58
+ if standarization:
59
+ print(f'>>> start standardization...')
60
+ scaler = StandardScaler()
61
+ embeddings = scaler.fit_transform(embeddings)
62
+ print(f'>>> finished standardization...')
63
+ ######## new: standarization ########
64
+
65
+
66
  model = KMeans()
67
  visualizer = KElbowVisualizer(
68
  model, k=(2, max_k+1), metric='silhouette', timings=False, locate_elbow=False
 
104
  return clusters
105
 
106
 
107
+ def __call__(self, documents: List[str], max_k:int, standarization = False):
108
  print(f'>>> pipeline starts...')
109
  x = self.__1_generate_word_embeddings__(documents)
110
  x = self.__2_dimenstion_reduction__(x)
111
+ clusters = self.__3_clustering__(x,max_k=max_k,standarization=standarization)
112
  outputs = self.__4_keywords_extraction__(clusters, documents)
113
  print(f'>>> pipeline finished!\n')
114
  return outputs
lrt/lrt.py CHANGED
@@ -49,15 +49,16 @@ class LiteratureResearchTool:
49
  max_k: int,
50
  platforms: List[str] = ['IEEE', 'Arxiv', 'Paper with Code'],
51
  loading_ctx_manager = None,
 
52
  ):
53
 
54
 
55
  for platform in platforms:
56
  if loading_ctx_manager:
57
  with loading_ctx_manager():
58
- clusters, articles = self.__platformPipeline__(platform,query,num_papers,start_year,end_year,max_k)
59
  else:
60
- clusters, articles = self.__platformPipeline__(platform, query, num_papers, start_year, end_year,max_k)
61
 
62
  clusters.sort()
63
  yield clusters,articles
@@ -69,7 +70,8 @@ class LiteratureResearchTool:
69
  num_papers: int,
70
  start_year: int,
71
  end_year: int,
72
- max_k: int
 
73
  ) -> (ClusterList,ArticleList):
74
 
75
  @st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
@@ -82,7 +84,7 @@ class LiteratureResearchTool:
82
  articles = ArticleList.parse_ieee_articles(
83
  self.literature_search.ieee(query, start_year, end_year, num_papers)) # ArticleList
84
  abstracts = articles.getAbstracts() # List[str]
85
- clusters = self.cluster_pipeline(abstracts,max_k)
86
  clusters = self.__postprocess_clusters__(clusters)
87
  return clusters, articles
88
 
@@ -94,7 +96,7 @@ class LiteratureResearchTool:
94
  articles = ArticleList.parse_arxiv_articles(
95
  self.literature_search.arxiv(query, num_papers)) # ArticleList
96
  abstracts = articles.getAbstracts() # List[str]
97
- clusters = self.cluster_pipeline(abstracts,max_k)
98
  clusters = self.__postprocess_clusters__(clusters)
99
  return clusters, articles
100
 
@@ -106,7 +108,7 @@ class LiteratureResearchTool:
106
  articles = ArticleList.parse_pwc_articles(
107
  self.literature_search.paper_with_code(query, num_papers)) # ArticleList
108
  abstracts = articles.getAbstracts() # List[str]
109
- clusters = self.cluster_pipeline(abstracts,max_k)
110
  clusters = self.__postprocess_clusters__(clusters)
111
  return clusters, articles
112
 
 
49
  max_k: int,
50
  platforms: List[str] = ['IEEE', 'Arxiv', 'Paper with Code'],
51
  loading_ctx_manager = None,
52
+ standardization = False
53
  ):
54
 
55
 
56
  for platform in platforms:
57
  if loading_ctx_manager:
58
  with loading_ctx_manager():
59
+ clusters, articles = self.__platformPipeline__(platform,query,num_papers,start_year,end_year,max_k,standardization)
60
  else:
61
+ clusters, articles = self.__platformPipeline__(platform, query, num_papers, start_year, end_year,max_k,standardization)
62
 
63
  clusters.sort()
64
  yield clusters,articles
 
70
  num_papers: int,
71
  start_year: int,
72
  end_year: int,
73
+ max_k: int,
74
+ standardization
75
  ) -> (ClusterList,ArticleList):
76
 
77
  @st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
 
84
  articles = ArticleList.parse_ieee_articles(
85
  self.literature_search.ieee(query, start_year, end_year, num_papers)) # ArticleList
86
  abstracts = articles.getAbstracts() # List[str]
87
+ clusters = self.cluster_pipeline(abstracts,max_k,standardization)
88
  clusters = self.__postprocess_clusters__(clusters)
89
  return clusters, articles
90
 
 
96
  articles = ArticleList.parse_arxiv_articles(
97
  self.literature_search.arxiv(query, num_papers)) # ArticleList
98
  abstracts = articles.getAbstracts() # List[str]
99
+ clusters = self.cluster_pipeline(abstracts,max_k,standardization)
100
  clusters = self.__postprocess_clusters__(clusters)
101
  return clusters, articles
102
 
 
108
  articles = ArticleList.parse_pwc_articles(
109
  self.literature_search.paper_with_code(query, num_papers)) # ArticleList
110
  abstracts = articles.getAbstracts() # List[str]
111
+ clusters = self.cluster_pipeline(abstracts,max_k,standardization)
112
  clusters = self.__postprocess_clusters__(clusters)
113
  return clusters, articles
114
 
widgets/body.py CHANGED
@@ -55,7 +55,7 @@ We have found following papers for you! (displaying 5 papers for each literature
55
 
56
  paperInGeneral.markdown(paperInGeneral_md)
57
 
58
- def render_body(platforms, num_papers, num_papers_preview, query_input, show_preview:bool, start_year, end_year, hyperparams: dict):
59
 
60
  tmp = st.empty()
61
  if query_input != '':
@@ -79,7 +79,7 @@ def render_body(platforms, num_papers, num_papers_preview, query_input, show_pre
79
  )
80
  model = LiteratureResearchTool(config)
81
 
82
- generator = model(query_input, num_papers, start_year, end_year, max_k=hyperparams['max_k'], platforms=platforms)
83
  for i,plat in enumerate(platforms):
84
  clusters, articles = next(generator)
85
  st.markdown(f'''# {i+1} {plat} Results''')
 
55
 
56
  paperInGeneral.markdown(paperInGeneral_md)
57
 
58
+ def render_body(platforms, num_papers, num_papers_preview, query_input, show_preview:bool, start_year, end_year, hyperparams: dict, standardization = False):
59
 
60
  tmp = st.empty()
61
  if query_input != '':
 
79
  )
80
  model = LiteratureResearchTool(config)
81
 
82
+ generator = model(query_input, num_papers, start_year, end_year, max_k=hyperparams['max_k'], platforms=platforms, standardization=standardization)
83
  for i,plat in enumerate(platforms):
84
  clusters, articles = next(generator)
85
  st.markdown(f'''# {i+1} {plat} Results''')
widgets/sidebar.py CHANGED
@@ -3,7 +3,7 @@ import datetime
3
  # from .utils import PACKAGE_ROOT
4
  from lrt.utils.functions import template
5
 
6
- APP_VERSION = 'v1.2.0'
7
 
8
  def render_sidebar():
9
  icons = f'''
@@ -70,9 +70,10 @@ def render_sidebar():
70
  with st.sidebar:
71
  st.markdown('## Adjust hyperparameters')
72
  with st.expander('Clustering Options'):
73
- dr = st.selectbox('1) Dimension Reduction', options=['none', 'pca'], index=0)
 
74
  tmp = min(number_papers,15)
75
- max_k = st.slider('2) Max number of clusters', 2,tmp , tmp//2)
76
 
77
  with st.expander('Keyphrases Generation Options'):
78
  model_cpt = st.selectbox(label='Model checkpoint', options=template.keywords_extraction.keys(),index=0)
@@ -88,5 +89,6 @@ def render_sidebar():
88
  return platforms, number_papers, start_year, end_year, dict(
89
  dimension_reduction= dr,
90
  max_k = max_k,
91
- model_cpt = model_cpt
 
92
  )
 
3
  # from .utils import PACKAGE_ROOT
4
  from lrt.utils.functions import template
5
 
6
+ APP_VERSION = 'v1.3.0'
7
 
8
  def render_sidebar():
9
  icons = f'''
 
70
  with st.sidebar:
71
  st.markdown('## Adjust hyperparameters')
72
  with st.expander('Clustering Options'):
73
+ standardization = st.selectbox('1) Standardization before clustering', options=['no', 'yes'], index=0 )
74
+ dr = st.selectbox('2) Dimension reduction', options=['none', 'pca'], index=0)
75
  tmp = min(number_papers,15)
76
+ max_k = st.slider('3) Max number of clusters', 2,tmp , tmp//2)
77
 
78
  with st.expander('Keyphrases Generation Options'):
79
  model_cpt = st.selectbox(label='Model checkpoint', options=template.keywords_extraction.keys(),index=0)
 
89
  return platforms, number_papers, start_year, end_year, dict(
90
  dimension_reduction= dr,
91
  max_k = max_k,
92
+ model_cpt = model_cpt,
93
+ standardization = True if standardization == 'yes' else False
94
  )