Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import datetime | |
import pathlib | |
import re | |
import tempfile | |
import pandas as pd | |
import requests | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from huggingface_hub import HfApi, Repository | |
from huggingface_hub.utils import RepositoryNotFoundError | |
class SpaceRestarter: | |
def __init__(self, space_id: str): | |
self.api = HfApi() | |
if self.api.get_token_permission() != 'write': | |
raise ValueError('The HF token must have write permission.') | |
try: | |
self.api.space_info(repo_id=space_id) | |
except RepositoryNotFoundError: | |
raise ValueError('The Space ID does not exist.') | |
self.space_id = space_id | |
def restart(self) -> None: | |
self.api.restart_space(self.space_id) | |
def find_github_links(summary: str) -> str: | |
links = re.findall( | |
r'https://github.com/[^/]+/[^/)}, ]+(?:/(?:tree|blob)/[^/]+/[^/)}, ]+)?', | |
summary) | |
if len(links) == 0: | |
return '' | |
if len(links) != 1: | |
raise RuntimeError(f'Found multiple GitHub links: {links}') | |
link = links[0] | |
if link.endswith('.'): | |
link = link[:-1] | |
link = link.strip() | |
return link | |
class RepoUpdater: | |
def __init__(self, repo_id: str, repo_type: str): | |
api = HfApi() | |
if api.get_token_permission() != 'write': | |
raise ValueError('The HF token must have write permission.') | |
name = api.whoami()['name'] | |
repo_dir = pathlib.Path( | |
tempfile.tempdir) / repo_id.split('/')[-1] # type: ignore | |
self.csv_path = repo_dir / 'papers.csv' | |
self.repo = Repository( | |
local_dir=repo_dir, | |
clone_from=repo_id, | |
repo_type=repo_type, | |
git_user=name, | |
git_email=f'{name}@users.noreply.huggingface.co') | |
self.repo.git_pull() | |
def update(self) -> None: | |
yesterday = (datetime.datetime.now() - | |
datetime.timedelta(days=1)).strftime('%Y-%m-%d') | |
today = datetime.datetime.now().strftime('%Y-%m-%d') | |
daily_papers = [ | |
{ | |
'date': | |
yesterday, | |
'papers': | |
requests.get( | |
f'https://huggingface.co/api/daily_papers?date={yesterday}' | |
).json() | |
}, | |
{ | |
'date': | |
today, | |
'papers': | |
requests.get( | |
f'https://huggingface.co/api/daily_papers?date={today}'). | |
json() | |
}, | |
] | |
self.repo.git_pull() | |
df = pd.read_csv(self.csv_path, dtype=str).fillna('') | |
rows = [row for _, row in df.iterrows()] | |
arxiv_ids = {row.arxiv_id for row in rows} | |
for d in daily_papers: | |
date = d['date'] | |
papers = d['papers'] | |
for paper in papers: | |
arxiv_id = paper['paper']['id'] | |
if arxiv_id in arxiv_ids: | |
continue | |
try: | |
github = find_github_links(paper['paper']['summary']) | |
except RuntimeError as e: | |
print(e) | |
continue | |
rows.append( | |
pd.Series({ | |
'date': date, | |
'arxiv_id': arxiv_id, | |
'github': github, | |
})) | |
df = pd.DataFrame(rows).reset_index(drop=True) | |
df.to_csv(self.csv_path, index=False) | |
def push(self) -> None: | |
self.repo.push_to_hub() | |
class UpdateScheduler: | |
def __init__(self, space_id: str, cron_hour: str, cron_minute: str): | |
self.space_restarter = SpaceRestarter(space_id=space_id) | |
self.repo_updater = RepoUpdater(repo_id=space_id, repo_type='space') | |
self.scheduler = BackgroundScheduler() | |
self.scheduler.add_job(func=self._update, | |
trigger='cron', | |
hour=cron_hour, | |
minute=cron_minute, | |
second=0, | |
timezone='UTC') | |
def _update(self) -> None: | |
self.repo_updater.update() | |
if self.repo_updater.repo.is_repo_clean(): | |
self.space_restarter.restart() | |
else: | |
self.repo_updater.push() | |
def start(self) -> None: | |
self.scheduler.start() | |