Spaces:
Runtime error
Runtime error
File size: 7,449 Bytes
e0f25ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import functools
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from flax.linen.initializers import constant, orthogonal
from typing import List, Sequence
import distrax
from kinetix.models.action_spaces import HybridActionDistribution, MultiDiscreteActionDistribution
class ScannedRNN(nn.Module):
@functools.partial(
nn.scan,
variable_broadcast="params",
in_axes=0,
out_axes=0,
split_rngs={"params": False},
)
@nn.compact
def __call__(self, carry, x):
"""Applies the module."""
rnn_state = carry
ins, resets = x
rnn_state = jnp.where(
resets[:, np.newaxis],
self.initialize_carry(ins.shape[0], 256),
rnn_state,
)
new_rnn_state, y = nn.GRUCell(features=256)(rnn_state, ins)
return new_rnn_state, y
@staticmethod
def initialize_carry(batch_size, hidden_size=256):
# Use a dummy key since the default state init fn is just zeros.
cell = nn.GRUCell(features=256)
return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size))
class GeneralActorCriticRNN(nn.Module):
action_dim: Sequence[int]
fc_layer_depth: int
fc_layer_width: int
action_mode: str # "continuous" or "discrete" or "hybrid"
hybrid_action_continuous_dim: int
multi_discrete_number_of_dims_per_distribution: List[int]
add_generator_embedding: bool = False
generator_embedding_number_of_timesteps: int = 10
recurrent: bool = False
# Given an embedding, return the action/values, since this is shared across all models.
@nn.compact
def __call__(self, hidden, obs, embedding, dones, activation):
if self.add_generator_embedding:
raise NotImplementedError()
if self.recurrent:
rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)
actor_mean = embedding
critic = embedding
actor_mean_last = embedding
for _ in range(self.fc_layer_depth):
actor_mean = nn.Dense(
self.fc_layer_width,
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
)(actor_mean)
actor_mean = activation(actor_mean)
critic = nn.Dense(
self.fc_layer_width,
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
)(critic)
critic = activation(critic)
actor_mean_last = actor_mean
actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
if self.action_mode == "discrete":
pi = distrax.Categorical(logits=actor_mean)
elif self.action_mode == "continuous":
actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))
elif self.action_mode == "multi_discrete":
pi = MultiDiscreteActionDistribution(actor_mean, self.multi_discrete_number_of_dims_per_distribution)
else:
actor_mean_continuous = nn.Dense(
self.hybrid_action_continuous_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
)(actor_mean_last)
actor_mean_sigma = jnp.exp(
nn.Dense(self.hybrid_action_continuous_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(
actor_mean_last
)
)
pi = HybridActionDistribution(actor_mean, actor_mean_continuous, actor_mean_sigma)
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic)
return hidden, pi, jnp.squeeze(critic, axis=-1)
class ActorCriticPixelsRNN(nn.Module):
action_dim: Sequence[int]
fc_layer_depth: int
fc_layer_width: int
action_mode: str
hybrid_action_continuous_dim: int
multi_discrete_number_of_dims_per_distribution: List[int]
activation: str
add_generator_embedding: bool = False
generator_embedding_number_of_timesteps: int = 10
recurrent: bool = True
@nn.compact
def __call__(self, hidden, x, **kwargs):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh
og_obs, dones = x
if self.add_generator_embedding:
obs = og_obs.obs
else:
obs = og_obs
image = obs.image
global_info = obs.global_info
x = nn.Conv(features=16, kernel_size=(8, 8), strides=(4, 4))(image)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(4, 4), strides=(2, 2))(x)
x = nn.relu(x)
embedding = x.reshape(x.shape[0], x.shape[1], -1)
embedding = jnp.concatenate([embedding, global_info], axis=-1)
return GeneralActorCriticRNN(
action_dim=self.action_dim,
fc_layer_depth=self.fc_layer_depth,
fc_layer_width=self.fc_layer_width,
action_mode=self.action_mode,
hybrid_action_continuous_dim=self.hybrid_action_continuous_dim,
multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution,
add_generator_embedding=self.add_generator_embedding,
generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps,
recurrent=self.recurrent,
)(hidden, og_obs, embedding, dones, activation)
@staticmethod
def initialize_carry(batch_size, hidden_size=256):
return ScannedRNN.initialize_carry(batch_size, hidden_size)
class ActorCriticSymbolicRNN(nn.Module):
action_dim: Sequence[int]
fc_layer_width: int
action_mode: str
hybrid_action_continuous_dim: int
multi_discrete_number_of_dims_per_distribution: List[int]
fc_layer_depth: int
activation: str
add_generator_embedding: bool = False
generator_embedding_number_of_timesteps: int = 10
recurrent: bool = True
@nn.compact
def __call__(self, hidden, x):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh
og_obs, dones = x
if self.add_generator_embedding:
obs = og_obs.obs
else:
obs = og_obs
embedding = nn.Dense(
self.fc_layer_width,
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
)(obs)
embedding = nn.relu(embedding)
return GeneralActorCriticRNN(
action_dim=self.action_dim,
fc_layer_depth=self.fc_layer_depth,
fc_layer_width=self.fc_layer_width,
action_mode=self.action_mode,
hybrid_action_continuous_dim=self.hybrid_action_continuous_dim,
multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution,
add_generator_embedding=self.add_generator_embedding,
generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps,
recurrent=self.recurrent,
)(hidden, og_obs, embedding, dones, activation)
@staticmethod
def initialize_carry(batch_size, hidden_size=256):
return ScannedRNN.initialize_carry(batch_size, hidden_size)
|