Spaces:
Running
Running
import gradio as gr | |
from typing import Optional | |
import pandas as pd | |
from huggingface_hub import HfApi, hf_hub_download, CommitOperationAdd | |
import json | |
import os | |
# PR function as before | |
def create_pr_in_hf_dataset(new_entry, oauth_token: gr.OAuthToken): | |
# Dataset and filename | |
REPO_ID = 'IAMJB/paper-central-pr' | |
FILENAME = 'data.json' | |
# Initialize HfApi | |
api = HfApi() | |
token = oauth_token.token | |
# Ensure the repository exists and has an initial empty data.json if not present | |
try: | |
# Create the repository if it doesn't exist | |
api.create_repo(repo_id=REPO_ID, token=token, repo_type='dataset', exist_ok=True) | |
# Check if data.json exists; if not, create it with empty list | |
files = api.list_repo_files(REPO_ID, repo_type='dataset', token=token) | |
if FILENAME not in files: | |
# Initialize with empty list | |
empty_data = [] | |
temp_filename = 'temp_data.json' | |
with open(temp_filename, 'w') as f: | |
json.dump(empty_data, f) | |
commit = CommitOperationAdd(path_in_repo=FILENAME, path_or_fileobj=temp_filename) | |
api.create_commit( | |
repo_id=REPO_ID, | |
operations=[commit], | |
commit_message="Initialize data.json", | |
repo_type="dataset", | |
token=token, | |
) | |
os.remove(temp_filename) | |
except Exception as e: | |
return f"Error creating or accessing repository: {e}" | |
# Download existing data from the dataset | |
try: | |
# Download the existing data.json file | |
local_filepath = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, repo_type='dataset', token=token) | |
with open(local_filepath, 'r') as f: | |
data = json.load(f) | |
except Exception as e: | |
print(f"Error downloading existing data: {e}") | |
data = [] | |
# Add the new entry | |
data.append(new_entry) | |
# Save to temporary file | |
temp_filename = 'temp_data.json' | |
with open(temp_filename, 'w') as f: | |
json.dump(data, f, indent=2) | |
# Create commit operation | |
commit = CommitOperationAdd(path_in_repo=FILENAME, path_or_fileobj=temp_filename) | |
# Create PR | |
try: | |
res = api.create_commit( | |
repo_id=REPO_ID, | |
operations=[commit], | |
commit_message=f"Add new entry for arXiv ID {new_entry['arxiv_id']}", | |
repo_type="dataset", | |
create_pr=True, | |
token=token, | |
) | |
pr_url = res.pr_url | |
os.remove(temp_filename) | |
except Exception as e: | |
print(f"Error creating PR: {e}") | |
pr_url = "Error creating PR." | |
return pr_url | |
def pr_paper_central_tab(paper_central_df): | |
with gr.Column(): | |
gr.Markdown("## PR Paper-central") | |
# Message to prompt user to log in | |
login_prompt = gr.Markdown("Please log in to proceed.", visible=False) | |
# Input for arXiv ID | |
arxiv_id_input = gr.Textbox(label="Enter arXiv ID") | |
arxiv_id_button = gr.Button("Submit") | |
# Message to display errors or information | |
message = gr.Markdown("", visible=False) | |
# Define the fields dynamically | |
fields = [ | |
{'name': 'paper_page', 'label': 'Paper Page'}, | |
{'name': 'github', 'label': 'GitHub URL'}, | |
{'name': 'conference_name', 'label': 'Conference Name'}, | |
{'name': 'type_', 'label': 'Type'}, # Renamed from 'type' to 'type_' | |
{'name': 'proceedings', 'label': 'Proceedings'}, | |
# Add or remove fields here as needed | |
] | |
input_fields = {} | |
for field in fields: | |
input_fields[field['name']] = gr.Textbox(label=field['label'], visible=False) | |
# Button to create PR | |
create_pr_button = gr.Button("Create PR", visible=False) | |
# Output message | |
pr_message = gr.Markdown("", visible=False) | |
# Function to handle arxiv_id submission and check login | |
def check_login_and_handle_arxiv_id(arxiv_id, oauth_token: Optional[gr.OAuthToken]): | |
if oauth_token is None: | |
return [gr.update(value="Please log in to proceed.", visible=True)] + \ | |
[gr.update(visible=False) for _ in fields] + \ | |
[gr.update(visible=False)] # create_pr_button | |
else: | |
if arxiv_id not in paper_central_df['arxiv_id'].values: | |
return [gr.update(value="arXiv ID not found. Please try again.", visible=True)] + \ | |
[gr.update(visible=False) for _ in fields] + [gr.update(visible=False)] # create_pr_button | |
else: | |
row = paper_central_df[paper_central_df['arxiv_id'] == arxiv_id].iloc[0] | |
updates = [gr.update(value="", visible=False)] # message | |
for field in fields: | |
value = row.get(field['name'], "") | |
updates.append(gr.update(value=value, visible=True)) | |
updates.append(gr.update(visible=True)) # create_pr_button | |
return updates | |
arxiv_id_button.click( | |
fn=check_login_and_handle_arxiv_id, | |
inputs=[arxiv_id_input], | |
outputs=[message] + [input_fields[field['name']] for field in fields] + [create_pr_button] | |
) | |
# Function to create PR | |
def create_pr(message, arxiv_id, paper_page, github, conference_name, type_, proceedings, | |
oauth_token: Optional[gr.OAuthToken] = None): | |
if oauth_token is None: | |
return gr.update(value="Please log in first.", visible=True) | |
else: | |
new_entry = { | |
'arxiv_id': arxiv_id, | |
'paper_page': paper_page, | |
'github': github, | |
'conference_name': conference_name, | |
'type': type_, | |
'proceedings': proceedings | |
} | |
# Now add this to the dataset and create a PR | |
pr_url = create_pr_in_hf_dataset(new_entry, oauth_token) | |
return gr.update(value=f"PR created: {pr_url}", visible=True) | |
create_pr_button.click( | |
fn=create_pr, | |
inputs=[pr_message, arxiv_id_input] + [input_fields[field['name']] for field in fields], | |
outputs=[pr_message] | |
) | |