Spaces:
Sleeping
Sleeping
import cv2 | |
import gradio as gr | |
import time | |
from huggingface_sb3 import load_from_hub | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.env_util import make_atari_env | |
from stable_baselines3.common.vec_env import VecFrameStack | |
from stable_baselines3.common.env_util import make_atari_env | |
max_steps = 5000 # Let's try with 5000 steps. | |
# Loading functions were taken from Edward Beeching code | |
def load_env(env_name): | |
env = make_atari_env(env_name, n_envs=1) | |
env = VecFrameStack(env, n_stack=4) | |
return env | |
def load_model(env_name): | |
custom_objects = { | |
"learning_rate": 0.0, | |
"lr_schedule": lambda _: 0.0, | |
"clip_range": lambda _: 0.0, | |
} | |
checkpoint = load_from_hub( | |
f"ThomasSimonini/ppo-{env_name}", | |
f"ppo-{env_name}.zip", | |
) | |
model = PPO.load(checkpoint, custom_objects=custom_objects) | |
return model | |
def replay(env_name, time_sleep): | |
max_steps = 500 | |
env = load_env(env_name) | |
model = load_model(env_name) | |
#for i in range(num_episodes): | |
obs = env.reset() | |
done = False | |
i = 0 | |
while not done: | |
i+= 1 | |
if i < max_steps: | |
frame = env.render(mode="rgb_array") | |
action, _states = model.predict(obs) | |
obs, reward, done, info = env.step([action]) | |
time.sleep(time_sleep) | |
yield frame | |
else: | |
break | |
demo = gr.Interface( | |
replay, | |
[gr.Dropdown(["SpaceInvadersNoFrameskip-v4", | |
"PongNoFrameskip-v4", | |
"SeaquestNoFrameskip-v4", | |
"QbertNoFrameskip-v4", | |
]), | |
#gr.Slider(100, 10000, value=500), | |
gr.Slider(0.01, 1, value=0.05), | |
#gr.Slider(1, 20, value=5) | |
], | |
gr.Image(shape=(300, 150)), | |
title="Watch Agents playing Atari games 🤖", | |
description="Select an environment to watch a Hugging Face's trained deep reinforcement learning agent.", | |
article = "time_sleep is the time delay between each frame (0.05 by default)." | |
).launch().queue(max_concurrency=20, max_size=20) |