Spaces:
Paused
Paused
| from typing import Dict, cast | |
| import attr | |
| from mlagents.torch_utils import torch, default_device | |
| from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil | |
| from mlagents_envs.timers import timed | |
| from mlagents.trainers.policy.torch_policy import TorchPolicy | |
| from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer | |
| from mlagents.trainers.settings import ( | |
| TrainerSettings, | |
| OnPolicyHyperparamSettings, | |
| ScheduleType, | |
| ) | |
| from mlagents.trainers.torch_entities.networks import ValueNetwork | |
| from mlagents.trainers.torch_entities.agent_action import AgentAction | |
| from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs | |
| from mlagents.trainers.torch_entities.utils import ModelUtils | |
| from mlagents.trainers.trajectory import ObsUtil | |
| class PPOSettings(OnPolicyHyperparamSettings): | |
| beta: float = 5.0e-3 | |
| epsilon: float = 0.2 | |
| lambd: float = 0.95 | |
| num_epoch: int = 3 | |
| shared_critic: bool = False | |
| learning_rate_schedule: ScheduleType = ScheduleType.LINEAR | |
| beta_schedule: ScheduleType = ScheduleType.LINEAR | |
| epsilon_schedule: ScheduleType = ScheduleType.LINEAR | |
| class TorchPPOOptimizer(TorchOptimizer): | |
| def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): | |
| """ | |
| Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy. | |
| The PPO optimizer has a value estimator and a loss function. | |
| :param policy: A TorchPolicy object that will be updated by this PPO Optimizer. | |
| :param trainer_params: Trainer parameters dictionary that specifies the | |
| properties of the trainer. | |
| """ | |
| # Create the graph here to give more granular control of the TF graph to the Optimizer. | |
| super().__init__(policy, trainer_settings) | |
| reward_signal_configs = trainer_settings.reward_signals | |
| reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] | |
| self.hyperparameters: PPOSettings = cast( | |
| PPOSettings, trainer_settings.hyperparameters | |
| ) | |
| params = list(self.policy.actor.parameters()) | |
| if self.hyperparameters.shared_critic: | |
| self._critic = policy.actor | |
| else: | |
| self._critic = ValueNetwork( | |
| reward_signal_names, | |
| policy.behavior_spec.observation_specs, | |
| network_settings=trainer_settings.network_settings, | |
| ) | |
| self._critic.to(default_device()) | |
| params += list(self._critic.parameters()) | |
| self.decay_learning_rate = ModelUtils.DecayedValue( | |
| self.hyperparameters.learning_rate_schedule, | |
| self.hyperparameters.learning_rate, | |
| 1e-10, | |
| self.trainer_settings.max_steps, | |
| ) | |
| self.decay_epsilon = ModelUtils.DecayedValue( | |
| self.hyperparameters.epsilon_schedule, | |
| self.hyperparameters.epsilon, | |
| 0.1, | |
| self.trainer_settings.max_steps, | |
| ) | |
| self.decay_beta = ModelUtils.DecayedValue( | |
| self.hyperparameters.beta_schedule, | |
| self.hyperparameters.beta, | |
| 1e-5, | |
| self.trainer_settings.max_steps, | |
| ) | |
| self.optimizer = torch.optim.Adam( | |
| params, lr=self.trainer_settings.hyperparameters.learning_rate | |
| ) | |
| self.stats_name_to_update_name = { | |
| "Losses/Value Loss": "value_loss", | |
| "Losses/Policy Loss": "policy_loss", | |
| } | |
| self.stream_names = list(self.reward_signals.keys()) | |
| def critic(self): | |
| return self._critic | |
| def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: | |
| """ | |
| Performs update on model. | |
| :param batch: Batch of experiences. | |
| :param num_sequences: Number of sequences to process. | |
| :return: Results of update. | |
| """ | |
| # Get decayed parameters | |
| decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) | |
| decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step()) | |
| decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) | |
| returns = {} | |
| old_values = {} | |
| for name in self.reward_signals: | |
| old_values[name] = ModelUtils.list_to_tensor( | |
| batch[RewardSignalUtil.value_estimates_key(name)] | |
| ) | |
| returns[name] = ModelUtils.list_to_tensor( | |
| batch[RewardSignalUtil.returns_key(name)] | |
| ) | |
| n_obs = len(self.policy.behavior_spec.observation_specs) | |
| current_obs = ObsUtil.from_buffer(batch, n_obs) | |
| # Convert to tensors | |
| current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] | |
| act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK]) | |
| actions = AgentAction.from_buffer(batch) | |
| memories = [ | |
| ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) | |
| for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) | |
| ] | |
| if len(memories) > 0: | |
| memories = torch.stack(memories).unsqueeze(0) | |
| # Get value memories | |
| value_memories = [ | |
| ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) | |
| for i in range( | |
| 0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length | |
| ) | |
| ] | |
| if len(value_memories) > 0: | |
| value_memories = torch.stack(value_memories).unsqueeze(0) | |
| run_out = self.policy.actor.get_stats( | |
| current_obs, | |
| actions, | |
| masks=act_masks, | |
| memories=memories, | |
| sequence_length=self.policy.sequence_length, | |
| ) | |
| log_probs = run_out["log_probs"] | |
| entropy = run_out["entropy"] | |
| values, _ = self.critic.critic_pass( | |
| current_obs, | |
| memories=value_memories, | |
| sequence_length=self.policy.sequence_length, | |
| ) | |
| old_log_probs = ActionLogProbs.from_buffer(batch).flatten() | |
| log_probs = log_probs.flatten() | |
| loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) | |
| value_loss = ModelUtils.trust_region_value_loss( | |
| values, old_values, returns, decay_eps, loss_masks | |
| ) | |
| policy_loss = ModelUtils.trust_region_policy_loss( | |
| ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]), | |
| log_probs, | |
| old_log_probs, | |
| loss_masks, | |
| decay_eps, | |
| ) | |
| loss = ( | |
| policy_loss | |
| + 0.5 * value_loss | |
| - decay_bet * ModelUtils.masked_mean(entropy, loss_masks) | |
| ) | |
| # Set optimizer learning rate | |
| ModelUtils.update_learning_rate(self.optimizer, decay_lr) | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| self.optimizer.step() | |
| update_stats = { | |
| # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow. | |
| # TODO: After PyTorch is default, change to something more correct. | |
| "Losses/Policy Loss": torch.abs(policy_loss).item(), | |
| "Losses/Value Loss": value_loss.item(), | |
| "Policy/Learning Rate": decay_lr, | |
| "Policy/Epsilon": decay_eps, | |
| "Policy/Beta": decay_bet, | |
| } | |
| return update_stats | |
| # TODO move module update into TorchOptimizer for reward_provider | |
| def get_modules(self): | |
| modules = { | |
| "Optimizer:value_optimizer": self.optimizer, | |
| "Optimizer:critic": self._critic, | |
| } | |
| for reward_provider in self.reward_signals.values(): | |
| modules.update(reward_provider.get_modules()) | |
| return modules | |