import functools import jax.numpy as jnp import flax.linen as nn import numpy as np from flax.linen.initializers import constant, orthogonal from typing import List, Sequence import distrax import jax from kinetix.models.actor_critic import GeneralActorCriticRNN, ScannedRNN from kinetix.render.renderer_symbolic_entity import EntityObservation from flax.linen.attention import MultiHeadDotProductAttention class Gating(nn.Module): # code taken from https://github.com/dhruvramani/Transformers-RL/blob/master/layers.py d_input: int bg: float = 0.0 @nn.compact def __call__(self, x, y): r = jax.nn.sigmoid(nn.Dense(self.d_input, use_bias=False)(y) + nn.Dense(self.d_input, use_bias=False)(x)) z = jax.nn.sigmoid( nn.Dense(self.d_input, use_bias=False)(y) + nn.Dense(self.d_input, use_bias=False)(x) - self.param("gating_bias", constant(self.bg), (self.d_input,)) ) h = jnp.tanh(nn.Dense(self.d_input, use_bias=False)(y) + nn.Dense(self.d_input, use_bias=False)(r * x)) g = (1 - z) * x + (z * h) return g class transformer_layer(nn.Module): num_heads: int out_features: int qkv_features: int gating: bool = False gating_bias: float = 0.0 def setup(self): self.attention1 = MultiHeadDotProductAttention( num_heads=self.num_heads, qkv_features=self.qkv_features, out_features=self.out_features ) self.ln1 = nn.LayerNorm() self.dense1 = nn.Dense(self.out_features) self.dense2 = nn.Dense(self.out_features) self.ln2 = nn.LayerNorm() if self.gating: self.gate1 = Gating(self.out_features, self.gating_bias) self.gate2 = Gating(self.out_features, self.gating_bias) def __call__(self, queries: jnp.ndarray, mask: jnp.ndarray): # After reading the paper, this is what I think we should do: # First layernorm, then do attention queries_n = self.ln1(queries) y = self.attention1(queries_n, mask=mask) if self.gating: # and gate y = self.gate1(queries, jax.nn.relu(y)) else: y = queries + y # Dense after norming, crucially no relu. e = self.dense1(self.ln2(y)) if self.gating: # and gate again # This may be the wrong way around e = self.gate2(y, jax.nn.relu(e)) else: e = y + e return e class Transformer(nn.Module): encoder_size: int num_heads: int qkv_features: int num_layers: int gating: bool = False gating_bias: float = 0.0 def setup(self): # self.encoder = nn.Dense(self.encoder_size) # self.positional_encoding = PositionalEncoding(self.encoder_size, max_len=self.max_len) self.tf_layers = [ transformer_layer( num_heads=self.num_heads, qkv_features=self.qkv_features, out_features=self.encoder_size, gating=self.gating, gating_bias=self.gating_bias, ) for _ in range(self.num_layers) ] self.joint_layers = [nn.Dense(self.encoder_size) for _ in range(self.num_layers)] self.thruster_layers = [nn.Dense(self.encoder_size) for _ in range(self.num_layers)] # self.pos_emb=PositionalEmbedding(self.encoder_size) def __call__( self, shape_embeddings: jnp.ndarray, shape_attention_mask, joint_embeddings, joint_mask, joint_indexes, thruster_embeddings, thruster_mask, thruster_indexes, ): # forward eval so obs is only one timestep # encoded = self.encoder(shape_embeddings) # pos_embed=self.pos_emb(jnp.arange(1+memories.shape[-3],-1,-1))[:1+memories.shape[-3]] for tf_layer, joint_layer, thruster_layer in zip(self.tf_layers, self.joint_layers, self.thruster_layers): # Do attention shape_embeddings = tf_layer(shape_embeddings, shape_attention_mask) # Joints # T, B, 2J, (2SE + JE) @jax.vmap @jax.vmap def do_index2(to_ind, ind): return to_ind[ind] joint_shape_embeddings = jnp.concatenate( [ do_index2(shape_embeddings, joint_indexes[..., 0]), do_index2(shape_embeddings, joint_indexes[..., 1]), joint_embeddings, ], axis=-1, ) shape_joint_entity_delta = joint_layer(joint_shape_embeddings) * joint_mask[..., None] @jax.vmap @jax.vmap def add2(addee, index, adder): return addee.at[index].add(adder) # Thrusters thruster_shape_embeddings = jnp.concatenate( [ do_index2(shape_embeddings, thruster_indexes), thruster_embeddings, ], axis=-1, ) shape_thruster_entity_delta = thruster_layer(thruster_shape_embeddings) * thruster_mask[..., None] shape_embeddings = add2(shape_embeddings, joint_indexes[..., 0], shape_joint_entity_delta) shape_embeddings = add2(shape_embeddings, thruster_indexes, shape_thruster_entity_delta) return shape_embeddings class ActorCriticTransformer(nn.Module): action_dim: Sequence[int] fc_layer_width: int action_mode: str hybrid_action_continuous_dim: int multi_discrete_number_of_dims_per_distribution: List[int] transformer_size: int transformer_encoder_size: int transformer_depth: int fc_layer_depth: int num_heads: int activation: str aggregate_mode: str # "dummy" or "mean" or "dummy_and_mean" full_attention_mask: bool # if true, only mask out inactives, and have everything attend to everything else add_generator_embedding: bool = False generator_embedding_number_of_timesteps: int = 10 recurrent: bool = True @nn.compact def __call__(self, hidden, x): if self.activation == "relu": activation = nn.relu else: activation = nn.tanh og_obs, dones = x if self.add_generator_embedding: obs = og_obs.obs else: obs = og_obs # obs._ is [T, B, N, L] # B - batch size # T - time # N - number of things # L - unembedded entity size obs: EntityObservation def _single_encoder(features, entity_id, concat=True): # assume two entity types num_to_remove = 1 if concat else 0 embedding = activation( nn.Dense( self.transformer_encoder_size - num_to_remove, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0), )(features) ) if concat: id_1h = jnp.zeros((*embedding.shape[:3], 1)).at[:, :, :, entity_id].set(entity_id) return jnp.concatenate([embedding, id_1h], axis=-1) else: return embedding circle_encodings = _single_encoder(obs.circles, 0) polygon_encodings = _single_encoder(obs.polygons, 1) joint_encodings = _single_encoder(obs.joints, -1, False) thruster_encodings = _single_encoder(obs.thrusters, -1, False) # Size of this is something like (T, B, N, K) (time, batch, num_entities, embedding_size) # T, B, M, K shape_encodings = jnp.concatenate([polygon_encodings, circle_encodings], axis=2) # T, B, M shape_mask = jnp.concatenate([obs.polygon_mask, obs.circle_mask], axis=2) def mask_out_inactives(flat_active_mask, matrix_attention_mask): matrix_attention_mask = matrix_attention_mask & (flat_active_mask[:, None]) & (flat_active_mask[None, :]) return matrix_attention_mask joint_indexes = obs.joint_indexes thruster_indexes = obs.thruster_indexes if self.aggregate_mode == "dummy" or self.aggregate_mode == "dummy_and_mean": T, B, _, K = circle_encodings.shape dummy = jnp.ones((T, B, 1, K)) shape_encodings = jnp.concatenate([dummy, shape_encodings], axis=2) shape_mask = jnp.concatenate( [jnp.ones((T, B, 1), dtype=bool), shape_mask], axis=2, ) N = obs.attention_mask.shape[-1] overall_mask = ( jnp.ones((T, B, obs.attention_mask.shape[2], N + 1, N + 1), dtype=bool) .at[:, :, :, 1:, 1:] .set(obs.attention_mask) ) overall_mask = jax.vmap(jax.vmap(mask_out_inactives))(shape_mask, overall_mask) # To account for the dummy entity joint_indexes = joint_indexes + 1 thruster_indexes = thruster_indexes + 1 else: overall_mask = obs.attention_mask if self.full_attention_mask: overall_mask = jnp.ones(overall_mask.shape, dtype=bool) overall_mask = jax.vmap(jax.vmap(mask_out_inactives))(shape_mask, overall_mask) # Now do attention on these embedding = Transformer( num_layers=self.transformer_depth, num_heads=self.num_heads, qkv_features=self.transformer_size, encoder_size=self.transformer_encoder_size, gating=True, gating_bias=0.0, )( shape_encodings, jnp.repeat(overall_mask, repeats=self.num_heads // overall_mask.shape[2], axis=2), joint_encodings, obs.joint_mask, joint_indexes, thruster_encodings, obs.thruster_mask, thruster_indexes, ) # add the extra dimension for the heads if self.aggregate_mode == "mean" or self.aggregate_mode == "dummy_and_mean": embedding = jnp.mean(embedding, axis=2, where=shape_mask[..., None]) else: embedding = embedding[:, :, 0] # Take the dummy entity as the embedding of the entire scene. return GeneralActorCriticRNN( action_dim=self.action_dim, fc_layer_depth=self.fc_layer_depth, fc_layer_width=self.fc_layer_width, action_mode=self.action_mode, hybrid_action_continuous_dim=self.hybrid_action_continuous_dim, multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution, add_generator_embedding=self.add_generator_embedding, generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps, recurrent=self.recurrent, )(hidden, og_obs, embedding, dones, activation)