|
|
|
|
|
import datetime |
|
import operator |
|
import pandas as pd |
|
import tqdm.auto |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
from huggingface_hub import HfApi |
|
from ragatouille import RAGPretrainedModel |
|
|
|
import gradio as gr |
|
from gradio_calendar import Calendar |
|
import datasets |
|
|
|
|
|
|
|
api = HfApi() |
|
|
|
INDEX_REPO_ID = "hysts-bot-data/daily-papers-abstract-index" |
|
INDEX_DIR_PATH = ".ragatouille/colbert/indexes/daily-papers-abstract-index/" |
|
api.snapshot_download( |
|
repo_id=INDEX_REPO_ID, |
|
repo_type="dataset", |
|
local_dir=INDEX_DIR_PATH, |
|
) |
|
abstract_retriever = RAGPretrainedModel.from_index(INDEX_DIR_PATH) |
|
|
|
abstract_retriever.search("LLM") |
|
|
|
|
|
def update_abstract_index() -> None: |
|
global abstract_retriever |
|
|
|
api.snapshot_download( |
|
repo_id=INDEX_REPO_ID, |
|
repo_type="dataset", |
|
local_dir=INDEX_DIR_PATH, |
|
) |
|
abstract_retriever = RAGPretrainedModel.from_index(INDEX_DIR_PATH) |
|
abstract_retriever.search("LLM") |
|
|
|
|
|
scheduler_abstract = BackgroundScheduler() |
|
scheduler_abstract.add_job( |
|
func=update_abstract_index, |
|
trigger="cron", |
|
minute=0, |
|
timezone="UTC", |
|
misfire_grace_time=3 * 60, |
|
) |
|
scheduler_abstract.start() |
|
|
|
|
|
def get_df() -> pd.DataFrame: |
|
df = pd.merge( |
|
left=datasets.load_dataset("hysts-bot-data/daily-papers", split="train").to_pandas(), |
|
right=datasets.load_dataset("hysts-bot-data/daily-papers-stats", split="train").to_pandas(), |
|
on="arxiv_id", |
|
) |
|
df = df[::-1].reset_index(drop=True) |
|
df["date"] = df["date"].dt.strftime("%Y-%m-%d") |
|
|
|
paper_info = [] |
|
for _, row in tqdm.auto.tqdm(df.iterrows(), total=len(df)): |
|
info = row.copy() |
|
del info["abstract"] |
|
info["paper_page"] = f"https://huggingface.co/papers/{row.arxiv_id}" |
|
paper_info.append(info) |
|
return pd.DataFrame(paper_info) |
|
|
|
|
|
class Prettifier: |
|
@staticmethod |
|
def get_github_link(link: str) -> str: |
|
if not link: |
|
return "" |
|
return Prettifier.create_link("github", link) |
|
|
|
@staticmethod |
|
def create_link(text: str, url: str) -> str: |
|
return f'<a href="{url}" target="_blank">{text}</a>' |
|
|
|
@staticmethod |
|
def to_div(text: str | None, category_name: str) -> str: |
|
if text is None: |
|
text = "" |
|
class_name = f"{category_name}-{text.lower()}" |
|
return f'<div class="{class_name}">{text}</div>' |
|
|
|
def __call__(self, df: pd.DataFrame) -> pd.DataFrame: |
|
new_rows = [] |
|
for _, row in df.iterrows(): |
|
new_row = { |
|
"date": Prettifier.create_link(row.date, f"https://huggingface.co/papers?date={row.date}"), |
|
"paper_page": Prettifier.create_link(row.arxiv_id, row.paper_page), |
|
"title": row["title"], |
|
"github": self.get_github_link(row.github), |
|
"๐": row["upvotes"], |
|
"๐ฌ": row["num_comments"], |
|
} |
|
new_rows.append(new_row) |
|
return pd.DataFrame(new_rows) |
|
|
|
|
|
class PaperList: |
|
COLUMN_INFO = [ |
|
["date", "markdown"], |
|
["paper_page", "markdown"], |
|
["title", "str"], |
|
["github", "markdown"], |
|
["๐", "number"], |
|
["๐ฌ", "number"], |
|
] |
|
|
|
def __init__(self, df: pd.DataFrame): |
|
self.df_raw = df |
|
self._prettifier = Prettifier() |
|
self.df_prettified = self._prettifier(df).loc[:, self.column_names] |
|
|
|
@property |
|
def column_names(self): |
|
return list(map(operator.itemgetter(0), self.COLUMN_INFO)) |
|
|
|
@property |
|
def column_datatype(self): |
|
return list(map(operator.itemgetter(1), self.COLUMN_INFO)) |
|
|
|
def search( |
|
self, |
|
start_date: datetime.datetime, |
|
end_date: datetime.datetime, |
|
title_search_query: str, |
|
abstract_search_query: str, |
|
max_num_to_retrieve: int, |
|
) -> pd.DataFrame: |
|
df = self.df_raw.copy() |
|
df["date"] = pd.to_datetime(df["date"]) |
|
|
|
|
|
df = df[(df["date"] >= start_date) & (df["date"] <= end_date)] |
|
df["date"] = df["date"].dt.strftime("%Y-%m-%d") |
|
|
|
|
|
if title_search_query: |
|
df = df[df["title"].str.contains(title_search_query, case=False)] |
|
|
|
|
|
if abstract_search_query: |
|
results = abstract_retriever.search(abstract_search_query, k=max_num_to_retrieve) |
|
remaining_ids = set(df["arxiv_id"]) |
|
found_id_set = set() |
|
found_ids = [] |
|
for x in results: |
|
arxiv_id = x["document_id"] |
|
if arxiv_id not in remaining_ids: |
|
continue |
|
if arxiv_id in found_id_set: |
|
continue |
|
found_id_set.add(arxiv_id) |
|
found_ids.append(arxiv_id) |
|
df = df[df["arxiv_id"].isin(found_ids)].set_index("arxiv_id").reindex(index=found_ids).reset_index() |
|
|
|
df_prettified = self._prettifier(df).loc[:, self.column_names] |
|
return df_prettified |
|
|
|
|
|
paper_list = PaperList(get_df()) |
|
|
|
|
|
def update_paper_list() -> None: |
|
global paper_list |
|
paper_list = PaperList(get_df()) |
|
|
|
|
|
scheduler_data = BackgroundScheduler() |
|
scheduler_data.add_job( |
|
func=update_paper_list, |
|
trigger="cron", |
|
minute=0, |
|
timezone="UTC", |
|
misfire_grace_time=60, |
|
) |
|
scheduler_data.start() |
|
|
|
|
|
|
|
DESCRIPTION = "# [Daily Papers](https://huggingface.co/papers)" |
|
|
|
FOOT_NOTE = """\ |
|
Related useful Spaces: |
|
- [Semantic Scholar Paper Recommender](https://huggingface.co/spaces/librarian-bots/recommend_similar_papers) by [davanstrien](https://huggingface.co/davanstrien) |
|
- [ArXiv CS RAG](https://huggingface.co/spaces/bishmoy/Arxiv-CS-RAG) by [bishmoy](https://huggingface.co/bishmoy) |
|
- [Paper Q&A](https://huggingface.co/spaces/chansung/paper_qa) by [chansung](https://huggingface.co/chansung) |
|
""" |
|
|
|
|
|
def update_df() -> pd.DataFrame: |
|
return paper_list.df_prettified |
|
|
|
|
|
def update_num_papers(df: pd.DataFrame) -> str: |
|
return f"{len(df)} / {len(paper_list.df_raw)}" |
|
|
|
|
|
def search( |
|
start_date: datetime.datetime, |
|
end_date: datetime.datetime, |
|
search_title: str, |
|
search_abstract: str, |
|
max_num_to_retrieve: int, |
|
) -> pd.DataFrame: |
|
return paper_list.search(start_date, end_date, search_title, search_abstract, max_num_to_retrieve) |
|
|
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
gr.Markdown(DESCRIPTION) |
|
with gr.Group(): |
|
search_title = gr.Textbox(label="Search title") |
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
search_abstract = gr.Textbox( |
|
label="Search abstract", |
|
info="The result may not be accurate as the abstract does not contain all the information.", |
|
) |
|
with gr.Column(scale=1): |
|
max_num_to_retrieve = gr.Slider( |
|
label="Max number to retrieve", |
|
info="This is used only for search on abstracts.", |
|
minimum=1, |
|
maximum=len(paper_list.df_raw), |
|
step=1, |
|
value=100, |
|
) |
|
with gr.Row(): |
|
start_date = Calendar(label="Start date", type="date", value="2023-05-05") |
|
end_date = Calendar(label="End date", type="date", value=datetime.datetime.utcnow().strftime("%Y-%m-%d")) |
|
|
|
num_papers = gr.Textbox(label="Number of papers", value=update_num_papers(paper_list.df_raw), interactive=False) |
|
df = gr.Dataframe( |
|
value=paper_list.df_prettified, |
|
datatype=paper_list.column_datatype, |
|
type="pandas", |
|
interactive=False, |
|
height=1000, |
|
elem_id="table", |
|
column_widths=["10%", "10%", "60%", "10%", "5%", "5%"], |
|
wrap=True, |
|
) |
|
|
|
gr.Markdown(FOOT_NOTE) |
|
|
|
|
|
search_event = gr.Button("Search") |
|
search_event.click( |
|
fn=search, |
|
inputs=[start_date, end_date, search_title, search_abstract, max_num_to_retrieve], |
|
outputs=df, |
|
).then( |
|
fn=update_num_papers, |
|
inputs=df, |
|
outputs=num_papers, |
|
queue=False, |
|
) |
|
|
|
|
|
for trigger in [start_date, end_date, search_title, search_abstract, max_num_to_retrieve]: |
|
trigger.change( |
|
fn=search, |
|
inputs=[start_date, end_date, search_title, search_abstract, max_num_to_retrieve], |
|
outputs=df, |
|
).then( |
|
fn=update_num_papers, |
|
inputs=df, |
|
outputs=num_papers, |
|
queue=False, |
|
) |
|
|
|
|
|
demo.load( |
|
fn=update_df, |
|
outputs=df, |
|
queue=False, |
|
).then( |
|
fn=update_num_papers, |
|
inputs=df, |
|
outputs=num_papers, |
|
queue=False, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(api_open=False).launch(show_api=False) |