Spaces:
Runtime error
Runtime error
File size: 948 Bytes
e0f25ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from dataclasses import field
import jax.numpy as jnp
from flax import struct
from jax2d.sim_state import SimState, SimParams, StaticSimParams
@struct.dataclass
class EnvState(SimState):
thruster_bindings: jnp.ndarray
motor_bindings: jnp.ndarray
motor_auto: jnp.ndarray
polygon_shape_roles: jnp.ndarray
circle_shape_roles: jnp.ndarray
polygon_highlighted: jnp.ndarray
circle_highlighted: jnp.ndarray
polygon_densities: jnp.ndarray
circle_densities: jnp.ndarray
timestep: int = 0
@struct.dataclass
class EnvParams(SimParams):
max_timesteps: int = 256
pixels_per_unit: int = 100
dense_reward_scale: float = 0.1
num_shape_roles: int = 4
@struct.dataclass
class StaticEnvParams(StaticSimParams):
screen_dim: tuple[int, int] = (500, 500)
downscale: int = 4
frame_skip: int = 1
max_shape_size: int = 2
num_motor_bindings: int = 4
num_thruster_bindings: int = 2
|