freddyaboulton HF staff commited on
Commit
43351bd
·
1 Parent(s): 38753f4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ import time
4
+
5
+ from huggingface_sb3 import load_from_hub
6
+
7
+ from stable_baselines3 import PPO
8
+ from stable_baselines3.common.env_util import make_atari_env
9
+ from stable_baselines3.common.vec_env import VecFrameStack
10
+
11
+ from stable_baselines3.common.env_util import make_atari_env
12
+
13
+ max_steps = 5000 # Let's try with 5000 steps.
14
+
15
+ # Loading functions were taken from Edward Beeching code
16
+ def load_env(env_name):
17
+ env = make_atari_env(env_name, n_envs=1)
18
+ env = VecFrameStack(env, n_stack=4)
19
+ return env
20
+
21
+ def load_model(env_name):
22
+ custom_objects = {
23
+ "learning_rate": 0.0,
24
+ "lr_schedule": lambda _: 0.0,
25
+ "clip_range": lambda _: 0.0,
26
+ }
27
+
28
+ checkpoint = load_from_hub(
29
+ f"ThomasSimonini/ppo-{env_name}",
30
+ f"ppo-{env_name}.zip",
31
+ )
32
+
33
+ model = PPO.load(checkpoint, custom_objects=custom_objects)
34
+
35
+ return model
36
+
37
+ def replay(env_name, time_sleep):
38
+ max_steps = 500
39
+ env = load_env(env_name)
40
+ model = load_model(env_name)
41
+ #for i in range(num_episodes):
42
+ obs = env.reset()
43
+ done = False
44
+ i = 0
45
+ while not done:
46
+ i+= 1
47
+ if i < max_steps:
48
+ frame = env.render(mode="rgb_array")
49
+ action, _states = model.predict(obs)
50
+ obs, reward, done, info = env.step([action])
51
+ time.sleep(time_sleep)
52
+ yield frame
53
+ else:
54
+ break
55
+
56
+ demo = gr.Interface(
57
+ replay,
58
+ [gr.Dropdown(["SpaceInvadersNoFrameskip-v4",
59
+ "PongNoFrameskip-v4",
60
+ "SeaquestNoFrameskip-v4",
61
+ "QbertNoFrameskip-v4",
62
+ ]),
63
+ #gr.Slider(100, 10000, value=500),
64
+ gr.Slider(0.01, 1, value=0.05),
65
+ #gr.Slider(1, 20, value=5)
66
+ ],
67
+ gr.Image(),
68
+ title="Watch Agents playing Atari games 🤖",
69
+ description="Select an environment to watch a Hugging Face's trained deep reinforcement learning agent.",
70
+ article = "time_sleep is the time delay between each frame (0.05 by default)."
71
+ ).launch().queue()