File size: 11,410 Bytes
e0f25ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import math
from functools import partial

import jax
import jax.numpy as jnp

from jax2d.engine import PhysicsEngine, calculate_collision_matrix, recalculate_mass_and_inertia, select_shape
from jax2d.sim_state import RigidBody, Thruster
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams


def sample_dimensions(rng, static_env_params: StaticEnvParams, is_rect: bool, ued_params, max_shape_size=None):
    if max_shape_size is None:
        max_shape_size = static_env_params.max_shape_size
    # Returns (half_dimensions, radius)

    rng, _rng = jax.random.split(rng)
    # Don't want overly small shapes
    min_rect_size = 0.05
    min_circle_size = 0.1
    cap_rect = max_shape_size / 2.0 / jnp.sqrt(2.0)
    cap_circ = max_shape_size / 2.0 * ued_params.circle_max_size_coeff
    half_dimensions = (
        jax.lax.select(is_rect, jax.random.uniform(_rng, shape=(2,)), jnp.zeros(2, dtype=jnp.float32))
        * (cap_rect - min_rect_size)
        + min_rect_size
    )

    rng, _rng, __rng = jax.random.split(rng, 3)
    dim_scale = (
        jnp.ones(2)
        .at[jax.random.randint(_rng, shape=(), minval=0, maxval=2)]
        .set(
            jax.lax.select(
                jax.random.uniform(__rng) < ued_params.large_rect_dim_chance, ued_params.large_rect_dim_scale, 1.0
            )
        )
    )
    half_dimensions *= dim_scale

    vertices = jnp.array(
        [
            half_dimensions * jnp.array([1, 1]),
            half_dimensions * jnp.array([1, -1]),
            half_dimensions * jnp.array([-1, -1]),
            half_dimensions * jnp.array([-1, 1]),
        ]
    )

    rng, _rng = jax.random.split(rng)
    radius = (
        jax.lax.select(is_rect, jnp.zeros((), dtype=jnp.float32), jax.random.uniform(_rng, shape=()))
        * (cap_circ - min_circle_size)
        + min_circle_size
    )
    return vertices, half_dimensions, radius


def count_roles(state: EnvState, static_env_params: StaticEnvParams, role: int, include_static_polys=True) -> int:
    active_to_use = state.polygon.active
    if not include_static_polys:
        active_to_use = active_to_use.at[: static_env_params.num_static_fixated_polys].set(False)
    return ((state.polygon_shape_roles == role) * active_to_use).sum() + (
        (state.circle_shape_roles == role) * state.circle.active
    ).sum()


def random_position_on_triangle(rng, vertices):
    verts = vertices[:3]
    rng, _rng, _rng2 = jax.random.split(rng, 3)
    f1 = jax.random.uniform(_rng)
    f2 = jax.random.uniform(_rng2)
    # https://www.reddit.com/r/godot/comments/mqp29g/how_do_i_get_a_random_position_inside_a_collision/
    return verts[0] + jnp.sqrt(f1) * (-verts[0] + verts[1] + f2 * (verts[2] - verts[1]))


def random_position_on_rectangle(rng, vertices):
    verts = vertices[:4]
    rng, _rng, _rng2 = jax.random.split(rng, 3)
    f1 = jax.random.uniform(_rng)
    f2 = jax.random.uniform(_rng2)

    min_x, max_x = jnp.min(verts[:, 0]), jnp.max(verts[:, 0])
    min_y, max_y = jnp.min(verts[:, 1]), jnp.max(verts[:, 1])
    random_x_pos = min_x + f1 * (max_x - min_x)
    random_y_pos = min_y + f2 * (max_y - min_y)

    return jnp.array([random_x_pos, random_y_pos])


def random_position_on_polygon(rng, vertices, n_vertices, static_env_params: StaticEnvParams):
    assert static_env_params.max_polygon_vertices <= 4, "Only supports up to 4 vertices"
    return jax.lax.select(
        n_vertices <= 3, random_position_on_triangle(rng, vertices), random_position_on_rectangle(rng, vertices)
    )


