|
--- |
|
library_name: stable-baselines3 |
|
tags: |
|
- FetchPickAndPlace-v4 |
|
- deep-reinforcement-learning |
|
- reinforcement-learning |
|
- stable-baselines3 |
|
model-index: |
|
- name: SAC |
|
results: |
|
- task: |
|
type: reinforcement-learning |
|
name: reinforcement-learning |
|
dataset: |
|
name: FetchPickAndPlace-v4 |
|
type: FetchPickAndPlace-v4 |
|
metrics: |
|
- type: mean_reward |
|
value: -9.70 +/- 4.17 |
|
name: mean_reward |
|
verified: false |
|
--- |
|
|
|
# SAC + HER Agent for FetchPickAndPlace-v4 |
|
|
|
## Model Overview |
|
|
|
This repository contains a Soft Actor-Critic (SAC) agent trained with Hindsight Experience Replay (HER) on the `FetchPickAndPlace-v4` environment from `gymnasium-robotics`. The agent learns to pick and place objects using sparse or dense rewards, and is suitable for robotic manipulation research. |
|
|
|
- **Algorithm:** Soft Actor-Critic (SAC) |
|
- **Replay Buffer:** Hindsight Experience Replay (HER) |
|
- **Environment:** FetchPickAndPlace-v4 (`gymnasium-robotics`) |
|
- **Framework:** Stable Baselines3 |
|
|
|
## Training Details |
|
|
|
- **Total Timesteps:** 500,000 |
|
- **Evaluation Frequency:** Every 2,000 steps (15 episodes per eval) |
|
- **Checkpoint Frequency:** Every 50,000 steps (model + replay buffer) |
|
- **Seed:** 42 |
|
- **Dense Shaping:** `False` (can be enabled with wrapper) |
|
- **Device:** CUDA if available, otherwise auto |
|
|
|
### Hyperparameters |
|
|
|
| Parameter | Value | |
|
|--------------------------|----------------------| |
|
| Algorithm | SAC | |
|
| Policy | MultiInputPolicy | |
|
| Replay Buffer | HER | |
|
| n_sampled_goal | 4 | |
|
| goal_selection_strategy | future | |
|
| Batch Size | 512 | |
|
| Buffer Size | 1,000,000 | |
|
| Learning Rate | 1e-3 | |
|
| Gamma | 0.95 | |
|
| Tau | 0.05 | |
|
| Entropy Coefficient | auto | |
|
| Train Frequency | 1 step | |
|
| Gradient Steps | 1 | |
|
| Tensorboard Log | logs_pnp_sac_her/tb | |
|
| Seed | 42 | |
|
| Device | CUDA/Auto | |
|
| Dense Shaping | False (default) | |
|
|
|
## Files |
|
|
|
- `sac_her_pnp.zip`: Final trained SAC model |
|
- `ckpt_sac_her_250000_steps.zip`: Latest checkpoint |
|
- `replay_buffer.pkl`: Replay buffer for continued training |
|
- `replay.mp4`: Replay video of agent performance (manual generation recommended) |
|
- `README.md`: This model card |
|
|
|
## Usage |
|
|
|
To load and use the model for inference: |
|
|
|
```python |
|
from stable_baselines3 import SAC |
|
import gymnasium as gym |
|
import gymnasium_robotics |
|
|
|
env = gym.make("FetchPickAndPlace-v4", render_mode="rgb_array") |
|
model = SAC.load("path/to/sac_her_pnp.zip", env=env) |
|
|
|
obs, info = env.reset() |
|
done = False |
|
while not done: |
|
action, _ = model.predict(obs, deterministic=True) |
|
obs, reward, done, truncated, info = env.step(action) |
|
env.render() |
|
``` |
|
|
|
## Evaluation |
|
|
|
To evaluate the agent over multiple episodes: |
|
|
|
```python |
|
from stable_baselines3 import SAC |
|
import gymnasium as gym |
|
import gymnasium_robotics |
|
|
|
env = gym.make("FetchPickAndPlace-v4", render_mode="human") |
|
model = SAC.load("path/to/sac_her_pnp.zip", env=env) |
|
|
|
num_episodes = 10 |
|
for ep in range(num_episodes): |
|
obs, info = env.reset() |
|
done = False |
|
truncated = False |
|
episode_reward = 0 |
|
while not (done or truncated): |
|
action, _ = model.predict(obs, deterministic=True) |
|
obs, reward, done, truncated, info = env.step(action) |
|
env.render() |
|
episode_reward += reward |
|
print(f"Episode {ep+1} reward: {episode_reward}") |
|
env.close() |
|
``` |
|
|
|
## Replay Video |
|
|
|
If `replay.mp4` is not present, you can manually generate it: |
|
|
|
```python |
|
import gymnasium as gym |
|
import gymnasium_robotics |
|
from stable_baselines3 import SAC |
|
import moviepy.editor as mpy |
|
|
|
env = gym.make("FetchPickAndPlace-v4", render_mode="rgb_array") |
|
model = SAC.load("path/to/sac_her_pnp.zip", env=env) |
|
|
|
frames = [] |
|
obs, info = env.reset() |
|
done = False |
|
truncated = False |
|
step = 0 |
|
max_steps = 1000 |
|
|
|
while not (done or truncated) and step < max_steps: |
|
frame = env.render() |
|
frames.append(frame) |
|
action, _ = model.predict(obs, deterministic=True) |
|
obs, reward, done, truncated, info = env.step(action) |
|
step += 1 |
|
|
|
env.close() |
|
clip = mpy.ImageSequenceClip(frames, fps=30) |
|
clip.write_videofile("replay.mp4", codec="libx264") |
|
``` |
|
|
|
## Continued Training |
|
|
|
To continue training from a checkpoint: |
|
|
|
```python |
|
from stable_baselines3 import SAC |
|
import gymnasium as gym |
|
import gymnasium_robotics |
|
|
|
env = gym.make("FetchPickAndPlace-v4", render_mode=None) |
|
model = SAC.load("logs_pnp_sac_her/ckpt_sac_her_250000_steps.zip", env=env) |
|
model.learn(total_timesteps=500_000, reset_num_timesteps=False) |
|
``` |
|
|
|
## Citation |
|
|
|
If you use this model, please cite: |
|
|
|
``` |
|
@misc{IntelliGrow_FetchPickAndPlace_SAC_HER, |
|
title={SAC + HER Agent for FetchPickAndPlace-v4}, |
|
author={IntelliGrow}, |
|
year={2025}, |
|
howpublished={Hugging Face Hub}, |
|
url={https://huggingface.co/IntelliGrow/FetchPickAndPlace-v4} |
|
} |
|
``` |
|
|
|
## License |
|
|
|
MIT License |
|
|
|
--- |
|
|
|
**Contact:** For questions or issues, open an issue on the [Hugging Face repository](https://huggingface.co/IntelliGrow/FetchPickAndPlace-v4). |