sighmon commited on
Commit
ea37acb
·
verified ·
1 Parent(s): 30b8bf0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +73 -5
README.md CHANGED
@@ -26,12 +26,80 @@ This is a trained model of a **A2C** agent playing **PandaReachDense-v3**
26
  using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
27
 
28
  ## Usage (with Stable-baselines3)
29
- TODO: Add your code
30
-
31
 
32
  ```python
33
- from stable_baselines3 import ...
34
- from huggingface_sb3 import load_from_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- ...
 
 
 
 
 
 
 
 
37
  ```
 
26
  using the [stable-baselines3 library](https://github.com/DLR-RM/stable-baselines3).
27
 
28
  ## Usage (with Stable-baselines3)
 
 
29
 
30
  ```python
31
+ import os
32
+
33
+ import gymnasium as gym
34
+ import panda_gym
35
+
36
+ from huggingface_sb3 import load_from_hub, package_to_hub
37
+
38
+ from stable_baselines3 import A2C
39
+ from stable_baselines3.common.evaluation import evaluate_policy
40
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
41
+ from stable_baselines3.common.env_util import make_vec_env
42
+
43
+ from huggingface_hub import notebook_login
44
+
45
+
46
+ env_id = "PandaReachDense-v3"
47
+
48
+ # Create the env
49
+ env = gym.make(env_id)
50
+
51
+ # Get the state space and action space
52
+ s_size = env.observation_space.shape
53
+ a_size = env.action_space
54
+
55
+ env = make_vec_env(env_id, n_envs=4)
56
+
57
+ # Adding this wrapper to normalize the observation and the reward
58
+ env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.)
59
+
60
+ model = A2C(
61
+ policy = "MultiInputPolicy",
62
+ env = env,
63
+ verbose=1,
64
+ )
65
+
66
+ # Train
67
+ model.learn(1_000_000)
68
+
69
+ # Save the model and VecNormalize statistics when saving the agent
70
+ model.save("a2c-PandaReachDense-v3")
71
+ env.save("vec_normalize.pkl")
72
+
73
+ from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
74
+
75
+ # Load the saved statistics
76
+ eval_env = DummyVecEnv([lambda: gym.make("PandaReachDense-v3")])
77
+ eval_env = VecNormalize.load("vec_normalize.pkl", eval_env)
78
+
79
+ # We need to override the render_mode
80
+ eval_env.render_mode = "rgb_array"
81
+
82
+ # do not update them at test time
83
+ eval_env.training = False
84
+ # reward normalization is not needed at test time
85
+ eval_env.norm_reward = False
86
+
87
+ # Load the agent
88
+ model = A2C.load("a2c-PandaReachDense-v3")
89
+
90
+ mean_reward, std_reward = evaluate_policy(model, eval_env)
91
+
92
+ print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
93
+
94
+ from huggingface_sb3 import package_to_hub
95
 
96
+ package_to_hub(
97
+ model=model,
98
+ model_name=f"a2c-{env_id}",
99
+ model_architecture="A2C",
100
+ env_id=env_id,
101
+ eval_env=eval_env,
102
+ repo_id=f"sighmon/a2c-{env_id}",
103
+ commit_message="With working video",
104
+ )
105
  ```