def random_position_on_circle(rng, radius, on_centre_chance):
    rngs = jax.random.split(rng, 3)

    on_centre = jax.random.uniform(rngs[0]) < on_centre_chance

    local_joint_position_circle_theta = jax.random.uniform(rngs[1], shape=()) * 2 * math.pi
    local_joint_position_circle_r = jax.random.uniform(rngs[2], shape=()) * radius
    local_joint_position_circle = jnp.array(
        [
            local_joint_position_circle_r * jnp.cos(local_joint_position_circle_theta),
            local_joint_position_circle_r * jnp.sin(local_joint_position_circle_theta),
        ]
    )

    return jax.lax.select(on_centre, jnp.array([0.0, 0.0]), local_joint_position_circle)


def get_role(rng, state: EnvState, static_env_params: StaticEnvParams, initial_p=None) -> int:

    if initial_p is None:
        initial_p = jnp.array([1.0, 1.0, 1.0, 1.0])

    needs_ball = count_roles(state, static_env_params, 1) == 0
    needs_goal = count_roles(state, static_env_params, 2) == 0
    needs_lava = count_roles(state, static_env_params, 3) == 0

    # always put goal/ball first.
    prob_of_something_else = (needs_ball == 0) & (needs_goal == 0)
    p = initial_p * jnp.array(
        [prob_of_something_else, needs_ball, needs_goal, prob_of_something_else * needs_lava / 3]
    )  # This ensures we cannot more than one ball or goal.
    return jax.random.choice(rng, jnp.array([0, 1, 2, 3]), p=p)


def is_space_for_shape(state: EnvState):
    return jnp.logical_not(jnp.concatenate([state.polygon.active, state.circle.active])).sum() > 0


def is_space_for_joint(state: EnvState):
    return jnp.logical_not(state.joint.active).sum() > 0


def are_there_shapes_present(state: EnvState, static_env_params: StaticEnvParams):
    m = (
        jnp.concatenate([state.polygon.active, state.circle.active])
        .at[: static_env_params.num_static_fixated_polys]
        .set(False)
    )
    return m.sum() > 0


@partial(jax.jit, static_argnums=(2, 9))
def add_rigidbody_to_state(
    state: EnvState,
    env_params: EnvParams,
    static_env_params: StaticEnvParams,
    position: jnp.ndarray,
    vertices: jnp.ndarray,
    n_vertices: int,
    radius: float,
    shape_role: int,
    density: float = 1,
    is_circle: bool = False,
):

    new_rigid_body = RigidBody(
        position=position,
        velocity=jnp.array([0.0, 0.0]),
        inverse_mass=1.0,
        inverse_inertia=1.0,
        rotation=0.0,
        angular_velocity=0.0,
        radius=radius,
        active=True,
        friction=1.0,
        vertices=vertices,
        n_vertices=n_vertices,
        collision_mode=1,
        restitution=0.0,
    )

    if is_circle:
        actives = state.circle.active
    else:
        actives = state.polygon.active

    idx = jnp.argmin(actives)

    def noop(state):
        return state

    def replace(state):
        add_func = lambda all, new: all.at[idx].set(new)
        if is_circle:
            state = state.replace(
                circle=jax.tree.map(add_func, state.circle, new_rigid_body),
                circle_densities=state.circle_densities.at[idx].set(density),
                circle_shape_roles=state.circle_shape_roles.at[idx].set(shape_role),
            )
        else:
            state = state.replace(
                polygon=jax.tree.map(add_func, state.polygon, new_rigid_body),
                polygon_densities=state.polygon_densities.at[idx].set(density),
                polygon_shape_roles=state.polygon_shape_roles.at[idx].set(shape_role),
            )

        state = state.replace(
            collision_matrix=calculate_collision_matrix(static_env_params, state.joint),
        )

        state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
        return state

    return jax.lax.cond(jnp.logical_not(actives).sum() > 0, replace, noop, state)


