File size: 8,848 Bytes
88d91f4
383512a
059d8f0
383512a
 
88d91f4
 
 
b774671
383512a
fb67b80
88d91f4
 
 
e9f37ce
383512a
88d91f4
383512a
63a514a
a51fb44
63a514a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383512a
 
63a514a
 
99e39f3
63a514a
 
99e39f3
 
63a514a
 
 
 
 
383512a
63a514a
383512a
7a4e68d
383512a
 
63a514a
 
383512a
63a514a
383512a
 
 
 
63a514a
383512a
 
63a514a
383512a
63a514a
 
 
 
 
383512a
63a514a
383512a
e9f37ce
383512a
63a514a
 
059d8f0
b774671
63a514a
b774671
 
 
 
63a514a
b774671
 
63a514a
 
 
 
 
 
 
b774671
 
 
 
 
 
 
 
 
 
 
63a514a
88d91f4
383512a
88d91f4
383512a
63a514a
d625373
63a514a
383512a
e9f37ce
88d91f4
63a514a
1347af3
63a514a
 
88d91f4
63a514a
3dcfe9e
63a514a
 
 
 
 
88d91f4
 
63a514a
 
 
88d91f4
63a514a
0496749
63a514a
 
0496749
059d8f0
63a514a
 
6867483
63a514a
 
 
383512a
63a514a
383512a
88d91f4
63a514a
 
 
 
 
 
 
 
 
 
 
 
 
 
070c3a1
9f4735a
070c3a1
63a514a
 
 
 
 
58a3f61
f6889a3
63a514a
5dac6c6
0fd6c86
63a514a
 
f6889a3
 
63a514a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import json
import requests
import gradio as gr
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.repocard import metadata_load
from apscheduler.schedulers.background import BackgroundScheduler
from tqdm.contrib.concurrent import thread_map
from utils import *

DATASET_REPO_URL = "https://huggingface.co/datasets/huggingface-projects/drlc-leaderboard-data"
DATASET_REPO_ID = "huggingface-projects/drlc-leaderboard-data"
HF_TOKEN = os.environ.get("HF_TOKEN")

block = gr.Blocks()
api = HfApi(token=HF_TOKEN)

# Define RL environments
rl_envs = [
    {"rl_env_beautiful": "LunarLander-v2 πŸš€", "rl_env": "LunarLander-v2", "video_link": "", "global": None},
    {"rl_env_beautiful": "CartPole-v1", "rl_env": "CartPole-v1", "video_link": "https://huggingface.co/sb3/ppo-CartPole-v1/resolve/main/replay.mp4", "global": None},
    {"rl_env_beautiful": "FrozenLake-v1-4x4-no_slippery ❄️", "rl_env": "FrozenLake-v1-4x4-no_slippery", "video_link": "", "global": None},
    {"rl_env_beautiful": "FrozenLake-v1-8x8-no_slippery ❄️", "rl_env": "FrozenLake-v1-8x8-no_slippery", "video_link": "", "global": None},
    {"rl_env_beautiful": "FrozenLake-v1-4x4 ❄️", "rl_env": "FrozenLake-v1-4x4", "video_link": "", "global": None},
    {"rl_env_beautiful": "FrozenLake-v1-8x8 ❄️", "rl_env": "FrozenLake-v1-8x8", "video_link": "", "global": None},
    {"rl_env_beautiful": "Taxi-v3 πŸš–", "rl_env": "Taxi-v3", "video_link": "", "global": None},
    {"rl_env_beautiful": "CarRacing-v0 🏎️", "rl_env": "CarRacing-v0", "video_link": "", "global": None},
    {"rl_env_beautiful": "CarRacing-v2 🏎️", "rl_env": "CarRacing-v2", "video_link": "", "global": None},
    {"rl_env_beautiful": "MountainCar-v0 ⛰️", "rl_env": "MountainCar-v0", "video_link": "", "global": None},
    {"rl_env_beautiful": "SpaceInvadersNoFrameskip-v4 πŸ‘Ύ", "rl_env": "SpaceInvadersNoFrameskip-v4", "video_link": "", "global": None},
    {"rl_env_beautiful": "PongNoFrameskip-v4 🎾", "rl_env": "PongNoFrameskip-v4", "video_link": "", "global": None},
    {"rl_env_beautiful": "BreakoutNoFrameskip-v4 🧱", "rl_env": "BreakoutNoFrameskip-v4", "video_link": "", "global": None},
    {"rl_env_beautiful": "QbertNoFrameskip-v4 🐦", "rl_env": "QbertNoFrameskip-v4", "video_link": "", "global": None},
    {"rl_env_beautiful": "BipedalWalker-v3", "rl_env": "BipedalWalker-v3", "video_link": "", "global": None},
    {"rl_env_beautiful": "Walker2DBulletEnv-v0", "rl_env": "Walker2DBulletEnv-v0", "video_link": "", "global": None},
    {"rl_env_beautiful": "AntBulletEnv-v0", "rl_env": "AntBulletEnv-v0", "video_link": "", "global": None},
    {"rl_env_beautiful": "HalfCheetahBulletEnv-v0", "rl_env": "HalfCheetahBulletEnv-v0", "video_link": "", "global": None},
    {"rl_env_beautiful": "PandaReachDense-v2", "rl_env": "PandaReachDense-v2", "video_link": "", "global": None},
    {"rl_env_beautiful": "PandaReachDense-v3", "rl_env": "PandaReachDense-v3", "video_link": "", "global": None},
    {"rl_env_beautiful": "Pixelcopter-PLE-v0", "rl_env": "Pixelcopter-PLE-v0", "video_link": "", "global": None}
]

