kinet-test / kinetix /render /renderer_symbolic_flat.py
tree3po's picture
Upload 46 files
581eeac verified
raw
history blame
3.8 kB
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax2d import joint
from jax2d.engine import select_shape
from jax2d.maths import rmat
from jax2d.sim_state import RigidBody
from jaxgl.maths import dist_from_line
from jaxgl.renderer import clear_screen, make_renderer
from jaxgl.shaders import (
fragment_shader_quad,
fragment_shader_edged_quad,
make_fragment_shader_texture,
nearest_neighbour,
make_fragment_shader_quad_textured,
)
from kinetix.render.renderer_symbolic_common import (
make_circle_features,
make_joint_features,
make_polygon_features,
make_thruster_features,
)
from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState
from flax import struct
def make_render_symbolic(params, static_params: StaticEnvParams):
def render_symbolic(state):
n_polys = static_params.num_polygons
nshapes = n_polys + static_params.num_circles
polygon_features, polygon_mask = make_polygon_features(state, params, static_params)
mask_to_ignore_walls_ceiling = np.ones(static_params.num_polygons, dtype=bool)
mask_to_ignore_walls_ceiling[np.array([1, 2, 3])] = False
polygon_features = polygon_features[mask_to_ignore_walls_ceiling]
polygon_mask = polygon_mask[mask_to_ignore_walls_ceiling]
circle_features, circle_mask = make_circle_features(state, params, static_params)
joint_features, joint_idxs, joint_mask = make_joint_features(state, params, static_params)
thruster_features, thruster_idxs, thruster_mask = make_thruster_features(state, params, static_params)
two_J = joint_features.shape[0]
J = two_J // 2 # for symbolic only have the one
joint_features = jnp.concatenate(
[
joint_features[:J], # shape (2 * J, K)
jax.nn.one_hot(joint_idxs[:J, 0], nshapes), # shape (2 * J, N)
jax.nn.one_hot(joint_idxs[:J, 1], nshapes), # shape (2 * J, N)
],
axis=1,
)
thruster_features = jnp.concatenate(
[
thruster_features,
jax.nn.one_hot(thruster_idxs, nshapes),
],
axis=1,
)
polygon_features = jnp.where(polygon_mask[:, None], polygon_features, 0.0).flatten()
circle_features = jnp.where(circle_mask[:, None], circle_features, 0.0).flatten()
joint_features = jnp.where(joint_mask[:J, None], joint_features, 0.0).flatten()
thruster_features = jnp.where(thruster_mask[:, None], thruster_features, 0.0).flatten()
def _get_manifold_features(manifold):
collision_mask_features = jnp.concatenate(
[
manifold.normal,
jnp.expand_dims(manifold.penetration, axis=-1),
manifold.collision_point,
jnp.expand_dims(manifold.acc_impulse_normal, axis=-1),
jnp.expand_dims(manifold.acc_impulse_tangent, axis=-1),
],
axis=-1,
)
return (collision_mask_features * manifold.active[..., None]).flatten()
obs = jnp.concatenate(
[
polygon_features,
circle_features,
joint_features,
thruster_features,
jnp.array([state.gravity[1]]) / 10,
# _get_manifold_features(state.acc_cc_manifolds),
# _get_manifold_features(state.acc_cr_manifolds),
# _get_manifold_features(state.acc_rr_manifolds),
],
axis=0,
)
obs = jnp.clip(obs, a_min=-10.0, a_max=10.0)
obs = jnp.nan_to_num(obs)
return obs
return render_symbolic