Spaces:
Running
Running
import asyncio | |
import os | |
import time | |
from datetime import datetime, timedelta, timezone | |
from typing import Any, Dict | |
import gradio as gr | |
import pandas as pd | |
import polars as pl | |
from cachetools import TTLCache, cached | |
from datasets import Dataset | |
from dotenv import load_dotenv | |
from httpx import AsyncClient, Client | |
from huggingface_hub import DatasetCard, hf_hub_url, list_datasets | |
from tqdm.auto import tqdm | |
load_dotenv() | |
LIMIT = 15_000 | |
CACHE_TIME = 60 * 60 * 1 # 1 hour | |
REMOVE_ORGS = { | |
"HuggingFaceM4", | |
"HuggingFaceBR4", | |
"open-llm-leaderboard", | |
"TrainingDataPro", | |
} | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
USER_AGENT = os.getenv("USER_AGENT") | |
if not HF_TOKEN or not USER_AGENT: | |
raise ValueError( | |
"Missing required environment variables. Please ensure both HF_TOKEN and USER_AGENT are set." | |
) | |
headers = {"authorization": f"Bearer {HF_TOKEN}", "user-agent": USER_AGENT} | |
client = Client( | |
headers=headers, | |
timeout=30, | |
) | |
async_client = AsyncClient( | |
headers=headers, | |
timeout=30, | |
http2=True, | |
) | |
cache = TTLCache(maxsize=10, ttl=CACHE_TIME) | |
def get_initial_data(): | |
datasets = list_datasets( | |
limit=LIMIT, | |
sort="createdAt", | |
direction=-1, | |
expand=[ | |
"trendingScore", | |
"createdAt", | |
"author", | |
"downloads", | |
"likes", | |
"cardData", | |
"lastModified", | |
"private", | |
], | |
) | |
return [d.__dict__ for d in tqdm(datasets)] | |
keep_initial = [ | |
"id", | |
"author", | |
"created_at", | |
"last_modified", | |
"private", | |
"downloads", | |
"likes", | |
"trending_score", | |
"card_data", | |
"cardData", | |
] | |
keep_final = [ | |
"id", | |
"author", | |
"created_at", | |
"last_modified", | |
"downloads", | |
"likes", | |
"trending_score", | |
] | |
def prepare_initial_df(): | |
ds = get_initial_data() | |
df = pl.LazyFrame(ds).select(keep_initial) | |
# remove private datasets | |
df = df.filter(~pl.col("private")) | |
df = df.filter(~pl.col("author").is_in(REMOVE_ORGS)) | |
df = df.filter(~pl.col("id").str.contains("my-distiset")) | |
df = df.select(keep_final) | |
return df.collect() | |
async def get_readme_len(row: Dict[str, Any]): | |
SEMPAHORE = asyncio.Semaphore(30) | |
try: | |
url = hf_hub_url(row["id"], "README.md", repo_type="dataset") | |
async with SEMPAHORE: | |
resp = await async_client.get(url) | |
if resp.status_code == 200: | |
card = DatasetCard(resp.text) | |
row["len"] = len(card.text) | |
else: | |
row["len"] = 0 # Use 0 instead of None to avoid type issues | |
return row | |
except Exception as e: | |
print(e) | |
row["len"] = 0 # Use 0 instead of None to avoid type issues | |
return row | |
def prepare_data_with_readme_len(df: pl.DataFrame): | |
ds = Dataset.from_polars(df) | |
ds = ds.map(get_readme_len) | |
return ds | |
async def check_ds_server_valid(row): | |
SEMPAHORE = asyncio.Semaphore(10) | |
try: | |
url = f"https://datasets-server.huggingface.co/is-valid?dataset={row['id']}" | |
async with SEMPAHORE: | |
response = await async_client.get(url) | |
if response.status_code != 200: | |
row["has_server_preview"] = False | |
data = response.json() | |
preview = data.get("preview") | |
row["has_server_preview"] = preview is not None | |
return row | |
except Exception as e: | |
print(e) | |
row["has_server_preview"] = False | |
return row | |
def prep_data_with_server_preview(ds): | |
ds = ds.map(check_ds_server_valid) | |
return ds.to_polars() | |
def render_model_hub_link(hub_id): | |
link = f"https://huggingface.co/datasets/{hub_id}" | |
return ( | |
f'<a target="_blank" href="{link}" style="color: var(--link-text-color);' | |
f' text-decoration: underline;text-decoration-style: dotted;">{hub_id}</a>' | |
) | |
def prep_final_data(): | |
# Check if we have a valid cached parquet file | |
cache_dir = "cache" | |
os.makedirs(cache_dir, exist_ok=True) | |
# Get current time and calculate cache validity | |
now = time.time() | |
cache_valid_time = ( | |
now - CACHE_TIME | |
) # Cache is valid if created within the last CACHE_TIME seconds | |
# Look for valid cache files | |
valid_cache_file = None | |
for filename in os.listdir(cache_dir): | |
if filename.startswith("dataset_cache_") and filename.endswith(".parquet"): | |
try: | |
# Extract timestamp from filename | |
timestamp = float( | |
filename.replace("dataset_cache_", "").replace(".parquet", "") | |
) | |
if timestamp > cache_valid_time: | |
valid_cache_file = os.path.join(cache_dir, filename) | |
break | |
except ValueError: | |
continue | |
# If we have a valid cache file, load it | |
if valid_cache_file: | |
print(f"Loading data from cache: {valid_cache_file}") | |
return pl.read_parquet(valid_cache_file) | |
# Otherwise, generate the data and cache it | |
print("Generating fresh data...") | |
df = prepare_initial_df() | |
ds = prepare_data_with_readme_len(df) | |
df = prep_data_with_server_preview(ds) | |
# Format the ID column as HTML links using string concatenation instead of regex | |
df = df.with_columns( | |
( | |
pl.lit('<a target="_blank" href="https://huggingface.co/datasets/') | |
+ pl.col("id") | |
+ pl.lit( | |
'" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">' | |
) | |
+ pl.col("id") | |
+ pl.lit("</a>") | |
).alias("hub_id") | |
) | |
df = df.drop("id") | |
df = df.sort(by=["trending_score", "likes", "downloads", "len"], descending=True) | |
# make hub_id column first column | |
print(df.columns) | |
df = df.select( | |
[ | |
"hub_id", | |
"author", | |
"created_at", | |
"last_modified", | |
"downloads", | |
"likes", | |
"trending_score", | |
"len", | |
"has_server_preview", | |
] | |
) | |
# Save to cache | |
cache_file = os.path.join(cache_dir, f"dataset_cache_{now}.parquet") | |
df.write_parquet(cache_file) | |
# Clean up old cache files | |
for filename in os.listdir(cache_dir): | |
if filename.startswith("dataset_cache_") and filename.endswith(".parquet"): | |
try: | |
timestamp = float( | |
filename.replace("dataset_cache_", "").replace(".parquet", "") | |
) | |
if timestamp <= cache_valid_time: | |
os.remove(os.path.join(cache_dir, filename)) | |
except ValueError: | |
continue | |
return df | |
def filter_by_max_age(df, max_age_days): | |
df = df.filter( | |
pl.col("created_at") | |
> (datetime.now(timezone.utc) - timedelta(days=max_age_days)) | |
) | |
return df | |
def filter_by_min_len(df, min_len): | |
df = df.filter(pl.col("len") >= min_len) | |
return df | |
def filter_by_server_preview(df, needs_server_preview): | |
df = df.filter(pl.col("has_server_preview") == needs_server_preview) | |
return df | |
def filter_df(max_age_days, min_len, needs_server_preview): | |
df = prep_final_data() | |
df = df.lazy() | |
df = filter_by_max_age(df, max_age_days) | |
df = filter_by_min_len(df, min_len) | |
df = filter_by_server_preview(df, needs_server_preview) | |
df = df.sort(by=["trending_score", "likes", "downloads", "len"], descending=True) | |
return df.collect() | |
with gr.Blocks() as demo: | |
gr.Markdown("# Recent Datasets on the Hub") | |
gr.Markdown( | |
"Datasets added in the past 90 days with a README.md and some metadata." | |
) | |
with gr.Row(): | |
max_age_days = gr.Slider( | |
label="Max Age (days)", | |
value=7, | |
minimum=0, | |
maximum=90, | |
step=1, | |
interactive=True, | |
) | |
min_len = gr.Slider( | |
label="Minimum README Length", | |
value=300, | |
minimum=0, | |
maximum=1000, | |
step=50, | |
interactive=True, | |
) | |
needs_server_preview = gr.Checkbox( | |
label="Exclude datasets without datasets-server preview?", | |
value=False, | |
interactive=True, | |
) | |
output = gr.DataFrame( | |
value=filter_df(7, 300, False), | |
interactive=False, | |
datatype="markdown", | |
) | |
def update_df(age, length, preview): | |
return filter_df(age, length, preview) | |
# Connect the input components to the update function | |
for component in [max_age_days, min_len, needs_server_preview]: | |
component.change( | |
fn=update_df, | |
inputs=[max_age_days, min_len, needs_server_preview], | |
outputs=[output], | |
) | |
demo.launch() | |