File size: 946 Bytes
821e8de
 
f3feb7a
 
 
 
 
 
cbdef19
f3feb7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbdef19
f3feb7a
 
821e8de
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr

def replay(name):
	# Get the correct model
	if (option == "LunarLander-v2 πŸš€πŸ‘©β€πŸš€"):
  		return "./LunarLander-v2.mp4"
	elif(option == "CartPole-v1 πŸ•ΉοΈ"):
  		return "./CartPole-v1.mp4"
	elif(option == "Atari Space Invaders πŸ‘Ύ"):
  		return "./SpaceInvadersNoFrameskip-v4.mp4"
	
"""
TODO: Next version with live video generation
def replay_classical(hf_model_filename, hf_model_id):
	import gym
	from stable_baselines3 import PPO
	from stable_baselines3.common.evaluation import evaluate_policy

	model = PPO.load_from_huggingface(hf_model_id,hf_model_filename)

	eval_env = gym.make(option)


def replay_atari(hf_model_filename, hf_model_id):
"""

#iface = gr.Interface(fn=, inputs="dropdown", outputs="text")



iface = gr.Interface(
    replay,
    [
        gr.inputs.Dropdown(["Atari Space Invaders πŸ‘Ύ", "CartPole-v1 πŸ•ΉοΈ", "LunarLander-v2 πŸš€πŸ‘©β€πŸš€ "]),
    ],
    "video"
   
)

iface.launch()