Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from jax2d import joint | |
| from jax2d.engine import select_shape | |
| from jax2d.maths import rmat | |
| from jax2d.sim_state import RigidBody | |
| from jaxgl.maths import dist_from_line | |
| from jaxgl.renderer import clear_screen, make_renderer | |
| from jaxgl.shaders import ( | |
| fragment_shader_quad, | |
| fragment_shader_edged_quad, | |
| make_fragment_shader_texture, | |
| nearest_neighbour, | |
| make_fragment_shader_quad_textured, | |
| ) | |
| from kinetix.render.textures import ( | |
| THRUSTER_TEXTURE_16_RGBA, | |
| RJOINT_TEXTURE_6_RGBA, | |
| FJOINT_TEXTURE_6_RGBA, | |
| ) | |
| from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState | |
| from flax import struct | |
| def make_render_pixels( | |
| params, | |
| static_params: StaticEnvParams, | |
| ): | |
| screen_dim = static_params.screen_dim | |
| downscale = static_params.downscale | |
| joint_tex_size = 6 | |
| thruster_tex_size = 16 | |
| FIXATED_COLOUR = jnp.array([80, 80, 80]) | |
| JOINT_COLOURS = jnp.array( | |
| [ | |
| # [0, 0, 255], | |
| [255, 255, 255], # yellow | |
| [255, 255, 0], # yellow | |
| [255, 0, 255], # purple/magenta | |
| [0, 255, 255], # cyan | |
| [255, 153, 51], # white | |
| ] | |
| ) | |
| def colour_thruster_texture(colour): | |
| return THRUSTER_TEXTURE_16_RGBA.at[:9, :, :3].mul(colour[None, None, :] / 255.0) | |
| coloured_thruster_textures = jax.vmap(colour_thruster_texture)(JOINT_COLOURS) | |
| ROLE_COLOURS = jnp.array( | |
| [ | |
| [160.0, 160.0, 160.0], # None | |
| [0.0, 204.0, 0.0], # Green: The ball | |
| [0.0, 102.0, 204.0], # Blue: The goal | |
| [255.0, 102.0, 102.0], # Red: Death Objects | |
| ] | |
| ) | |
| BACKGROUND_COLOUR = jnp.array([255.0, 255.0, 255.0]) | |
| def _get_colour(shape_role, inverse_inertia): | |
| base_colour = ROLE_COLOURS[shape_role] | |
| f = (inverse_inertia == 0) * 1 | |
| is_not_normal = (shape_role != 0) * 1 | |
| return jnp.array( | |
| [ | |
| base_colour, | |
| base_colour, | |
| FIXATED_COLOUR, | |
| base_colour * 0.5, | |
| ] | |
| )[2 * f + is_not_normal] | |
| # Pixels per unit distance | |
| ppud = params.pixels_per_unit // downscale | |
| downscaled_screen_dim = (screen_dim[0] // downscale, screen_dim[1] // downscale) | |
| full_screen_size = ( | |
| downscaled_screen_dim[0] + (static_params.max_shape_size * 2 * ppud), | |
| downscaled_screen_dim[1] + (static_params.max_shape_size * 2 * ppud), | |
| ) | |
| cleared_screen = clear_screen(full_screen_size, BACKGROUND_COLOUR) | |
| def _world_space_to_pixel_space(x): | |
| return (x + static_params.max_shape_size) * ppud | |
| def fragment_shader_kinetix_circle(position, current_frag, unit_position, uniform): | |
| centre, radius, rotation, colour, mask = uniform | |
| dist = jnp.sqrt(jnp.square(position - centre).sum()) | |
| inside = dist <= radius | |
| on_edge = dist > radius - 2 | |
| # TODO - precompute? | |
| normal = jnp.array([jnp.sin(rotation), -jnp.cos(rotation)]) | |
| dist = dist_from_line(position, centre, centre + normal) | |
| on_edge |= (dist < 1) & (jnp.dot(normal, position - centre) <= 0) | |
| fragment = jax.lax.select(on_edge, jnp.zeros(3), colour) | |
| return jax.lax.select(inside & mask, fragment, current_frag) | |
| def fragment_shader_kinetix_joint(position, current_frag, unit_position, uniform): | |
| texture, colour, mask = uniform | |
| tex_coord = ( | |
| jnp.array( | |
| [ | |
| joint_tex_size * unit_position[0], | |
| joint_tex_size * unit_position[1], | |
| ] | |
| ) | |
| - 0.5 | |
| ) | |
| tex_frag = nearest_neighbour(texture, tex_coord) | |
| tex_frag = tex_frag.at[3].mul(mask) | |
| tex_frag = tex_frag.at[:3].mul(colour / 255.0) | |
| tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag) | |
| return tex_frag | |
| thruster_pixel_size = thruster_tex_size // downscale | |
| thruster_pixel_size_diagonal = (thruster_pixel_size * np.sqrt(2)).astype(jnp.int32) + 1 | |
| def fragment_shader_kinetix_thruster(fragment_position, current_frag, unit_position, uniform): | |
| thruster_position, rotation, texture, mask = uniform | |
| tex_position = jnp.matmul(rmat(-rotation), (fragment_position - thruster_position)) / thruster_pixel_size + 0.5 | |
| mask &= (tex_position[0] >= 0) & (tex_position[0] <= 1) & (tex_position[1] >= 0) & (tex_position[1] <= 1) | |
| eps = 0.001 | |
| tex_coord = ( | |
| jnp.array( | |
| [ | |
| thruster_tex_size * tex_position[0], | |
| thruster_tex_size * tex_position[1], | |
| ] | |
| ) | |
| - 0.5 | |
| + eps | |
| ) | |
| tex_frag = nearest_neighbour(texture, tex_coord) | |
| tex_frag = tex_frag.at[3].mul(mask) | |
| tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag) | |
| return tex_frag | |
| patch_size_1d = static_params.max_shape_size * ppud | |
| patch_size = (patch_size_1d, patch_size_1d) | |
| circle_renderer = make_renderer(full_screen_size, fragment_shader_kinetix_circle, patch_size, batched=True) | |
| quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, patch_size, batched=True) | |
| big_quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, downscaled_screen_dim) | |
| joint_pixel_size = joint_tex_size // downscale | |
| joint_renderer = make_renderer( | |
| full_screen_size, fragment_shader_kinetix_joint, (joint_pixel_size, joint_pixel_size), batched=True | |
| ) | |
| thruster_renderer = make_renderer( | |
| full_screen_size, | |
| fragment_shader_kinetix_thruster, | |
| (thruster_pixel_size_diagonal, thruster_pixel_size_diagonal), | |
| batched=True, | |
| ) | |
| def render_pixels(state: EnvState): | |
| pixels = cleared_screen | |
| # Floor | |
| floor_uniform = ( | |
| _world_space_to_pixel_space(state.polygon.position[0, None, :] + state.polygon.vertices[0]), | |
| _get_colour(state.polygon_shape_roles[0], 0), | |
| jnp.zeros(3), | |
| True, | |
| ) | |
| pixels = big_quad_renderer(pixels, _world_space_to_pixel_space(jnp.zeros(2, dtype=jnp.int32)), floor_uniform) | |
| # Rectangles | |
| rectangle_patch_positions = _world_space_to_pixel_space( | |
| state.polygon.position - (static_params.max_shape_size / 2.0) | |
| ).astype(jnp.int32) | |
| rectangle_rmats = jax.vmap(rmat)(state.polygon.rotation) | |
| rectangle_rmats = jnp.repeat(rectangle_rmats[:, None, :, :], repeats=static_params.max_polygon_vertices, axis=1) | |
| rectangle_vertices_pixel_space = _world_space_to_pixel_space( | |
| state.polygon.position[:, None, :] + jax.vmap(jax.vmap(jnp.matmul))(rectangle_rmats, state.polygon.vertices) | |
| ) | |
| rectangle_colours = jax.vmap(_get_colour)(state.polygon_shape_roles, state.polygon.inverse_mass) | |
| rectangle_edge_colours = jnp.zeros((static_params.num_polygons, 3)) | |
| rectangle_uniforms = ( | |
| rectangle_vertices_pixel_space, | |
| rectangle_colours, | |
| rectangle_edge_colours, | |
| state.polygon.active, | |
| ) | |
| pixels = quad_renderer(pixels, rectangle_patch_positions, rectangle_uniforms) | |
| # Circles | |
| circle_positions_pixel_space = _world_space_to_pixel_space(state.circle.position) | |
| circle_radii_pixel_space = state.circle.radius * ppud | |
| circle_patch_positions = _world_space_to_pixel_space( | |
| state.circle.position - (static_params.max_shape_size / 2.0) | |
| ).astype(jnp.int32) | |
| circle_colours = jax.vmap(_get_colour)(state.circle_shape_roles, state.circle.inverse_mass) | |
| circle_uniforms = ( | |
| circle_positions_pixel_space, | |
| circle_radii_pixel_space, | |
| state.circle.rotation, | |
| circle_colours, | |
| state.circle.active, | |
| ) | |
| pixels = circle_renderer(pixels, circle_patch_positions, circle_uniforms) | |
| # Joints | |
| joint_patch_positions = jnp.round( | |
| _world_space_to_pixel_space(state.joint.global_position) - (joint_pixel_size // 2) | |
| ).astype(jnp.int32) | |
| joint_textures = jax.vmap(jax.lax.select, in_axes=(0, None, None))( | |
| state.joint.is_fixed_joint, FJOINT_TEXTURE_6_RGBA, RJOINT_TEXTURE_6_RGBA | |
| ) | |
| joint_colours = JOINT_COLOURS[ | |
| (state.motor_bindings + 1) * (state.joint.motor_on & (~state.joint.is_fixed_joint)) | |
| ] | |
| joint_uniforms = (joint_textures, joint_colours, state.joint.active) | |
| pixels = joint_renderer(pixels, joint_patch_positions, joint_uniforms) | |
| # Thrusters | |
| thruster_positions = jnp.round(_world_space_to_pixel_space(state.thruster.global_position)).astype(jnp.int32) | |
| thruster_patch_positions = thruster_positions - (thruster_pixel_size_diagonal // 2) | |
| thruster_textures = coloured_thruster_textures[state.thruster_bindings + 1] | |
| thruster_rotations = ( | |
| state.thruster.rotation | |
| + jax.vmap(select_shape, in_axes=(None, 0, None))( | |
| state, state.thruster.object_index, static_params | |
| ).rotation | |
| ) | |
| thruster_uniforms = (thruster_positions, thruster_rotations, thruster_textures, state.thruster.active) | |
| pixels = thruster_renderer(pixels, thruster_patch_positions, thruster_uniforms) | |
| # Crop out the sides | |
| crop_amount = static_params.max_shape_size * ppud | |
| return pixels[crop_amount:-crop_amount, crop_amount:-crop_amount] | |
| return render_pixels | |
| class PixelsObservation: | |
| image: jnp.ndarray | |
| global_info: jnp.ndarray | |
| def make_render_pixels_rl(params, static_params: StaticEnvParams): | |
| render_fn = make_render_pixels(params, static_params) | |
| def inner(state): | |
| pixels = render_fn(state) / 255.0 | |
| return PixelsObservation( | |
| image=pixels, | |
| global_info=jnp.array([state.gravity[1] / 10.0]), | |
| ) | |
| return inner | |