kinet-test / kinetix /render /renderer_symbolic_entity.py
tree3po's picture
Upload 46 files
581eeac verified
raw
history blame
4.94 kB
from cmath import rect
from functools import partial
import jax
import jax.numpy as jnp
from flax import struct
from jax2d.engine import get_pairwise_interaction_indices
from kinetix.environment.env_state import EnvState
from kinetix.render.renderer_symbolic_common import (
make_circle_features,
make_joint_features,
make_polygon_features,
make_thruster_features,
make_unified_shape_features,
)
@struct.dataclass
class EntityObservation:
circles: jnp.ndarray
polygons: jnp.ndarray
joints: jnp.ndarray
thrusters: jnp.ndarray
circle_mask: jnp.ndarray
polygon_mask: jnp.ndarray
joint_mask: jnp.ndarray
thruster_mask: jnp.ndarray
attention_mask: jnp.ndarray
# collision_mask: jnp.ndarray
joint_indexes: jnp.ndarray
thruster_indexes: jnp.ndarray
def make_render_entities(params, static_params):
_, _, _, circle_circle_pairs, circle_rect_pairs, rect_rect_pairs = get_pairwise_interaction_indices(static_params)
circle_rect_pairs = circle_rect_pairs.at[:, 0].add(static_params.num_polygons)
circle_circle_pairs = circle_circle_pairs + static_params.num_polygons
def render_entities(state: EnvState):
state = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x), state)
joint_features, joint_indexes, joint_mask = make_joint_features(state, params, static_params)
thruster_features, thruster_indexes, thruster_mask = make_thruster_features(state, params, static_params)
poly_nodes, poly_mask = make_polygon_features(state, params, static_params)
circle_nodes, circle_mask = make_circle_features(state, params, static_params)
def _add_grav(nodes):
return jnp.concatenate(
[nodes, jnp.zeros((nodes.shape[0], 1)) + state.gravity[1] / 10], axis=-1
) # add gravity to each shape's embedding
poly_nodes = _add_grav(poly_nodes)
circle_nodes = _add_grav(circle_nodes)
# Shape of something like (NPoly + NCircle + 2 * NJoint + NThruster )
mask_flat_shapes = jnp.concatenate([poly_mask, circle_mask], axis=0)
num_shapes = static_params.num_polygons + static_params.num_circles
def make_n_squared_mask(val):
# val has shape N of bools.
N = val.shape[0]
A = jnp.eye(N, N, dtype=bool) # also have things attend to themselves
# Make the shapes fully connected
full_mask = A.at[:num_shapes, :num_shapes].set(jnp.ones((num_shapes, num_shapes), dtype=bool))
one_hop_connected = jnp.zeros((N, N), dtype=bool)
one_hop_connected = one_hop_connected.at[joint_indexes[:, 0], joint_indexes[:, 1]].set(True)
one_hop_connected = one_hop_connected.at[0, 0].set(False) # invalid joints have indices of (0, 0)
multi_hop_connected = jnp.logical_not(state.collision_matrix)
collision_mask = state.collision_matrix
# where val is false, we want to mask out the row and column.
full_mask = full_mask & (val[:, None]) & (val[None, :])
collision_mask = collision_mask & (val[:, None]) & (val[None, :])
multi_hop_connected = multi_hop_connected & (val[:, None]) & (val[None, :])
one_hop_connected = one_hop_connected & (val[:, None]) & (val[None, :])
collision_manifold_mask = jnp.zeros_like(collision_mask)
def _set(collision_manifold_mask, pairs, active):
return collision_manifold_mask.at[
pairs[:, 0],
pairs[:, 1],
].set(active)
collision_manifold_mask = _set(
collision_manifold_mask,
rect_rect_pairs,
jnp.logical_or(state.acc_rr_manifolds.active[..., 0], state.acc_rr_manifolds.active[..., 1]),
)
collision_manifold_mask = _set(collision_manifold_mask, circle_rect_pairs, state.acc_cr_manifolds.active)
collision_manifold_mask = _set(collision_manifold_mask, circle_circle_pairs, state.acc_cc_manifolds.active)
collision_manifold_mask = collision_manifold_mask & (val[:, None]) & (val[None, :])
return jnp.concatenate(
[full_mask[None], multi_hop_connected[None], one_hop_connected[None], collision_manifold_mask[None]],
axis=0,
)
mask_n_squared = make_n_squared_mask(mask_flat_shapes)
return EntityObservation(
circles=circle_nodes,
polygons=poly_nodes,
joints=joint_features,
thrusters=thruster_features,
circle_mask=circle_mask,
polygon_mask=poly_mask,
joint_mask=joint_mask,
thruster_mask=thruster_mask,
attention_mask=mask_n_squared,
joint_indexes=joint_indexes,
thruster_indexes=thruster_indexes,
)
return render_entities