def rectangle_vertices(half_dim):
    return jnp.array(
        [
            half_dim * jnp.array([1, 1]),
            half_dim * jnp.array([1, -1]),
            half_dim * jnp.array([-1, -1]),
            half_dim * jnp.array([-1, 1]),
        ]
    )


# More Manual Control
@partial(jax.jit, static_argnums=(2,))
def add_rectangle_to_state(
    state: EnvState,
    env_params: EnvParams,
    static_env_params: StaticEnvParams,
    position: jnp.ndarray,
    width: float,
    height: float,
    shape_role: int,
    density: float = 1,
):

    return add_rigidbody_to_state(
        state,
        env_params,
        static_env_params,
        position,
        rectangle_vertices(jnp.array([width, height]) / 2),
        4,
        0.0,
        shape_role,
        density,
        is_circle=False,
    )


@partial(jax.jit, static_argnums=(2,))
def add_circle_to_state(
    state: EnvState,
    env_params: EnvParams,
    static_env_params: StaticEnvParams,
    position: jnp.ndarray,
    radius: float,
    shape_role: int,
    density: float = 1,
):
    return add_rigidbody_to_state(
        state,
        env_params,
        static_env_params,
        position,
        jnp.array([0.0, 0.0]),
        0,
        radius,
        shape_role,
        density,
        is_circle=True,
    )


@partial(jax.jit, static_argnums=(2,))
def add_thruster_to_object(
    state: EnvState,
    env_params: EnvParams,
    static_env_params: StaticEnvParams,
    shape_index: int,
    rotation: float,
    colour: int,
    thruster_power_multiplier: float,
):
    def dummy(state):
        return state

    def do_add(state: EnvState):
        thruster_idx = jnp.argmin(state.thruster.active)

        shape = select_shape(state, shape_index, static_env_params)

        thruster = Thruster(
            object_index=shape_index,
            active=True,
            relative_position=jnp.array([0.0, 0.0]),  # a bit of a hack but reasonable.
            rotation=rotation,
            power=1.0 / jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass) * thruster_power_multiplier,
            global_position=select_shape(state, shape_index, static_env_params).position,
        )

        state = state.replace(
            thruster=jax.tree_map(lambda y, x: y.at[thruster_idx].set(x), state.thruster, thruster),
            thruster_bindings=state.thruster_bindings.at[thruster_idx].set(colour),
        )

        return state

    return jax.lax.cond(
        (select_shape(state, shape_index, static_env_params).active)
        & (jnp.logical_not(state.thruster.active).sum() > 0),
        do_add,
        dummy,
        state,
    )


def make_velocities_zero(state: EnvState):
    def inner(state):
        return state.replace(
            polygon=state.polygon.replace(
                angular_velocity=state.polygon.angular_velocity * 0,
                velocity=state.polygon.velocity * 0,
            ),
            circle=state.circle.replace(
                angular_velocity=state.circle.angular_velocity * 0,
                velocity=state.circle.velocity * 0,
            ),
        )

    return inner(state)


def make_do_dummy_step(
    params: EnvParams, static_sim_params: StaticEnvParams, zero_collisions=True, zero_velocities=True
):
    env = PhysicsEngine(static_sim_params)

    @jax.jit
    def _step_fn(state):
        state, _ = env.step(state, params, jnp.zeros((static_sim_params.num_joints + static_sim_params.num_thrusters,)))
        return state

    def do_dummy_step(state: EnvState) -> EnvState:
        rng = jax.random.PRNGKey(0)
        og_col = state.collision_matrix
        g = state.gravity
        state = state.replace(
            collision_matrix=state.collision_matrix & (not zero_collisions), gravity=state.gravity * 0
        )
        state = _step_fn(state)
        state = state.replace(gravity=g, collision_matrix=og_col)
        if zero_velocities:
            state = make_velocities_zero(state)
        return state

    return do_dummy_step