tree3po's picture
Upload 190 files
e0f25ed verified
raw
history blame
12.3 kB
from functools import partial
import math
import chex
import jax
import jax.numpy as jnp
from flax.serialization import to_state_dict
from jax2d.engine import (
calculate_collision_matrix,
calc_inverse_mass_polygon,
calc_inverse_mass_circle,
calc_inverse_inertia_circle,
calc_inverse_inertia_polygon,
recalculate_mass_and_inertia,
select_shape,
PhysicsEngine,
)
from jax2d.sim_state import SimState, RigidBody, Joint, Thruster
from jax2d.maths import rmat
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
from kinetix.environment.ued.mutators import (
mutate_add_connected_shape_proper,
mutate_add_shape,
mutate_add_connected_shape,
mutate_add_thruster,
)
from kinetix.environment.ued.ued_state import UEDParams
from kinetix.environment.ued.util import (
get_role,
sample_dimensions,
is_space_for_shape,
random_position_on_polygon,
random_position_on_circle,
are_there_shapes_present,
is_space_for_joint,
)
from kinetix.environment.utils import permute_state
from kinetix.util.saving import load_world_state_pickle
from flax import struct
from kinetix.environment.env import create_empty_env
@partial(jax.jit, static_argnums=(1, 3, 5, 6, 7, 8, 9, 10))
def create_vmapped_filtered_distribution(
rng,
level_sampler,
env_params: EnvParams,
static_env_params: StaticEnvParams,
ued_params: UEDParams,
n_samples: int,
env,
do_filter_levels: bool,
level_filter_sample_ratio: int,
env_size_name: str,
level_filter_n_steps: int,
):
if do_filter_levels and level_filter_n_steps > 0:
sample_ratio = level_filter_sample_ratio
n_unfiltered_samples = sample_ratio * n_samples
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, n_unfiltered_samples)
# unfiltered_levels = jax.vmap(level_sampler, in_axes=(0, None, None, None, None))(
# _rngs, env_params, static_env_params, ued_params, env_size_name
# )
unfiltered_levels = jax.vmap(level_sampler, in_axes=(0,))(_rngs)
#
# No-op filtering
def _noop_step(states, rng):
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, n_unfiltered_samples)
action = jnp.zeros((n_unfiltered_samples, *env.action_space(env_params).shape), dtype=jnp.int32)
obs, states, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
_rngs, states, action, env_params
)
return states, (done, reward)
# Wrap levels
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, n_unfiltered_samples)
obsv, unfiltered_levels_wrapped = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
_rngs, unfiltered_levels, env_params
)
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, level_filter_n_steps)
_, (done, rewards) = jax.lax.scan(_noop_step, unfiltered_levels_wrapped, xs=_rngs)
done_indexes = jnp.argmax(done, axis=0)
done_rewards = rewards[done_indexes, jnp.arange(n_unfiltered_samples)]
noop_solved_indexes = done_rewards > 0.5
p = noop_solved_indexes * 0.001 + (1 - noop_solved_indexes) * 1.0
p /= p.sum()
rng, _rng = jax.random.split(rng)
level_indexes = jax.random.choice(
_rng, jnp.arange(n_unfiltered_samples), shape=(n_samples,), replace=False, p=p
)
levels = jax.tree.map(lambda x: x[level_indexes], unfiltered_levels)
else:
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, n_samples)
levels = jax.vmap(level_sampler, in_axes=(0,))(_rngs)
return levels
@partial(jax.jit, static_argnums=(1, 3, 4, 5))
def sample_kinetix_level(
rng,
engine: PhysicsEngine,
env_params: EnvParams,
static_env_params: StaticEnvParams,
ued_params: UEDParams,
env_size_name: str = "l",
):
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, 12)
small_force_no_fixate = env_size_name == "s"
# Start with empty state
state = create_empty_env(static_env_params)
# Set the floor
prob_of_floor_colour = jnp.array(
[
ued_params.floor_prob_normal,
ued_params.floor_prob_green,
ued_params.floor_prob_blue,
ued_params.floor_prob_red,
]
)
floor_colour = jax.random.choice(_rngs[0], jnp.arange(4), p=prob_of_floor_colour)
state = state.replace(polygon_shape_roles=state.polygon_shape_roles.at[0].set(floor_colour))
# When we add shapes we don't want them to collide with already existing shapes
def _choose_proposal_with_least_collisions(proposals, bias=None):
rr, cr, cc = jax.vmap(engine.calculate_collision_manifolds)(proposals)
rr_collisions = jnp.sum(jnp.sum(rr.active.astype(jnp.int32), axis=-1), axis=-1)
cr_collisions = jnp.sum(cr.active.astype(jnp.int32), axis=-1)
cc_collisions = jnp.sum(cc.active.astype(jnp.int32), axis=-1)
all_collisions = jnp.concatenate(
[rr_collisions[:, None], cr_collisions[:, None], cc_collisions[:, None]], axis=1
)
num_collisions = jnp.sum(all_collisions, axis=-1)
if bias is not None:
num_collisions = num_collisions + bias
chosen_addition_idx = jnp.argmin(num_collisions)
return jax.tree.map(lambda x: x[chosen_addition_idx], proposals)
def _add_filtered_shape(rng, state, force_no_fixate=False):
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, ued_params.add_shape_n_proposals)
proposed_additions = jax.vmap(mutate_add_shape, in_axes=(0, None, None, None, None, None))(
_rngs,
state,
env_params,
static_env_params,
ued_params,
jnp.logical_or(force_no_fixate, small_force_no_fixate),
)
return _choose_proposal_with_least_collisions(proposed_additions)
def _add_filtered_connected_shape(rng, state, force_rjoint=False):
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, ued_params.add_shape_n_proposals)
proposed_additions, valid = jax.vmap(mutate_add_connected_shape, in_axes=(0, None, None, None, None, None))(
_rngs, state, env_params, static_env_params, ued_params, force_rjoint
)
bias = (jnp.ones(ued_params.add_shape_n_proposals) - 1 * valid) * ued_params.connect_no_visibility_bias
return _choose_proposal_with_least_collisions(proposed_additions, bias=bias)
# Add green and blue - make sure they're not both fixated
force_green_no_fixate = (jax.random.uniform(_rngs[1]) < 0.5) | (state.polygon_shape_roles[0] == 2)
state = _add_filtered_shape(_rngs[2], state, force_green_no_fixate)
state = _add_filtered_shape(_rngs[3], state, ~force_green_no_fixate)
# Forced controls
forced_control = jnp.array([[0, 1], [1, 0], [1, 1]])[jax.random.randint(_rngs[4], (), 0, 3)]
force_thruster, force_motor = forced_control[0], forced_control[1]
# Forced motor
state = jax.lax.cond(
force_motor,
lambda: _add_filtered_connected_shape(_rngs[5], state, force_rjoint=True), # force the rjoint
lambda: _add_filtered_shape(_rngs[6], state),
)
# Forced thruster
state = jax.lax.cond(
force_thruster,
lambda: mutate_add_thruster(_rngs[7], state, env_params, static_env_params, ued_params),
lambda: state,
)
# Add rest of shapes
n_shapes_to_add = (
static_env_params.num_polygons + static_env_params.num_circles - 3 - static_env_params.num_static_fixated_polys
)
def _add_shape(state, rng):
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, 3)
shape_add_type = jax.random.choice(
_rngs[0],
jnp.arange(3),
p=jnp.array(
[ued_params.add_connected_shape_chance, ued_params.add_shape_chance, ued_params.add_no_shape_chance]
),
)
state = jax.lax.switch(
shape_add_type,
[
lambda: _add_filtered_connected_shape(_rngs[1], state),
lambda: _add_filtered_shape(_rngs[2], state),
lambda: state,
],
)
return state, None
state, _ = jax.lax.scan(_add_shape, state, jax.random.split(_rngs[8], n_shapes_to_add))
# Add thrusters
n_thrusters_to_add = static_env_params.num_thrusters - 1
def _add_thruster(state, rng):
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, 3)
state = jax.lax.cond(
jax.random.uniform(_rngs[0]) < ued_params.add_thruster_chance,
lambda: mutate_add_thruster(_rngs[1], state, env_params, static_env_params, ued_params),
lambda: state,
)
return state, None
state, _ = jax.lax.scan(_add_thruster, state, jax.random.split(_rngs[9], n_thrusters_to_add))
# Randomly swap green and blue to remove left-right bias
def _swap_roles(do_swap_roles, roles):
role1 = roles == 1
role2 = roles == 2
swapped_roles = roles * ~(role1 | role2) + role1.astype(int) * 2 + role2.astype(int) * 1
return jax.lax.select(do_swap_roles, swapped_roles, roles)
do_swap_roles = jax.random.uniform(_rngs[10], shape=()) < 0.5
# Don't want to swap if floor is non-standard
do_swap_roles &= state.polygon_shape_roles[0] == 0
state = state.replace(
polygon_shape_roles=_swap_roles(do_swap_roles, state.polygon_shape_roles),
circle_shape_roles=_swap_roles(do_swap_roles, state.circle_shape_roles),
)
return permute_state(_rngs[11], state, static_env_params)
@partial(jax.jit, static_argnums=(2, 4, 5))
def create_random_starting_distribution(
rng,
env_params: EnvParams,
static_env_params: StaticEnvParams,
ued_params: UEDParams,
env_size_name: str,
controllable=True,
):
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, 15)
d = to_state_dict(ued_params)
ued_params = UEDParams(
**(
d
| dict(
goal_body_size_factor=2.0,
thruster_power_multiplier=2.0,
max_shape_size=0.5,
)
),
)
prob_of_large_shapes = 0.05
ued_params_large_shapes = ued_params.replace(
max_shape_size=static_env_params.max_shape_size * 1.0, goal_body_size_factor=1.0
)
state = create_empty_env(env_params, static_env_params)
def _get_ued_params(rng):
rng, _rng, _rng2 = jax.random.split(rng, 3)
large_shapes = jax.random.uniform(_rng) < prob_of_large_shapes
params_to_use = jax.tree.map(
lambda x, y: jax.lax.select(large_shapes, x, y), ued_params_large_shapes, ued_params
)
return params_to_use
def _my_add_shape(rng, state):
rng, _rng, _rng2 = jax.random.split(rng, 3)
return mutate_add_shape(_rng, state, env_params, static_env_params, _get_ued_params(_rng2))
def _my_add_connected_shape(rng, state, **kwargs):
rng, _rng, _rng2 = jax.random.split(rng, 3)
return mutate_add_connected_shape_proper(
_rng, state, env_params, static_env_params, _get_ued_params(_rng2), **kwargs
)
# Add the green thing and blue thing
state = _my_add_shape(_rngs[0], state)
state = _my_add_shape(_rngs[1], state)
if controllable:
# Forced controls
forced_control = jnp.array([[0, 1], [1, 0], [1, 1]])[jax.random.randint(_rngs[2], (), 0, 3)]
force_thruster, force_motor = forced_control[0], forced_control[1]
# Forced motor
state = jax.lax.cond(
force_motor,
lambda: _my_add_connected_shape(_rngs[3], state, force_rjoint=True), # force the rjoint
lambda: state,
)
# Forced thruster
state = jax.lax.cond(
force_thruster,
lambda: mutate_add_thruster(_rngs[4], state, env_params, static_env_params, ued_params),
lambda: state,
)
return permute_state(_rngs[7], state, static_env_params)