File size: 5,870 Bytes
821e8de
4b8b1d0
 
821e8de
4b8b1d0
 
f3feb7a
 
4b8b1d0
f3feb7a
4b8b1d0
cbdef19
4b8b1d0
f3feb7a
4b8b1d0
 
 
 
 
 
1c71ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3feb7a
 
1c71ebe
 
f3feb7a
1c71ebe
f3feb7a
1c71ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3feb7a
 
 
1c71ebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3feb7a
 
 
 
4b8b1d0
f3feb7a
4b8b1d0
 
 
 
 
1c71ebe
 
 
4b8b1d0
 
 
 
1c71ebe
 
 
4b8b1d0
 
 
f3feb7a
821e8de
4b8b1d0
1c71ebe
4b8b1d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import os
from moviepy.editor import *

def replay(option):
	path = ""
	# Get the correct model
	if (option == "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"):
  		path = "./LunarLander-v2.mp4"
	elif(option == "CartPole-v1 πŸ•ΉοΈ"):
  		path = "./CartPole-v1.mp4"
	elif(option == "Atari Space Invaders πŸ‘Ύ"):
  		path = "./SpaceInvadersNoFrameskip-v4.mp4"
	
	
  	# The only turnaround I found (Base64 video pb)
	videoclip = VideoFileClip(path)
	videoclip.write_videofile("new_filename.mp4")
	return 'new_filename.mp4'

iface = gr.Interface(
    replay,
    [
        gr.inputs.Dropdown(["Atari Space Invaders πŸ‘Ύ", "CartPole-v1 πŸ•ΉοΈ", "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"]),
    ],
    "video",
     title = 'Stable Baselines 3 with πŸ€—',
            description = '',
             article = 
                        '''<div>
                        	<p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p>
                            <p style="text-align: center"> Select the trained agent you want to watch perform.
                            These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p>
                            <p>
                            There are currently 3 models:
                            <ul>
                            	<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li>
                            	<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li>
                            	<li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li>
                            </ul>
                        </div>'''
            )
   

iface.launch()

"""
TODO: Next version with live video generation
import gradio as gr
import os

from Recorder import Recorder

from stable_baselines3 import PPO


#The Agent plays and we generate the video
def replay(option):
  video_path = ""
  # Get the correct model
  if (option == "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"):
    env_name = "Lunar Lander v2"
    agent_name = "PPO"
    print("TEST")
    hf_model_filename = "LunarLander-v2"
    hf_model_id = "ThomasSimonini/stable-baselines3-ppo-LunarLander-v2"
    video_path = replay_gym(hf_model_filename, hf_model_id)
  elif(option == "CartPole-v1 πŸ•ΉοΈ"):
      hf_model_filename = "CartPole-v1"
      hf_model_id = "ThomasSimonini/stable-baselines3-ppo-CartPole-v1"
      video_path = replay_gym(hf_model_filename, hf_model_id)
  elif(option == "Atari Space Invaders πŸ‘Ύ"):
    hf_model_filename = "SpaceInvadersNoFrameskip-v4"
    hf_model_id = "ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4"
    video_path = replay_atari(hf_model_filename, hf_model_id)
      #video_path = "./SpaceInvadersNoFrameskip-v4.mp4"

  return video_path


def replay_gym(hf_model_filename, hf_model_id):
  import gym
  from stable_baselines3.common.evaluation import evaluate_policy


  model = PPO.load_from_huggingface(hf_model_id,hf_model_filename)

  eval_env = gym.make(hf_model_filename)

  directory = './video'
  env = Recorder(eval_env, directory)

  obs = env.reset()
  done = False
  while not done:
      action, _state = model.predict(obs)
      obs, reward, done, info = env.step(action)
  clip = env.play()
  return clip


def replay_atari(hf_model_filename, hf_model_id):
  os.system("python -m atari_py.import_roms \"content/atari_roms\"")
  import gym
  from stable_baselines3.common.env_util import make_atari_env
  from stable_baselines3.common.vec_env import VecFrameStack

  from stable_baselines3.common.evaluation import evaluate_policy

  model = PPO.load_from_huggingface(hf_model_id, hf_model_filename)


  eval_env = make_atari_env(hf_model_filename, n_envs=1, seed=0)
  eval_env = VecFrameStack(eval_env, n_stack=4)

  model = PPO.load_from_huggingface(hf_model_id, hf_model_filename)

  import gym
  directory = './video'
  env = Recorder(eval_env, directory)

  obs = env.reset()
  done = False
  while not done:
      action, _state = model.predict(obs)
      obs, reward, done, info = env.step(action)
  clip = env.play()
  return clip



iface = gr.Interface(
    replay,
    [
        gr.inputs.Dropdown(["Atari Space Invaders πŸ‘Ύ", "CartPole-v1 πŸ•ΉοΈ", "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"]),
    ],
    "video",
     title = 'Stable Baselines 3 with πŸ€—',
            description = '',
             article = 
                        '''<div>
                          <p style="text-align: center">This version of the RL library allows you to load models directly from the Hugging Face Hub</p>
                            <p style="text-align: center"> Select the trained agent you want to watch perform. We record your agent playing.
                            <p style="text-align: center"> Don't forget to <b>click on clear between each record.</b> </p>
                            These models are from <a href="https://github.com/araffin/rl-baselines-zoo">Stable Baseline Zoo</a></p>
                            <p>
                            There are currently 3 models:
                            <ul>
                              <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-SpaceInvadersNoFrameskip-v4">PPO SpaceInvadersNoFrameskip-v4</a></li>
                              <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-LunarLander-v2">PPO LunarLander-v2</a></li>
                              <li><a href="https://huggingface.co/ThomasSimonini/stable-baselines3-ppo-CartPole-v1">PPO CartPole-v1</a></li>
                            </ul>
                        </div>'''
            )
   

iface.launch()
"""