# -------------------- Utility Functions --------------------

def restart():
    """Restart the Hugging Face Space."""
    print("RESTARTING SPACE...")
    api.restart_space(repo_id="huggingface-projects/Deep-Reinforcement-Learning-Leaderboard")

def download_leaderboard_dataset():
    """Download leaderboard dataset once at startup."""
    print("Downloading leaderboard dataset...")
    return snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")

def get_metadata(model_id):
    """Fetch metadata for a given model from Hugging Face."""
    try:
        readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180)
        return metadata_load(readme_path)
    except requests.exceptions.HTTPError:
        return None  # 404 README.md not found

def parse_metrics_accuracy(meta):
    """Extract accuracy metrics from metadata."""
    if "model-index" not in meta:
        return None
    result = meta["model-index"][0]["results"]
    metrics = result[0]["metrics"]
    return metrics[0]["value"]

def parse_rewards(accuracy):
    """Extract mean and std rewards from accuracy metrics."""
    default_std = -1000
    default_reward = -1000
    if accuracy is not None:
        parsed = str(accuracy).split('+/-')
        mean_reward = float(parsed[0].strip()) if parsed[0] else default_reward
        std_reward = float(parsed[1].strip()) if len(parsed) > 1 else 0
    else:
        mean_reward, std_reward = default_reward, default_std
    return mean_reward, std_reward

def get_model_ids(rl_env):
    """Retrieve models matching the given RL environment."""
    return [x.modelId for x in api.list_models(filter=rl_env)]

def update_leaderboard_dataset_parallel(rl_env, path):
    """Parallelized update of leaderboard dataset for a given RL environment."""
    model_ids = get_model_ids(rl_env)

    def process_model(model_id):
        meta = get_metadata(model_id)
        if not meta:
            return None
        user_id = model_id.split('/')[0]
        row = {
            "User": user_id,
            "Model": model_id,
            "Results": None,
            "Mean Reward": None,
            "Std Reward": None
        }
        accuracy = parse_metrics_accuracy(meta)
        mean_reward, std_reward = parse_rewards(accuracy)
        row["Results"] = mean_reward - std_reward
        row["Mean Reward"] = mean_reward
        row["Std Reward"] = std_reward
        return row

    data = list(thread_map(process_model, model_ids, desc="Processing models"))
    data = [row for row in data if row is not None]

    ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
    ranked_dataframe.to_csv(os.path.join(path, f"{rl_env}.csv"), index=False)

    return ranked_dataframe

def rank_dataframe(dataframe):
    """Sort models by results and assign ranking."""
    dataframe = dataframe.sort_values(by=['Results', 'User', 'Model'], ascending=False)
    dataframe.insert(0, 'Ranking', range(1, len(dataframe) + 1))
    return dataframe

def run_update_dataset():
    """Update dataset periodically using the scheduler."""
    path_ = download_leaderboard_dataset()
    for env in rl_envs:
        update_leaderboard_dataset_parallel(env["rl_env"], path_)

    print("Uploading updated dataset...")
    api.upload_folder(
        folder_path=path_,
        repo_id=DATASET_REPO_ID,
        repo_type="dataset",
        commit_message="Update dataset"
    )

def filter_data(rl_env, path, user_id):
    """Filter dataset for a specific user ID."""
    data_df = pd.read_csv(os.path.join(path, f"{rl_env}.csv"))
    return data_df[data_df["User"] == user_id]

# -------------------- Gradio UI --------------------

print("Initializing dataset...")
path_ = download_leaderboard_dataset()

with block:
    gr.Markdown("""
    # πŸ† Deep Reinforcement Learning Course Leaderboard πŸ†

    This leaderboard displays trained agents from the [Deep Reinforcement Learning Course](https://huggingface.co/learn/deep-rl-course/unit0/introduction?fw=pt).
    
    **Models are ranked using `mean_reward - std_reward`.**
    
    If you can't find your model, please wait for the next update (every 2 hours).
    """)

    grpath = gr.State(path_)  # Store dataset path as a state variable

    for env in rl_envs:
        with gr.TabItem(env["rl_env_beautiful"]):
            gr.Markdown(f"## {env['rl_env_beautiful']}")
            user_id = gr.Textbox(label="Your user ID")
            search_btn = gr.Button("Search πŸ”Ž")
            reset_btn = gr.Button("Clear Search")
            env_state = gr.State(env["rl_env"])  # Store environment name as a state variable

            gr_dataframe = gr.Dataframe(
                value=pd.read_csv(os.path.join(path_, f"{env['rl_env']}.csv")),
                headers=["Ranking πŸ†", "User πŸ€—", "Model πŸ€–", "Results", "Mean Reward", "Std Reward"],
                datatype=["number", "markdown", "markdown", "number", "number", "number"],
                # row_count=(100, 'fixed')
                row_count=(100,"dynamic")  # Allows displaying all rows dynamically

            )

            # βœ… Corrected: Use `gr.State()` for env["rl_env"] and `grpath`
            search_btn.click(fn=filter_data, inputs=[env_state, grpath, user_id], outputs=gr_dataframe)
            reset_btn.click(fn=lambda: pd.read_csv(os.path.join(path_, f"{env['rl_env']}.csv")), inputs=[], outputs=gr_dataframe)


# -------------------- Scheduler --------------------

scheduler = BackgroundScheduler()
scheduler.add_job(run_update_dataset, 'interval', hours=2)  # Update dataset every 2 hours
scheduler.add_job(restart, 'interval', hours=3)  # Restart space every 3 hours
scheduler.start()

block.launch()