Spaces:
Runtime error
Runtime error
File size: 3,751 Bytes
6541245 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from .clustering import *
from typing import List
import textdistance as td
from .utils import UnionFind, ArticleList
from .academic_query import AcademicQuery
class LiteratureResearchTool:
def __init__(self, cluster_config: Configuration = None):
self.literature_search = AcademicQuery
self.cluster_pipeline = ClusterPipeline(cluster_config)
def __postprocess_clusters__(self, clusters: ClusterList) ->ClusterList:
'''
add top-5 keyphrases to each cluster
:param clusters:
:return: clusters
'''
def condition(x, y):
return td.ratcliff_obershelp(x, y) > 0.8
def valid_keyphrase(x:str):
return x is not None and x != '' and not x.isspace()
for cluster in clusters:
cluster.top_5_keyphrases = []
keyphrases = cluster.get_keyphrases()
keyphrases = list(keyphrases.keys())
keyphrases = list(filter(valid_keyphrase,keyphrases))
unionfind = UnionFind(keyphrases, condition)
unionfind.union_step()
keyphrases = sorted(list(unionfind.get_unions().values()), key=len, reverse=True)[:5] # top-5 keyphrases: list
for i in keyphrases:
tmp = '/'.join(i)
cluster.top_5_keyphrases.append(tmp)
return clusters
def __call__(self,
query: str,
num_papers: int,
start_year: int,
end_year: int,
platforms: List[str] = ['IEEE', 'Arxiv', 'Paper with Code'],
best_k: int = 5,
loading_ctx_manager = None,
decorator: callable = None
):
for platform in platforms:
if loading_ctx_manager:
with loading_ctx_manager:
clusters, articles = self.__platformPipeline__(platform,query,num_papers,start_year,end_year,best_k)
else:
clusters, articles = self.__platformPipeline__(platform, query, num_papers, start_year, end_year,best_k)
clusters.sort()
yield clusters,articles
def __platformPipeline__(self,platforn_name:str,
query: str,
num_papers: int,
start_year: int,
end_year: int,
best_k: int = 5
) -> (ClusterList,ArticleList):
if platforn_name == 'IEEE':
articles = ArticleList.parse_ieee_articles(self.literature_search.ieee(query,start_year,end_year,num_papers)) # ArticleList
abstracts = articles.getAbstracts() # List[str]
clusters = self.cluster_pipeline(abstracts, best_k=best_k)
clusters = self.__postprocess_clusters__(clusters)
return clusters,articles
elif platforn_name == 'Arxiv':
articles = ArticleList.parse_arxiv_articles(
self.literature_search.arxiv(query, num_papers)) # ArticleList
abstracts = articles.getAbstracts() # List[str]
clusters = self.cluster_pipeline(abstracts,best_k=best_k)
clusters = self.__postprocess_clusters__(clusters)
return clusters, articles
elif platforn_name == 'Paper with Code':
articles = ArticleList.parse_pwc_articles(
self.literature_search.paper_with_code(query, num_papers)) # ArticleList
abstracts = articles.getAbstracts() # List[str]
clusters = self.cluster_pipeline(abstracts,best_k=best_k)
clusters = self.__postprocess_clusters__(clusters)
return clusters, articles
|