File size: 8,848 Bytes
88d91f4 383512a 059d8f0 383512a 88d91f4 b774671 383512a fb67b80 88d91f4 e9f37ce 383512a 88d91f4 383512a 63a514a a51fb44 63a514a 383512a 63a514a 99e39f3 63a514a 99e39f3 63a514a 383512a 63a514a 383512a 7a4e68d 383512a 63a514a 383512a 63a514a 383512a 63a514a 383512a 63a514a 383512a 63a514a 383512a 63a514a 383512a e9f37ce 383512a 63a514a 059d8f0 b774671 63a514a b774671 63a514a b774671 63a514a b774671 63a514a 88d91f4 383512a 88d91f4 383512a 63a514a d625373 63a514a 383512a e9f37ce 88d91f4 63a514a 1347af3 63a514a 88d91f4 63a514a 3dcfe9e 63a514a 88d91f4 63a514a 88d91f4 63a514a 0496749 63a514a 0496749 059d8f0 63a514a 6867483 63a514a 383512a 63a514a 383512a 88d91f4 63a514a 070c3a1 9f4735a 070c3a1 63a514a 58a3f61 f6889a3 63a514a 5dac6c6 0fd6c86 63a514a f6889a3 63a514a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import os
import json
import requests
import gradio as gr
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.repocard import metadata_load
from apscheduler.schedulers.background import BackgroundScheduler
from tqdm.contrib.concurrent import thread_map
from utils import *
DATASET_REPO_URL = "https://huggingface.co/datasets/huggingface-projects/drlc-leaderboard-data"
DATASET_REPO_ID = "huggingface-projects/drlc-leaderboard-data"
HF_TOKEN = os.environ.get("HF_TOKEN")
block = gr.Blocks()
api = HfApi(token=HF_TOKEN)
# Define RL environments
rl_envs = [
{"rl_env_beautiful": "LunarLander-v2 π", "rl_env": "LunarLander-v2", "video_link": "", "global": None},
{"rl_env_beautiful": "CartPole-v1", "rl_env": "CartPole-v1", "video_link": "https://huggingface.co/sb3/ppo-CartPole-v1/resolve/main/replay.mp4", "global": None},
{"rl_env_beautiful": "FrozenLake-v1-4x4-no_slippery βοΈ", "rl_env": "FrozenLake-v1-4x4-no_slippery", "video_link": "", "global": None},
{"rl_env_beautiful": "FrozenLake-v1-8x8-no_slippery βοΈ", "rl_env": "FrozenLake-v1-8x8-no_slippery", "video_link": "", "global": None},
{"rl_env_beautiful": "FrozenLake-v1-4x4 βοΈ", "rl_env": "FrozenLake-v1-4x4", "video_link": "", "global": None},
{"rl_env_beautiful": "FrozenLake-v1-8x8 βοΈ", "rl_env": "FrozenLake-v1-8x8", "video_link": "", "global": None},
{"rl_env_beautiful": "Taxi-v3 π", "rl_env": "Taxi-v3", "video_link": "", "global": None},
{"rl_env_beautiful": "CarRacing-v0 ποΈ", "rl_env": "CarRacing-v0", "video_link": "", "global": None},
{"rl_env_beautiful": "CarRacing-v2 ποΈ", "rl_env": "CarRacing-v2", "video_link": "", "global": None},
{"rl_env_beautiful": "MountainCar-v0 β°οΈ", "rl_env": "MountainCar-v0", "video_link": "", "global": None},
{"rl_env_beautiful": "SpaceInvadersNoFrameskip-v4 πΎ", "rl_env": "SpaceInvadersNoFrameskip-v4", "video_link": "", "global": None},
{"rl_env_beautiful": "PongNoFrameskip-v4 πΎ", "rl_env": "PongNoFrameskip-v4", "video_link": "", "global": None},
{"rl_env_beautiful": "BreakoutNoFrameskip-v4 π§±", "rl_env": "BreakoutNoFrameskip-v4", "video_link": "", "global": None},
{"rl_env_beautiful": "QbertNoFrameskip-v4 π¦", "rl_env": "QbertNoFrameskip-v4", "video_link": "", "global": None},
{"rl_env_beautiful": "BipedalWalker-v3", "rl_env": "BipedalWalker-v3", "video_link": "", "global": None},
{"rl_env_beautiful": "Walker2DBulletEnv-v0", "rl_env": "Walker2DBulletEnv-v0", "video_link": "", "global": None},
{"rl_env_beautiful": "AntBulletEnv-v0", "rl_env": "AntBulletEnv-v0", "video_link": "", "global": None},
{"rl_env_beautiful": "HalfCheetahBulletEnv-v0", "rl_env": "HalfCheetahBulletEnv-v0", "video_link": "", "global": None},
{"rl_env_beautiful": "PandaReachDense-v2", "rl_env": "PandaReachDense-v2", "video_link": "", "global": None},
{"rl_env_beautiful": "PandaReachDense-v3", "rl_env": "PandaReachDense-v3", "video_link": "", "global": None},
{"rl_env_beautiful": "Pixelcopter-PLE-v0", "rl_env": "Pixelcopter-PLE-v0", "video_link": "", "global": None}
]
# -------------------- Utility Functions --------------------
def restart():
"""Restart the Hugging Face Space."""
print("RESTARTING SPACE...")
api.restart_space(repo_id="huggingface-projects/Deep-Reinforcement-Learning-Leaderboard")
def download_leaderboard_dataset():
"""Download leaderboard dataset once at startup."""
print("Downloading leaderboard dataset...")
return snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")
def get_metadata(model_id):
"""Fetch metadata for a given model from Hugging Face."""
try:
readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180)
return metadata_load(readme_path)
except requests.exceptions.HTTPError:
return None # 404 README.md not found
def parse_metrics_accuracy(meta):
"""Extract accuracy metrics from metadata."""
if "model-index" not in meta:
return None
result = meta["model-index"][0]["results"]
metrics = result[0]["metrics"]
return metrics[0]["value"]
def parse_rewards(accuracy):
"""Extract mean and std rewards from accuracy metrics."""
default_std = -1000
default_reward = -1000
if accuracy is not None:
parsed = str(accuracy).split('+/-')
mean_reward = float(parsed[0].strip()) if parsed[0] else default_reward
std_reward = float(parsed[1].strip()) if len(parsed) > 1 else 0
else:
mean_reward, std_reward = default_reward, default_std
return mean_reward, std_reward
def get_model_ids(rl_env):
"""Retrieve models matching the given RL environment."""
return [x.modelId for x in api.list_models(filter=rl_env)]
def update_leaderboard_dataset_parallel(rl_env, path):
"""Parallelized update of leaderboard dataset for a given RL environment."""
model_ids = get_model_ids(rl_env)
def process_model(model_id):
meta = get_metadata(model_id)
if not meta:
return None
user_id = model_id.split('/')[0]
row = {
"User": user_id,
"Model": model_id,
"Results": None,
"Mean Reward": None,
"Std Reward": None
}
accuracy = parse_metrics_accuracy(meta)
mean_reward, std_reward = parse_rewards(accuracy)
row["Results"] = mean_reward - std_reward
row["Mean Reward"] = mean_reward
row["Std Reward"] = std_reward
return row
data = list(thread_map(process_model, model_ids, desc="Processing models"))
data = [row for row in data if row is not None]
ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
ranked_dataframe.to_csv(os.path.join(path, f"{rl_env}.csv"), index=False)
return ranked_dataframe
def rank_dataframe(dataframe):
"""Sort models by results and assign ranking."""
dataframe = dataframe.sort_values(by=['Results', 'User', 'Model'], ascending=False)
dataframe.insert(0, 'Ranking', range(1, len(dataframe) + 1))
return dataframe
def run_update_dataset():
"""Update dataset periodically using the scheduler."""
path_ = download_leaderboard_dataset()
for env in rl_envs:
update_leaderboard_dataset_parallel(env["rl_env"], path_)
print("Uploading updated dataset...")
api.upload_folder(
folder_path=path_,
repo_id=DATASET_REPO_ID,
repo_type="dataset",
commit_message="Update dataset"
)
def filter_data(rl_env, path, user_id):
"""Filter dataset for a specific user ID."""
data_df = pd.read_csv(os.path.join(path, f"{rl_env}.csv"))
return data_df[data_df["User"] == user_id]
# -------------------- Gradio UI --------------------
print("Initializing dataset...")
path_ = download_leaderboard_dataset()
with block:
gr.Markdown("""
# π Deep Reinforcement Learning Course Leaderboard π
This leaderboard displays trained agents from the [Deep Reinforcement Learning Course](https://huggingface.co/learn/deep-rl-course/unit0/introduction?fw=pt).
**Models are ranked using `mean_reward - std_reward`.**
If you can't find your model, please wait for the next update (every 2 hours).
""")
grpath = gr.State(path_) # Store dataset path as a state variable
for env in rl_envs:
with gr.TabItem(env["rl_env_beautiful"]):
gr.Markdown(f"## {env['rl_env_beautiful']}")
user_id = gr.Textbox(label="Your user ID")
search_btn = gr.Button("Search π")
reset_btn = gr.Button("Clear Search")
env_state = gr.State(env["rl_env"]) # Store environment name as a state variable
gr_dataframe = gr.Dataframe(
value=pd.read_csv(os.path.join(path_, f"{env['rl_env']}.csv")),
headers=["Ranking π", "User π€", "Model π€", "Results", "Mean Reward", "Std Reward"],
datatype=["number", "markdown", "markdown", "number", "number", "number"],
# row_count=(100, 'fixed')
row_count=(100,"dynamic") # Allows displaying all rows dynamically
)
# β
Corrected: Use `gr.State()` for env["rl_env"] and `grpath`
search_btn.click(fn=filter_data, inputs=[env_state, grpath, user_id], outputs=gr_dataframe)
reset_btn.click(fn=lambda: pd.read_csv(os.path.join(path_, f"{env['rl_env']}.csv")), inputs=[], outputs=gr_dataframe)
# -------------------- Scheduler --------------------
scheduler = BackgroundScheduler()
scheduler.add_job(run_update_dataset, 'interval', hours=2) # Update dataset every 2 hours
scheduler.add_job(restart, 'interval', hours=3) # Restart space every 3 hours
scheduler.start()
block.launch()
|