Spaces:
Runtime error
Runtime error
updates
Browse files- lrt/clustering/clusters.py +30 -1
- lrt/lrt.py +26 -15
lrt/clustering/clusters.py
CHANGED
@@ -1,6 +1,32 @@
|
|
1 |
from typing import List, Iterable, Union
|
2 |
from pprint import pprint
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
class SingleCluster:
|
5 |
def __init__(self):
|
6 |
self.__container__ = []
|
@@ -12,7 +38,10 @@ class SingleCluster:
|
|
12 |
def elements(self) -> List:
|
13 |
return self.__container__
|
14 |
def get_keyphrases(self):
|
15 |
-
|
|
|
|
|
|
|
16 |
def add_keyphrase(self, keyphrase:Union[str,Iterable]):
|
17 |
if isinstance(keyphrase,str):
|
18 |
if keyphrase not in self.__keyphrases__.keys():
|
|
|
1 |
from typing import List, Iterable, Union
|
2 |
from pprint import pprint
|
3 |
|
4 |
+
class KeyphraseCount:
|
5 |
+
|
6 |
+
def __init__(self, keyphrase: str, count: int) -> None:
|
7 |
+
super().__init__()
|
8 |
+
self.keyphrase = keyphrase
|
9 |
+
self.count = count
|
10 |
+
|
11 |
+
@classmethod
|
12 |
+
def reduce(cls, kcs: list) :
|
13 |
+
'''
|
14 |
+
kcs: List[KeyphraseCount]
|
15 |
+
'''
|
16 |
+
keys = ''
|
17 |
+
count = 0
|
18 |
+
|
19 |
+
for i in range(len(kcs)-1):
|
20 |
+
kc = kcs[i]
|
21 |
+
keys += kc.keyphrase + '/'
|
22 |
+
count += kc.count
|
23 |
+
|
24 |
+
keys += kcs[-1].keyphrase
|
25 |
+
count += kcs[-1].count
|
26 |
+
return KeyphraseCount(keys, count)
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
class SingleCluster:
|
31 |
def __init__(self):
|
32 |
self.__container__ = []
|
|
|
38 |
def elements(self) -> List:
|
39 |
return self.__container__
|
40 |
def get_keyphrases(self):
|
41 |
+
ret = []
|
42 |
+
for key, count in self.__keyphrases__.items():
|
43 |
+
ret.append(KeyphraseCount(key,count))
|
44 |
+
return ret
|
45 |
def add_keyphrase(self, keyphrase:Union[str,Iterable]):
|
46 |
if isinstance(keyphrase,str):
|
47 |
if keyphrase not in self.__keyphrases__.keys():
|
lrt/lrt.py
CHANGED
@@ -5,6 +5,8 @@ from .utils import UnionFind, ArticleList
|
|
5 |
from .academic_query import AcademicQuery
|
6 |
import streamlit as st
|
7 |
from tokenizers import Tokenizer
|
|
|
|
|
8 |
|
9 |
|
10 |
class LiteratureResearchTool:
|
@@ -13,31 +15,40 @@ class LiteratureResearchTool:
|
|
13 |
self.cluster_pipeline = ClusterPipeline(cluster_config)
|
14 |
|
15 |
|
16 |
-
def __postprocess_clusters__(self, clusters: ClusterList) ->ClusterList:
|
17 |
'''
|
18 |
add top-5 keyphrases to each cluster
|
19 |
:param clusters:
|
20 |
:return: clusters
|
21 |
'''
|
22 |
-
def condition(x, y):
|
23 |
-
return td.ratcliff_obershelp(x, y) > 0.8
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
def valid_keyphrase(x:str):
|
26 |
-
return x is not None and x != '' and not x.isspace()
|
27 |
|
28 |
for cluster in clusters:
|
29 |
-
|
30 |
-
keyphrases = cluster.get_keyphrases()
|
31 |
-
keyphrases = list(keyphrases.keys())
|
32 |
keyphrases = list(filter(valid_keyphrase,keyphrases))
|
33 |
unionfind = UnionFind(keyphrases, condition)
|
34 |
unionfind.union_step()
|
35 |
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
|
42 |
return clusters
|
43 |
|
@@ -85,7 +96,7 @@ class LiteratureResearchTool:
|
|
85 |
self.literature_search.ieee(query, start_year, end_year, num_papers)) # ArticleList
|
86 |
abstracts = articles.getAbstracts() # List[str]
|
87 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
88 |
-
clusters = self.__postprocess_clusters__(clusters)
|
89 |
return clusters, articles
|
90 |
|
91 |
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
|
@@ -97,7 +108,7 @@ class LiteratureResearchTool:
|
|
97 |
self.literature_search.arxiv(query, num_papers)) # ArticleList
|
98 |
abstracts = articles.getAbstracts() # List[str]
|
99 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
100 |
-
clusters = self.__postprocess_clusters__(clusters)
|
101 |
return clusters, articles
|
102 |
|
103 |
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
|
@@ -109,7 +120,7 @@ class LiteratureResearchTool:
|
|
109 |
self.literature_search.paper_with_code(query, num_papers)) # ArticleList
|
110 |
abstracts = articles.getAbstracts() # List[str]
|
111 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
112 |
-
clusters = self.__postprocess_clusters__(clusters)
|
113 |
return clusters, articles
|
114 |
|
115 |
if platforn_name == 'IEEE':
|
|
|
5 |
from .academic_query import AcademicQuery
|
6 |
import streamlit as st
|
7 |
from tokenizers import Tokenizer
|
8 |
+
from .clustering.clusters import KeyphraseCount
|
9 |
+
|
10 |
|
11 |
|
12 |
class LiteratureResearchTool:
|
|
|
15 |
self.cluster_pipeline = ClusterPipeline(cluster_config)
|
16 |
|
17 |
|
18 |
+
def __postprocess_clusters__(self, clusters: ClusterList,query: str) ->ClusterList:
|
19 |
'''
|
20 |
add top-5 keyphrases to each cluster
|
21 |
:param clusters:
|
22 |
:return: clusters
|
23 |
'''
|
24 |
+
def condition(x: KeyphraseCount, y: KeyphraseCount):
|
25 |
+
return td.ratcliff_obershelp(x.keyphrase, y.keyphrase) > 0.8
|
26 |
+
|
27 |
+
def valid_keyphrase(x:KeyphraseCount):
|
28 |
+
tmp = x.keyphrase
|
29 |
+
return tmp is not None and tmp != '' and not tmp.isspace() and len(tmp)!=1\
|
30 |
+
and tmp != query
|
31 |
|
|
|
|
|
32 |
|
33 |
for cluster in clusters:
|
34 |
+
|
35 |
+
keyphrases = cluster.get_keyphrases() # [kc]
|
|
|
36 |
keyphrases = list(filter(valid_keyphrase,keyphrases))
|
37 |
unionfind = UnionFind(keyphrases, condition)
|
38 |
unionfind.union_step()
|
39 |
|
40 |
+
tmp = unionfind.get_unions() # dict(root_id = [kc])
|
41 |
+
tmp = tmp.values() # [[kc]]
|
42 |
+
# [[kc]] -> [ new kc] -> sorted
|
43 |
+
tmp = [KeyphraseCount.reduce(x) for x in tmp]
|
44 |
+
keyphrases = sorted(tmp,key= lambda x: x.count,reverse=True)[:5]
|
45 |
+
keyphrases = [x.keyphrase for x in keyphrases]
|
46 |
|
47 |
+
# keyphrases = sorted(list(unionfind.get_unions().values()), key=len, reverse=True)[:5] # top-5 keyphrases: list
|
48 |
+
# for i in keyphrases:
|
49 |
+
# tmp = '/'.join(i)
|
50 |
+
# cluster.top_5_keyphrases.append(tmp)
|
51 |
+
cluster.top_5_keyphrases = keyphrases
|
52 |
|
53 |
return clusters
|
54 |
|
|
|
96 |
self.literature_search.ieee(query, start_year, end_year, num_papers)) # ArticleList
|
97 |
abstracts = articles.getAbstracts() # List[str]
|
98 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
99 |
+
clusters = self.__postprocess_clusters__(clusters,query)
|
100 |
return clusters, articles
|
101 |
|
102 |
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
|
|
|
108 |
self.literature_search.arxiv(query, num_papers)) # ArticleList
|
109 |
abstracts = articles.getAbstracts() # List[str]
|
110 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
111 |
+
clusters = self.__postprocess_clusters__(clusters,query)
|
112 |
return clusters, articles
|
113 |
|
114 |
@st.cache(hash_funcs={Tokenizer: Tokenizer.__hash__},allow_output_mutation=True)
|
|
|
120 |
self.literature_search.paper_with_code(query, num_papers)) # ArticleList
|
121 |
abstracts = articles.getAbstracts() # List[str]
|
122 |
clusters = self.cluster_pipeline(abstracts,max_k,standardization)
|
123 |
+
clusters = self.__postprocess_clusters__(clusters,query)
|
124 |
return clusters, articles
|
125 |
|
126 |
if platforn_name == 'IEEE':
|