Spaces:
Runtime error
Runtime error
from functools import partial | |
import math | |
import chex | |
import jax | |
import jax.numpy as jnp | |
from flax.serialization import to_state_dict | |
from jax2d.engine import ( | |
calculate_collision_matrix, | |
calc_inverse_mass_polygon, | |
calc_inverse_mass_circle, | |
calc_inverse_inertia_circle, | |
calc_inverse_inertia_polygon, | |
recalculate_mass_and_inertia, | |
select_shape, | |
PhysicsEngine, | |
) | |
from jax2d.sim_state import SimState, RigidBody, Joint, Thruster | |
from jax2d.maths import rmat | |
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams | |
from kinetix.environment.ued.mutators import ( | |
mutate_add_connected_shape_proper, | |
mutate_add_shape, | |
mutate_add_connected_shape, | |
mutate_add_thruster, | |
) | |
from kinetix.environment.ued.ued_state import UEDParams | |
from kinetix.environment.ued.util import ( | |
get_role, | |
sample_dimensions, | |
is_space_for_shape, | |
random_position_on_polygon, | |
random_position_on_circle, | |
are_there_shapes_present, | |
is_space_for_joint, | |
) | |
from kinetix.environment.utils import permute_state | |
from kinetix.util.saving import load_world_state_pickle | |
from flax import struct | |
from kinetix.environment.env import create_empty_env | |
def create_vmapped_filtered_distribution( | |
rng, | |
level_sampler, | |
env_params: EnvParams, | |
static_env_params: StaticEnvParams, | |
ued_params: UEDParams, | |
n_samples: int, | |
env, | |
do_filter_levels: bool, | |
level_filter_sample_ratio: int, | |
env_size_name: str, | |
level_filter_n_steps: int, | |
): | |
if do_filter_levels and level_filter_n_steps > 0: | |
sample_ratio = level_filter_sample_ratio | |
n_unfiltered_samples = sample_ratio * n_samples | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, n_unfiltered_samples) | |
# unfiltered_levels = jax.vmap(level_sampler, in_axes=(0, None, None, None, None))( | |
# _rngs, env_params, static_env_params, ued_params, env_size_name | |
# ) | |
unfiltered_levels = jax.vmap(level_sampler, in_axes=(0,))(_rngs) | |
# | |
# No-op filtering | |
def _noop_step(states, rng): | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, n_unfiltered_samples) | |
action = jnp.zeros((n_unfiltered_samples, *env.action_space(env_params).shape), dtype=jnp.int32) | |
obs, states, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))( | |
_rngs, states, action, env_params | |
) | |
return states, (done, reward) | |
# Wrap levels | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, n_unfiltered_samples) | |
obsv, unfiltered_levels_wrapped = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))( | |
_rngs, unfiltered_levels, env_params | |
) | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, level_filter_n_steps) | |
_, (done, rewards) = jax.lax.scan(_noop_step, unfiltered_levels_wrapped, xs=_rngs) | |
done_indexes = jnp.argmax(done, axis=0) | |
done_rewards = rewards[done_indexes, jnp.arange(n_unfiltered_samples)] | |
noop_solved_indexes = done_rewards > 0.5 | |
p = noop_solved_indexes * 0.001 + (1 - noop_solved_indexes) * 1.0 | |
p /= p.sum() | |
rng, _rng = jax.random.split(rng) | |
level_indexes = jax.random.choice( | |
_rng, jnp.arange(n_unfiltered_samples), shape=(n_samples,), replace=False, p=p | |
) | |
levels = jax.tree.map(lambda x: x[level_indexes], unfiltered_levels) | |
else: | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, n_samples) | |
levels = jax.vmap(level_sampler, in_axes=(0,))(_rngs) | |
return levels | |
def sample_kinetix_level( | |
rng, | |
engine: PhysicsEngine, | |
env_params: EnvParams, | |
static_env_params: StaticEnvParams, | |
ued_params: UEDParams, | |
env_size_name: str = "l", | |
): | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, 12) | |
small_force_no_fixate = env_size_name == "s" | |
# Start with empty state | |
state = create_empty_env(static_env_params) | |
# Set the floor | |
prob_of_floor_colour = jnp.array( | |
[ | |
ued_params.floor_prob_normal, | |
ued_params.floor_prob_green, | |
ued_params.floor_prob_blue, | |
ued_params.floor_prob_red, | |
] | |
) | |
floor_colour = jax.random.choice(_rngs[0], jnp.arange(4), p=prob_of_floor_colour) | |
state = state.replace(polygon_shape_roles=state.polygon_shape_roles.at[0].set(floor_colour)) | |
# When we add shapes we don't want them to collide with already existing shapes | |
def _choose_proposal_with_least_collisions(proposals, bias=None): | |
rr, cr, cc = jax.vmap(engine.calculate_collision_manifolds)(proposals) | |
rr_collisions = jnp.sum(jnp.sum(rr.active.astype(jnp.int32), axis=-1), axis=-1) | |
cr_collisions = jnp.sum(cr.active.astype(jnp.int32), axis=-1) | |
cc_collisions = jnp.sum(cc.active.astype(jnp.int32), axis=-1) | |
all_collisions = jnp.concatenate( | |
[rr_collisions[:, None], cr_collisions[:, None], cc_collisions[:, None]], axis=1 | |
) | |
num_collisions = jnp.sum(all_collisions, axis=-1) | |
if bias is not None: | |
num_collisions = num_collisions + bias | |
chosen_addition_idx = jnp.argmin(num_collisions) | |
return jax.tree.map(lambda x: x[chosen_addition_idx], proposals) | |
def _add_filtered_shape(rng, state, force_no_fixate=False): | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, ued_params.add_shape_n_proposals) | |
proposed_additions = jax.vmap(mutate_add_shape, in_axes=(0, None, None, None, None, None))( | |
_rngs, | |
state, | |
env_params, | |
static_env_params, | |
ued_params, | |
jnp.logical_or(force_no_fixate, small_force_no_fixate), | |
) | |
return _choose_proposal_with_least_collisions(proposed_additions) | |
def _add_filtered_connected_shape(rng, state, force_rjoint=False): | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, ued_params.add_shape_n_proposals) | |
proposed_additions, valid = jax.vmap(mutate_add_connected_shape, in_axes=(0, None, None, None, None, None))( | |
_rngs, state, env_params, static_env_params, ued_params, force_rjoint | |
) | |
bias = (jnp.ones(ued_params.add_shape_n_proposals) - 1 * valid) * ued_params.connect_no_visibility_bias | |
return _choose_proposal_with_least_collisions(proposed_additions, bias=bias) | |
# Add green and blue - make sure they're not both fixated | |
force_green_no_fixate = (jax.random.uniform(_rngs[1]) < 0.5) | (state.polygon_shape_roles[0] == 2) | |
state = _add_filtered_shape(_rngs[2], state, force_green_no_fixate) | |
state = _add_filtered_shape(_rngs[3], state, ~force_green_no_fixate) | |
# Forced controls | |
forced_control = jnp.array([[0, 1], [1, 0], [1, 1]])[jax.random.randint(_rngs[4], (), 0, 3)] | |
force_thruster, force_motor = forced_control[0], forced_control[1] | |
# Forced motor | |
state = jax.lax.cond( | |
force_motor, | |
lambda: _add_filtered_connected_shape(_rngs[5], state, force_rjoint=True), # force the rjoint | |
lambda: _add_filtered_shape(_rngs[6], state), | |
) | |
# Forced thruster | |
state = jax.lax.cond( | |
force_thruster, | |
lambda: mutate_add_thruster(_rngs[7], state, env_params, static_env_params, ued_params), | |
lambda: state, | |
) | |
# Add rest of shapes | |
n_shapes_to_add = ( | |
static_env_params.num_polygons + static_env_params.num_circles - 3 - static_env_params.num_static_fixated_polys | |
) | |
def _add_shape(state, rng): | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, 3) | |
shape_add_type = jax.random.choice( | |
_rngs[0], | |
jnp.arange(3), | |
p=jnp.array( | |
[ued_params.add_connected_shape_chance, ued_params.add_shape_chance, ued_params.add_no_shape_chance] | |
), | |
) | |
state = jax.lax.switch( | |
shape_add_type, | |
[ | |
lambda: _add_filtered_connected_shape(_rngs[1], state), | |
lambda: _add_filtered_shape(_rngs[2], state), | |
lambda: state, | |
], | |
) | |
return state, None | |
state, _ = jax.lax.scan(_add_shape, state, jax.random.split(_rngs[8], n_shapes_to_add)) | |
# Add thrusters | |
n_thrusters_to_add = static_env_params.num_thrusters - 1 | |
def _add_thruster(state, rng): | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, 3) | |
state = jax.lax.cond( | |
jax.random.uniform(_rngs[0]) < ued_params.add_thruster_chance, | |
lambda: mutate_add_thruster(_rngs[1], state, env_params, static_env_params, ued_params), | |
lambda: state, | |
) | |
return state, None | |
state, _ = jax.lax.scan(_add_thruster, state, jax.random.split(_rngs[9], n_thrusters_to_add)) | |
# Randomly swap green and blue to remove left-right bias | |
def _swap_roles(do_swap_roles, roles): | |
role1 = roles == 1 | |
role2 = roles == 2 | |
swapped_roles = roles * ~(role1 | role2) + role1.astype(int) * 2 + role2.astype(int) * 1 | |
return jax.lax.select(do_swap_roles, swapped_roles, roles) | |
do_swap_roles = jax.random.uniform(_rngs[10], shape=()) < 0.5 | |
# Don't want to swap if floor is non-standard | |
do_swap_roles &= state.polygon_shape_roles[0] == 0 | |
state = state.replace( | |
polygon_shape_roles=_swap_roles(do_swap_roles, state.polygon_shape_roles), | |
circle_shape_roles=_swap_roles(do_swap_roles, state.circle_shape_roles), | |
) | |
return permute_state(_rngs[11], state, static_env_params) | |
def create_random_starting_distribution( | |
rng, | |
env_params: EnvParams, | |
static_env_params: StaticEnvParams, | |
ued_params: UEDParams, | |
env_size_name: str, | |
controllable=True, | |
): | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, 15) | |
d = to_state_dict(ued_params) | |
ued_params = UEDParams( | |
**( | |
d | |
| dict( | |
goal_body_size_factor=2.0, | |
thruster_power_multiplier=2.0, | |
max_shape_size=0.5, | |
) | |
), | |
) | |
prob_of_large_shapes = 0.05 | |
ued_params_large_shapes = ued_params.replace( | |
max_shape_size=static_env_params.max_shape_size * 1.0, goal_body_size_factor=1.0 | |
) | |
state = create_empty_env(env_params, static_env_params) | |
def _get_ued_params(rng): | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
large_shapes = jax.random.uniform(_rng) < prob_of_large_shapes | |
params_to_use = jax.tree.map( | |
lambda x, y: jax.lax.select(large_shapes, x, y), ued_params_large_shapes, ued_params | |
) | |
return params_to_use | |
def _my_add_shape(rng, state): | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
return mutate_add_shape(_rng, state, env_params, static_env_params, _get_ued_params(_rng2)) | |
def _my_add_connected_shape(rng, state, **kwargs): | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
return mutate_add_connected_shape_proper( | |
_rng, state, env_params, static_env_params, _get_ued_params(_rng2), **kwargs | |
) | |
# Add the green thing and blue thing | |
state = _my_add_shape(_rngs[0], state) | |
state = _my_add_shape(_rngs[1], state) | |
if controllable: | |
# Forced controls | |
forced_control = jnp.array([[0, 1], [1, 0], [1, 1]])[jax.random.randint(_rngs[2], (), 0, 3)] | |
force_thruster, force_motor = forced_control[0], forced_control[1] | |
# Forced motor | |
state = jax.lax.cond( | |
force_motor, | |
lambda: _my_add_connected_shape(_rngs[3], state, force_rjoint=True), # force the rjoint | |
lambda: state, | |
) | |
# Forced thruster | |
state = jax.lax.cond( | |
force_thruster, | |
lambda: mutate_add_thruster(_rngs[4], state, env_params, static_env_params, ued_params), | |
lambda: state, | |
) | |
return permute_state(_rngs[7], state, static_env_params) | |