tree3po's picture
Upload 190 files
e0f25ed verified
raw
history blame
48.3 kB
from functools import partial
import time
from enum import IntEnum
from typing import Tuple
import chex
import hydra
import jax
import jax.numpy as jnp
import numpy as np
from omegaconf import OmegaConf
import optax
from flax import core, struct
from flax.training.train_state import TrainState as BaseTrainState
import wandb
from kinetix.environment.ued.distributions import (
create_random_starting_distribution,
)
from kinetix.environment.ued.ued import (
make_mutate_env,
make_reset_train_function_with_mutations,
make_vmapped_filtered_level_sampler,
)
from kinetix.environment.ued.ued import (
make_mutate_env,
make_reset_train_function_with_list_of_levels,
make_reset_train_function_with_mutations,
)
from kinetix.util.config import (
generate_ued_params_from_config,
get_video_frequency,
init_wandb,
normalise_config,
save_data_to_local_file,
generate_params_from_config,
get_eval_level_groups,
)
from jaxued.environments.underspecified_env import EnvState
from jaxued.level_sampler import LevelSampler
from jaxued.utils import compute_max_returns, max_mc, positive_value_loss
from flax.serialization import to_state_dict
import sys
sys.path.append("experiments")
from kinetix.environment.env import make_kinetix_env_from_name
from kinetix.environment.env_state import StaticEnvParams
from kinetix.environment.wrappers import (
UnderspecifiedToGymnaxWrapper,
LogWrapper,
DenseRewardWrapper,
AutoReplayWrapper,
)
from kinetix.models import make_network_from_config
from kinetix.render.renderer_pixels import make_render_pixels
from kinetix.models.actor_critic import ScannedRNN
from kinetix.util.learning import (
general_eval,
get_eval_levels,
no_op_and_random_rollout,
sample_trajectories_and_learn,
)
from kinetix.util.saving import (
load_train_state_from_wandb_artifact_path,
save_model_to_wandb,
)
class UpdateState(IntEnum):
DR = 0
REPLAY = 1
MUTATE = 2
def get_level_complexity_metrics(all_levels: EnvState, static_env_params: StaticEnvParams):
def get_for_single_level(level):
return {
"complexity/num_shapes": level.polygon.active[static_env_params.num_static_fixated_polys :].sum()
+ level.circle.active.sum(),
"complexity/num_joints": level.joint.active.sum(),
"complexity/num_thrusters": level.thruster.active.sum(),
"complexity/num_rjoints": (level.joint.active * jnp.logical_not(level.joint.is_fixed_joint)).sum(),
"complexity/num_fjoints": (level.joint.active * (level.joint.is_fixed_joint)).sum(),
"complexity/has_ball": ((level.polygon_shape_roles == 1) * level.polygon.active).sum()
+ ((level.circle_shape_roles == 1) * level.circle.active).sum(),
"complexity/has_goal": ((level.polygon_shape_roles == 2) * level.polygon.active).sum()
+ ((level.circle_shape_roles == 2) * level.circle.active).sum(),
}
return jax.tree.map(lambda x: x.mean(), jax.vmap(get_for_single_level)(all_levels))
def get_ued_score_metrics(all_ued_scores):
(mc, pvl, learn) = all_ued_scores
scores = {}
for score, name in zip([mc, pvl, learn], ["MaxMC", "PVL", "Learnability"]):
scores[f"ued_scores/{name}/Mean"] = score.mean()
scores[f"ued_scores_additional/{name}/Max"] = score.max()
scores[f"ued_scores_additional/{name}/Min"] = score.min()
return scores
class TrainState(BaseTrainState):
sampler: core.FrozenDict[str, chex.ArrayTree] = struct.field(pytree_node=True)
update_state: UpdateState = struct.field(pytree_node=True)
# === Below is used for logging ===
num_dr_updates: int
num_replay_updates: int
num_mutation_updates: int
dr_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
replay_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
mutation_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
dr_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
replay_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
mutation_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
dr_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
replay_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
mutation_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
# region PPO helper functions
# endregion
def train_state_to_log_dict(train_state: TrainState, level_sampler: LevelSampler) -> dict:
"""To prevent the entire (large) train_state to be copied to the CPU when doing logging, this function returns all of the important information in a dictionary format.
Anything in the `log` key will be logged to wandb.
Args:
train_state (TrainState):
level_sampler (LevelSampler):
Returns:
dict:
"""
sampler = train_state.sampler
idx = jnp.arange(level_sampler.capacity) < sampler["size"]
s = jnp.maximum(idx.sum(), 1)
return {
"log": {
"level_sampler/size": sampler["size"],
"level_sampler/episode_count": sampler["episode_count"],
"level_sampler/max_score": sampler["scores"].max(),
"level_sampler/weighted_score": (sampler["scores"] * level_sampler.level_weights(sampler)).sum(),
"level_sampler/mean_score": (sampler["scores"] * idx).sum() / s,
},
"info": {
"num_dr_updates": train_state.num_dr_updates,
"num_replay_updates": train_state.num_replay_updates,
"num_mutation_updates": train_state.num_mutation_updates,
},
}
def compute_learnability(config, done, reward, info, num_envs):
num_agents = 1
BATCH_ACTORS = num_envs * num_agents
rollout_length = config["num_steps"] * config["outer_rollout_steps"]
@partial(jax.vmap, in_axes=(None, 1, 1, 1))
@partial(jax.jit, static_argnums=(0,))
def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
idxs = jnp.arange(max_steps)
@partial(jax.vmap, in_axes=(0, 0))
def __ep_outcomes(start_idx, end_idx):
mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
r = jnp.sum(returns * mask)
goal_r = info["GoalR"]
success = jnp.sum(goal_r * mask)
collision = 0
timeo = 0
l = end_idx - start_idx
return r, success, collision, timeo, l
done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
mask_done = jnp.where(done_idxs == max_steps, 0, 1)
ep_return, success, collision, timeo, length = __ep_outcomes(
jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
)
return {
"ep_return": ep_return.mean(where=mask_done),
"num_episodes": mask_done.sum(),
"num_success": success.sum(where=mask_done),
"success_rate": success.mean(where=mask_done),
"collision_rate": collision.mean(where=mask_done),
"timeout_rate": timeo.mean(where=mask_done),
"ep_len": length.mean(where=mask_done),
}
done_by_env = done.reshape((-1, num_agents, num_envs))
reward_by_env = reward.reshape((-1, num_agents, num_envs))
o = _calc_outcomes_by_agent(rollout_length, done, reward, info)
success_by_env = o["success_rate"].reshape((num_agents, num_envs))
learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
return (
learnability_by_env,
o["num_episodes"].reshape(num_agents, num_envs).sum(axis=0),
o["num_success"].reshape(num_agents, num_envs).T,
) # so agents is at the end.
def compute_score(
config: dict, dones: chex.Array, values: chex.Array, max_returns: chex.Array, reward, info, advantages: chex.Array
) -> chex.Array:
# Computes the score for each level
if config["score_function"] == "MaxMC":
return max_mc(dones, values, max_returns)
elif config["score_function"] == "pvl":
return positive_value_loss(dones, advantages)
elif config["score_function"] == "learnability":
learnability, num_episodes, num_success = compute_learnability(
config, dones, reward, info, config["num_train_envs"]
)
return learnability
else:
raise ValueError(f"Unknown score function: {config['score_function']}")
def compute_all_scores(
config: dict,
dones: chex.Array,
values: chex.Array,
max_returns: chex.Array,
reward,
info,
advantages: chex.Array,
return_success_rate=False,
):
mc = max_mc(dones, values, max_returns)
pvl = positive_value_loss(dones, advantages)
learnability, num_episodes, num_success = compute_learnability(
config, dones, reward, info, config["num_train_envs"]
)
if config["score_function"] == "MaxMC":
main_score = mc
elif config["score_function"] == "pvl":
main_score = pvl
elif config["score_function"] == "learnability":
main_score = learnability
else:
raise ValueError(f"Unknown score function: {config['score_function']}")
if return_success_rate:
success_rate = num_success.squeeze(1) / jnp.maximum(num_episodes, 1)
return main_score, (mc, pvl, learnability, success_rate)
return main_score, (mc, pvl, learnability)
@hydra.main(version_base=None, config_path="../configs", config_name="plr")
def main(config=None):
my_name = "PLR"
config = OmegaConf.to_container(config)
if config["ued"]["replay_prob"] == 0.0:
my_name = "DR"
elif config["ued"]["use_accel"]:
my_name = "ACCEL"
time_start = time.time()
config = normalise_config(config, my_name)
env_params, static_env_params = generate_params_from_config(config)
config["env_params"] = to_state_dict(env_params)
config["static_env_params"] = to_state_dict(static_env_params)
run = init_wandb(config, my_name)
config = wandb.config
time_prev = time.time()
def log_eval(stats, train_state_info):
nonlocal time_prev
print(f"Logging update: {stats['update_count']}")
total_loss = jnp.mean(stats["losses"][0])
if jnp.isnan(total_loss):
print("NaN loss, skipping logging")
raise ValueError("NaN loss")
# generic stats
env_steps = int(
int(stats["update_count"]) * config["num_train_envs"] * config["num_steps"] * config["outer_rollout_steps"]
)
env_steps_delta = (
config["eval_freq"] * config["num_train_envs"] * config["num_steps"] * config["outer_rollout_steps"]
)
time_now = time.time()
log_dict = {
"timing/num_updates": stats["update_count"],
"timing/num_env_steps": env_steps,
"timing/sps": env_steps_delta / (time_now - time_prev),
"timing/sps_agg": env_steps / (time_now - time_start),
"loss/total_loss": jnp.mean(stats["losses"][0]),
"loss/value_loss": jnp.mean(stats["losses"][1][0]),
"loss/policy_loss": jnp.mean(stats["losses"][1][1]),
"loss/entropy_loss": jnp.mean(stats["losses"][1][2]),
}
time_prev = time_now
# evaluation performance
returns = stats["eval_returns"]
log_dict.update({"eval/mean_eval_return": returns.mean()})
log_dict.update({"eval/mean_eval_learnability": stats["eval_learn"].mean()})
log_dict.update({"eval/mean_eval_solve_rate": stats["eval_solves"].mean()})
log_dict.update({"eval/mean_eval_eplen": stats["eval_ep_lengths"].mean()})
for i in range(config["num_eval_levels"]):
log_dict[f"eval_avg_return/{config['eval_levels'][i]}"] = returns[i]
log_dict[f"eval_avg_learnability/{config['eval_levels'][i]}"] = stats["eval_learn"][i]
log_dict[f"eval_avg_solve_rate/{config['eval_levels'][i]}"] = stats["eval_solves"][i]
log_dict[f"eval_avg_episode_length/{config['eval_levels'][i]}"] = stats["eval_ep_lengths"][i]
log_dict[f"eval_get_max_eplen/{config['eval_levels'][i]}"] = stats["eval_get_max_eplen"][i]
log_dict[f"episode_return_bigger_than_negative/{config['eval_levels'][i]}"] = stats[
"episode_return_bigger_than_negative"
][i]
def _aggregate_per_size(values, name):
to_return = {}
for group_name, indices in eval_group_indices.items():
to_return[f"{name}_{group_name}"] = values[indices].mean()
return to_return
log_dict.update(_aggregate_per_size(returns, "eval_aggregate/return"))
log_dict.update(_aggregate_per_size(stats["eval_solves"], "eval_aggregate/solve_rate"))
if config["EVAL_ON_SAMPLED"]:
log_dict.update({"eval/mean_eval_return_sampled": stats["eval_dr_returns"].mean()})
log_dict.update({"eval/mean_eval_solve_rate_sampled": stats["eval_dr_solve_rates"].mean()})
log_dict.update({"eval/mean_eval_eplen_sampled": stats["eval_dr_eplen"].mean()})
# level sampler
log_dict.update(train_state_info["log"])
# images
log_dict.update(
{
"images/highest_scoring_level": wandb.Image(
np.array(stats["highest_scoring_level"]), caption="Highest scoring level"
)
}
)
log_dict.update(
{
"images/highest_weighted_level": wandb.Image(
np.array(stats["highest_weighted_level"]), caption="Highest weighted level"
)
}
)
for s in ["dr", "replay", "mutation"]:
if train_state_info["info"][f"num_{s}_updates"] > 0:
log_dict.update(
{
f"images/{s}_levels": [
wandb.Image(np.array(image), caption=f"{score}")
for image, score in zip(stats[f"{s}_levels"], stats[f"{s}_scores"])
]
}
)
if stats["log_videos"]:
# animations
rollout_ep = stats[f"{s}_ep_len"]
arr = np.array(stats[f"{s}_rollout"][:rollout_ep])
log_dict.update(
{
f"media/{s}_eval": wandb.Video(
arr.astype(np.uint8), fps=15, caption=f"{s.capitalize()} (len {rollout_ep})"
)
}
)
# * 255
# DR, Replay and Mutate Returns
dr_inds = (stats["update_state"] == UpdateState.DR).nonzero()[0]
rep_inds = (stats["update_state"] == UpdateState.REPLAY).nonzero()[0]
mut_inds = (stats["update_state"] == UpdateState.MUTATE).nonzero()[0]
for name, inds in [
("DR", dr_inds),
("REPLAY", rep_inds),
("MUTATION", mut_inds),
]:
if len(inds) > 0:
log_dict.update(
{
f"{name}/episode_return": stats["episode_return"][inds].mean(),
f"{name}/mean_eplen": stats["returned_episode_lengths"][inds].mean(),
f"{name}/mean_success": stats["returned_episode_solved"][inds].mean(),
f"{name}/noop_return": stats["noop_returns"][inds].mean(),
f"{name}/noop_eplen": stats["noop_eplen"][inds].mean(),
f"{name}/noop_success": stats["noop_success"][inds].mean(),
f"{name}/random_return": stats["random_returns"][inds].mean(),
f"{name}/random_eplen": stats["random_eplen"][inds].mean(),
f"{name}/random_success": stats["random_success"][inds].mean(),
}
)
for k in stats:
if "complexity/" in k:
k2 = "complexity/" + name + "_" + k.replace("complexity/", "")
log_dict.update({k2: stats[k][inds].mean()})
if "ued_scores/" in k:
k2 = "ued_scores/" + name + "_" + k.replace("ued_scores/", "")
log_dict.update({k2: stats[k][inds].mean()})
# Eval rollout animations
if stats["log_videos"]:
for i in range((config["num_eval_levels"])):
frames, episode_length = stats["eval_animation"][0][:, i], stats["eval_animation"][1][i]
frames = np.array(frames[:episode_length])
log_dict.update(
{
f"media/eval_video_{config['eval_levels'][i]}": wandb.Video(
frames.astype(np.uint8), fps=15, caption=f"Len ({episode_length})"
)
}
)
wandb.log(log_dict)
def get_all_metrics(
rng,
losses,
info,
init_env_state,
init_obs,
dones,
grads,
all_ued_scores,
new_levels,
):
noop_returns, noop_len, noop_success, random_returns, random_lens, random_success = no_op_and_random_rollout(
env,
env_params,
rng,
init_obs,
init_env_state,
config["num_train_envs"],
config["num_steps"] * config["outer_rollout_steps"],
)
metrics = (
{
"losses": jax.tree_util.tree_map(lambda x: x.mean(), losses),
"returned_episode_lengths": (info["returned_episode_lengths"] * dones).sum()
/ jnp.maximum(1, dones.sum()),
"max_episode_length": info["returned_episode_lengths"].max(),
"levels_played": init_env_state.env_state.env_state,
"episode_return": (info["returned_episode_returns"] * dones).sum() / jnp.maximum(1, dones.sum()),
"episode_return_v2": (info["returned_episode_returns"] * info["returned_episode"]).sum()
/ jnp.maximum(1, info["returned_episode"].sum()),
"grad_norms": grads.mean(),
"noop_returns": noop_returns,
"noop_eplen": noop_len,
"noop_success": noop_success,
"random_returns": random_returns,
"random_eplen": random_lens,
"random_success": random_success,
"returned_episode_solved": (info["returned_episode_solved"] * dones).sum()
/ jnp.maximum(1, dones.sum()),
}
| get_level_complexity_metrics(new_levels, static_env_params)
| get_ued_score_metrics(all_ued_scores)
)
return metrics
# Setup the environment.
def make_env(static_env_params):
env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
env = AutoReplayWrapper(env)
env = UnderspecifiedToGymnaxWrapper(env)
env = DenseRewardWrapper(env, dense_reward_scale=config["dense_reward_scale"])
env = LogWrapper(env)
return env
env = make_env(static_env_params)
if config["train_level_mode"] == "list":
sample_random_level = make_reset_train_function_with_list_of_levels(
config, config["train_levels_list"], static_env_params, make_pcg_state=False, is_loading_train_levels=True
)
elif config["train_level_mode"] == "random":
sample_random_level = make_reset_train_function_with_mutations(
env.physics_engine, env_params, static_env_params, config, make_pcg_state=False
)
else:
raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}")
if config["use_accel"] and config["accel_start_from_empty"]:
def make_sample_random_level():
def inner(rng):
def _inner_accel(rng):
return create_random_starting_distribution(
rng, env_params, static_env_params, ued_params, config["env_size_name"], controllable=True
)
def _inner_accel_not_controllable(rng):
return create_random_starting_distribution(
rng, env_params, static_env_params, ued_params, config["env_size_name"], controllable=False
)
rng, _rng = jax.random.split(rng)
return _inner_accel(_rng)
return inner
sample_random_level = make_sample_random_level()
sample_random_levels = make_vmapped_filtered_level_sampler(
sample_random_level, env_params, static_env_params, config, make_pcg_state=False, env=env
)
def generate_world():
raise NotImplementedError
pass
def generate_eval_world(rng, env_params, static_env_params, level_idx):
# jax.random.split(jax.random.PRNGKey(101), num_levels), env_params, static_env_params, jnp.arange(num_levels)
raise NotImplementedError
_, eval_static_env_params = generate_params_from_config(
config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
)
eval_env = make_env(eval_static_env_params)
ued_params = generate_ued_params_from_config(config)
mutate_world = make_mutate_env(static_env_params, env_params, ued_params)
def make_render_fn(static_env_params):
render_fn_inner = make_render_pixels(env_params, static_env_params)
render_fn = lambda x: render_fn_inner(x).transpose(1, 0, 2)[::-1]
return render_fn
render_fn = make_render_fn(static_env_params)
render_fn_eval = make_render_fn(eval_static_env_params)
if config["EVAL_ON_SAMPLED"]:
NUM_EVAL_DR_LEVELS = 200
key_to_sample_dr_eval_set = jax.random.PRNGKey(100)
DR_EVAL_LEVELS = sample_random_levels(key_to_sample_dr_eval_set, NUM_EVAL_DR_LEVELS)
# And the level sampler
level_sampler = LevelSampler(
capacity=config["level_buffer_capacity"],
replay_prob=config["replay_prob"],
staleness_coeff=config["staleness_coeff"],
minimum_fill_ratio=config["minimum_fill_ratio"],
prioritization=config["prioritization"],
prioritization_params={"temperature": config["temperature"], "k": config["topk_k"]},
duplicate_check=config["buffer_duplicate_check"],
)
@jax.jit
def create_train_state(rng) -> TrainState:
# Creates the train state
def linear_schedule(count):
frac = 1.0 - (count // (config["num_minibatches"] * config["update_epochs"])) / (
config["num_updates"] * config["outer_rollout_steps"]
)
return config["lr"] * frac
rng, _rng = jax.random.split(rng)
init_state = jax.tree.map(lambda x: x[0], sample_random_levels(_rng, 1))
rng, _rng = jax.random.split(rng)
obs, _ = env.reset_to_level(_rng, init_state, env_params)
ns = config["num_steps"] * config["outer_rollout_steps"]
obs = jax.tree.map(
lambda x: jnp.repeat(jnp.repeat(x[None, ...], config["num_train_envs"], axis=0)[None, ...], ns, axis=0),
obs,
)
init_x = (obs, jnp.zeros((ns, config["num_train_envs"]), dtype=jnp.bool_))
network = make_network_from_config(env, env_params, config)
rng, _rng = jax.random.split(rng)
network_params = network.init(_rng, ScannedRNN.initialize_carry(config["num_train_envs"]), init_x)
if config["anneal_lr"]:
tx = optax.chain(
optax.clip_by_global_norm(config["max_grad_norm"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(config["max_grad_norm"]),
optax.adam(config["lr"], eps=1e-5),
)
pholder_level = jax.tree.map(lambda x: x[0], sample_random_levels(jax.random.PRNGKey(0), 1))
sampler = level_sampler.initialize(pholder_level, {"max_return": -jnp.inf})
pholder_level_batch = jax.tree_util.tree_map(
lambda x: jnp.array([x]).repeat(config["num_train_envs"], axis=0), pholder_level
)
pholder_rollout_batch = (
jax.tree.map(
lambda x: jnp.repeat(
jnp.expand_dims(x, 0), repeats=config["num_steps"] * config["outer_rollout_steps"], axis=0
),
init_state,
),
init_x[1][:, 0],
)
pholder_level_batch_scores = jnp.zeros((config["num_train_envs"],), dtype=jnp.float32)
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
sampler=sampler,
update_state=0,
num_dr_updates=0,
num_replay_updates=0,
num_mutation_updates=0,
dr_last_level_batch_scores=pholder_level_batch_scores,
replay_last_level_batch_scores=pholder_level_batch_scores,
mutation_last_level_batch_scores=pholder_level_batch_scores,
dr_last_level_batch=pholder_level_batch,
replay_last_level_batch=pholder_level_batch,
mutation_last_level_batch=pholder_level_batch,
dr_last_rollout_batch=pholder_rollout_batch,
replay_last_rollout_batch=pholder_rollout_batch,
mutation_last_rollout_batch=pholder_rollout_batch,
)
if config["load_from_checkpoint"] != None:
print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"])
train_state = load_train_state_from_wandb_artifact_path(
train_state,
config["load_from_checkpoint"],
load_only_params=config["load_only_params"],
legacy=config["load_legacy_checkpoint"],
)
return train_state
all_eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
eval_group_indices = get_eval_level_groups(config["eval_levels"])
@jax.jit
def train_step(carry: Tuple[chex.PRNGKey, TrainState], _):
"""
This is the main training loop. It basically calls either `on_new_levels`, `on_replay_levels`, or `on_mutate_levels` at every step.
"""
def on_new_levels(rng: chex.PRNGKey, train_state: TrainState):
"""
Samples new (randomly-generated) levels and evaluates the policy on these. It also then adds the levels to the level buffer if they have high-enough scores.
The agent is updated on these trajectories iff `config["exploratory_grad_updates"]` is True.
"""
sampler = train_state.sampler
# Reset
rng, rng_levels, rng_reset = jax.random.split(rng, 3)
new_levels = sample_random_levels(rng_levels, config["num_train_envs"])
init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
jax.random.split(rng_reset, config["num_train_envs"]), new_levels, env_params
)
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
# Rollout
(
(rng, train_state, new_hstate, last_obs, last_env_state),
(
obs,
actions,
rewards,
dones,
log_probs,
values,
info,
advantages,
targets,
losses,
grads,
rollout_states,
),
) = sample_trajectories_and_learn(
env,
env_params,
config,
rng,
train_state,
init_hstate,
init_obs,
init_env_state,
update_grad=config["exploratory_grad_updates"],
return_states=True,
)
max_returns = compute_max_returns(dones, rewards)
scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
sampler, _ = level_sampler.insert_batch(sampler, new_levels, scores, {"max_return": max_returns})
rng, _rng = jax.random.split(rng)
metrics = {
"update_state": UpdateState.DR,
} | get_all_metrics(_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, new_levels)
train_state = train_state.replace(
sampler=sampler,
update_state=UpdateState.DR,
num_dr_updates=train_state.num_dr_updates + 1,
dr_last_level_batch=new_levels,
dr_last_level_batch_scores=scores,
dr_last_rollout_batch=jax.tree.map(
lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
),
)
return (rng, train_state), metrics
def on_replay_levels(rng: chex.PRNGKey, train_state: TrainState):
"""
This samples levels from the level buffer, and updates the policy on them.
"""
sampler = train_state.sampler
# Collect trajectories on replay levels
rng, rng_levels, rng_reset = jax.random.split(rng, 3)
sampler, (level_inds, levels) = level_sampler.sample_replay_levels(
sampler, rng_levels, config["num_train_envs"]
)
init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
jax.random.split(rng_reset, config["num_train_envs"]), levels, env_params
)
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
(
(rng, train_state, new_hstate, last_obs, last_env_state),
(
obs,
actions,
rewards,
dones,
log_probs,
values,
info,
advantages,
targets,
losses,
grads,
rollout_states,
),
) = sample_trajectories_and_learn(
env,
env_params,
config,
rng,
train_state,
init_hstate,
init_obs,
init_env_state,
update_grad=True,
return_states=True,
)
max_returns = jnp.maximum(
level_sampler.get_levels_extra(sampler, level_inds)["max_return"], compute_max_returns(dones, rewards)
)
scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
sampler = level_sampler.update_batch(sampler, level_inds, scores, {"max_return": max_returns})
rng, _rng = jax.random.split(rng)
metrics = {
"update_state": UpdateState.REPLAY,
} | get_all_metrics(_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, levels)
train_state = train_state.replace(
sampler=sampler,
update_state=UpdateState.REPLAY,
num_replay_updates=train_state.num_replay_updates + 1,
replay_last_level_batch=levels,
replay_last_level_batch_scores=scores,
replay_last_rollout_batch=jax.tree.map(
lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
),
)
return (rng, train_state), metrics
def on_mutate_levels(rng: chex.PRNGKey, train_state: TrainState):
"""
This mutates the previous batch of replay levels and potentially adds them to the level buffer.
This also updates the policy iff `config["exploratory_grad_updates"]` is True.
"""
sampler = train_state.sampler
rng, rng_mutate, rng_reset = jax.random.split(rng, 3)
# mutate
parent_levels = train_state.replay_last_level_batch
child_levels = jax.vmap(mutate_world, (0, 0, None))(
jax.random.split(rng_mutate, config["num_train_envs"]), parent_levels, config["num_edits"]
)
init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
jax.random.split(rng_reset, config["num_train_envs"]), child_levels, env_params
)
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
# rollout
(
(rng, train_state, new_hstate, last_obs, last_env_state),
(
obs,
actions,
rewards,
dones,
log_probs,
values,
info,
advantages,
targets,
losses,
grads,
rollout_states,
),
) = sample_trajectories_and_learn(
env,
env_params,
config,
rng,
train_state,
init_hstate,
init_obs,
init_env_state,
update_grad=config["exploratory_grad_updates"],
return_states=True,
)
max_returns = compute_max_returns(dones, rewards)
scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
sampler, _ = level_sampler.insert_batch(sampler, child_levels, scores, {"max_return": max_returns})
rng, _rng = jax.random.split(rng)
metrics = {"update_state": UpdateState.MUTATE,} | get_all_metrics(
_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, child_levels
)
train_state = train_state.replace(
sampler=sampler,
update_state=UpdateState.DR,
num_mutation_updates=train_state.num_mutation_updates + 1,
mutation_last_level_batch=child_levels,
mutation_last_level_batch_scores=scores,
mutation_last_rollout_batch=jax.tree.map(
lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
),
)
return (rng, train_state), metrics
rng, train_state = carry
rng, rng_replay = jax.random.split(rng)
# The train step makes a decision on which branch to take, either on_new, on_replay or on_mutate.
# on_mutate is only called if the replay branch has been taken before (as it uses `train_state.update_state`).
branches = [
on_new_levels,
on_replay_levels,
]
if config["use_accel"]:
s = train_state.update_state
branch = (1 - s) * level_sampler.sample_replay_decision(train_state.sampler, rng_replay) + 2 * s
branches.append(on_mutate_levels)
else:
branch = level_sampler.sample_replay_decision(train_state.sampler, rng_replay).astype(int)
return jax.lax.switch(branch, branches, rng, train_state)
@partial(jax.jit, static_argnums=(2,))
def eval(rng: chex.PRNGKey, train_state: TrainState, keep_states=True):
"""
This evaluates the current policy on the set of evaluation levels specified by config["eval_levels"].
It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
"""
num_levels = config["num_eval_levels"]
return general_eval(
rng,
eval_env,
env_params,
train_state,
all_eval_levels,
env_params.max_timesteps,
num_levels,
keep_states=keep_states,
return_trajectories=True,
)
@partial(jax.jit, static_argnums=(2,))
def eval_on_dr_levels(rng: chex.PRNGKey, train_state: TrainState, keep_states=False):
return general_eval(
rng,
env,
env_params,
train_state,
DR_EVAL_LEVELS,
env_params.max_timesteps,
NUM_EVAL_DR_LEVELS,
keep_states=keep_states,
)
@jax.jit
def train_and_eval_step(runner_state, _):
"""
This function runs the train_step for a certain number of iterations, and then evaluates the policy.
It returns the updated train state, and a dictionary of metrics.
"""
# Train
(rng, train_state), metrics = jax.lax.scan(train_step, runner_state, None, config["eval_freq"])
# Eval
metrics["update_count"] = (
train_state.num_dr_updates + train_state.num_replay_updates + train_state.num_mutation_updates
)
vid_frequency = get_video_frequency(config, metrics["update_count"])
should_log_videos = metrics["update_count"] % vid_frequency == 0
def _compute_eval_learnability(dones, rewards, infos):
@jax.vmap
def _single(d, r, i):
learn, num_eps, num_succ = compute_learnability(config, d, r, i, config["num_eval_levels"])
return num_eps, num_succ.squeeze(-1)
num_eps, num_succ = _single(dones, rewards, infos)
num_eps, num_succ = num_eps.sum(axis=0), num_succ.sum(axis=0)
success_rate = num_succ / jnp.maximum(1, num_eps)
return success_rate * (1 - success_rate)
@jax.jit
def _get_eval(rng):
metrics = {}
rng, rng_eval = jax.random.split(rng)
(states, cum_rewards, done_idx, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(
eval, (0, None)
)(jax.random.split(rng_eval, config["eval_num_attempts"]), train_state)
# learnability here of the holdout set:
eval_learn = _compute_eval_learnability(eval_dones, eval_rewards, eval_infos)
# Collect Metrics
eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,)
eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum(
1, eval_dones.sum(axis=1)
)
eval_solves = eval_solves.mean(axis=0)
metrics["eval_returns"] = eval_returns
metrics["eval_ep_lengths"] = episode_lengths.mean(axis=0)
metrics["eval_learn"] = eval_learn
metrics["eval_solves"] = eval_solves
metrics["eval_get_max_eplen"] = (episode_lengths == env_params.max_timesteps).mean(axis=0)
metrics["episode_return_bigger_than_negative"] = (cum_rewards > -0.4).mean(axis=0)
if config["EVAL_ON_SAMPLED"]:
states_dr, cum_rewards_dr, done_idx_dr, episode_lengths_dr, infos_dr = jax.vmap(
eval_on_dr_levels, (0, None)
)(jax.random.split(rng_eval, config["eval_num_attempts"]), train_state)
eval_dr_returns = cum_rewards_dr.mean(axis=0).mean()
eval_dr_eplen = episode_lengths_dr.mean(axis=0).mean()
my_eval_dones = infos_dr["returned_episode"]
eval_dr_solves = (infos_dr["returned_episode_solved"] * my_eval_dones).sum(axis=1) / jnp.maximum(
1, my_eval_dones.sum(axis=1)
)
metrics["eval_dr_returns"] = eval_dr_returns
metrics["eval_dr_eplen"] = eval_dr_eplen
metrics["eval_dr_solve_rates"] = eval_dr_solves
return metrics, states, episode_lengths, cum_rewards
@jax.jit
def _get_videos(rng, states, episode_lengths, cum_rewards):
metrics = {"log_videos": True}
# just grab the first run
states, episode_lengths = jax.tree_util.tree_map(
lambda x: x[0], (states, episode_lengths)
) # (num_steps, num_eval_levels, ...), (num_eval_levels,)
# And one attempt
states = jax.tree_util.tree_map(lambda x: x[:, :], states)
episode_lengths = episode_lengths[:]
images = jax.vmap(jax.vmap(render_fn_eval))(
states.env_state.env_state.env_state
) # (num_steps, num_eval_levels, ...)
frames = images.transpose(
0, 1, 4, 2, 3
) # WandB expects color channel before image dimensions when dealing with animations for some reason
@jax.jit
def _get_video(rollout_batch):
states = rollout_batch[0]
images = jax.vmap(render_fn)(states) # dimensions are (steps, x, y, 3)
return (
# jax.tree.map(lambda x: x[:].transpose(0, 2, 1, 3)[:, ::-1], images).transpose(0, 3, 1, 2),
images.transpose(0, 3, 1, 2),
# images.transpose(0, 1, 4, 2, 3),
rollout_batch[1][:].argmax(),
)
# rollouts
metrics["dr_rollout"], metrics["dr_ep_len"] = _get_video(train_state.dr_last_rollout_batch)
metrics["replay_rollout"], metrics["replay_ep_len"] = _get_video(train_state.replay_last_rollout_batch)
metrics["mutation_rollout"], metrics["mutation_ep_len"] = _get_video(
train_state.mutation_last_rollout_batch
)
metrics["eval_animation"] = (frames, episode_lengths)
metrics["eval_returns_video"] = cum_rewards[0]
metrics["eval_len_video"] = episode_lengths
# Eval on sampled
return metrics
@jax.jit
def _get_dummy_videos(rng, states, episode_lengths, cum_rewards):
n_eval = config["num_eval_levels"]
nsteps = env_params.max_timesteps
nsteps2 = config["outer_rollout_steps"] * config["num_steps"]
img_size = (
env.static_env_params.screen_dim[0] // env.static_env_params.downscale,
env.static_env_params.screen_dim[1] // env.static_env_params.downscale,
)
return {
"log_videos": False,
"dr_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
"dr_ep_len": jnp.zeros((), jnp.int32),
"replay_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
"replay_ep_len": jnp.zeros((), jnp.int32),
"mutation_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
"mutation_ep_len": jnp.zeros((), jnp.int32),
# "eval_returns": jnp.zeros((n_eval,), jnp.float32),
# "eval_solves": jnp.zeros((n_eval,), jnp.float32),
# "eval_learn": jnp.zeros((n_eval,), jnp.float32),
# "eval_ep_lengths": jnp.zeros((n_eval,), jnp.int32),
"eval_animation": (
jnp.zeros((nsteps, n_eval, 3, *img_size), jnp.float32),
jnp.zeros((n_eval,), jnp.int32),
),
"eval_returns_video": jnp.zeros((n_eval,), jnp.float32),
"eval_len_video": jnp.zeros((n_eval,), jnp.int32),
}
rng, rng_eval, rng_vid = jax.random.split(rng, 3)
metrics_eval, states, episode_lengths, cum_rewards = _get_eval(rng_eval)
metrics = {
**metrics,
**metrics_eval,
**jax.lax.cond(
should_log_videos, _get_videos, _get_dummy_videos, rng_vid, states, episode_lengths, cum_rewards
),
}
max_num_images = 8
top_regret_ones = max_num_images // 2
bot_regret_ones = max_num_images - top_regret_ones
@jax.jit
def get_values(level_batch, scores):
args = jnp.argsort(scores) # low scores are at the start, high scores are at the end
low_scores = args[:bot_regret_ones]
high_scores = args[-top_regret_ones:]
low_levels = jax.tree.map(lambda x: x[low_scores], level_batch)
high_levels = jax.tree.map(lambda x: x[high_scores], level_batch)
low_scores = scores[low_scores]
high_scores = scores[high_scores]
# now concatenate:
return jax.vmap(render_fn)(
jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), low_levels, high_levels)
), jnp.concatenate([low_scores, high_scores], axis=0)
metrics["dr_levels"], metrics["dr_scores"] = get_values(
train_state.dr_last_level_batch, train_state.dr_last_level_batch_scores
)
metrics["replay_levels"], metrics["replay_scores"] = get_values(
train_state.replay_last_level_batch, train_state.replay_last_level_batch_scores
)
metrics["mutation_levels"], metrics["mutation_scores"] = get_values(
train_state.mutation_last_level_batch, train_state.mutation_last_level_batch_scores
)
def _t(i):
return jax.lax.select(i == 0, config["num_steps"], i)
metrics["dr_ep_len"] = _t(train_state.dr_last_rollout_batch[1][:].argmax())
metrics["replay_ep_len"] = _t(train_state.replay_last_rollout_batch[1][:].argmax())
metrics["mutation_ep_len"] = _t(train_state.mutation_last_rollout_batch[1][:].argmax())
highest_scoring_level = level_sampler.get_levels(train_state.sampler, train_state.sampler["scores"].argmax())
highest_weighted_level = level_sampler.get_levels(
train_state.sampler, level_sampler.level_weights(train_state.sampler).argmax()
)
metrics["highest_scoring_level"] = render_fn(highest_scoring_level)
metrics["highest_weighted_level"] = render_fn(highest_weighted_level)
# log_eval(metrics, train_state_to_log_dict(runner_state[1], level_sampler))
jax.debug.callback(log_eval, metrics, train_state_to_log_dict(runner_state[1], level_sampler))
return (rng, train_state), {"update_count": metrics["update_count"]}
def log_checkpoint(update_count, train_state):
if config["save_path"] is not None and config["checkpoint_save_freq"] > 1:
steps = (
int(update_count)
* int(config["num_train_envs"])
* int(config["num_steps"])
* int(config["outer_rollout_steps"])
)
# save_params_to_wandb(train_state.params, steps, config)
save_model_to_wandb(train_state, steps, config)
def train_eval_and_checkpoint_step(runner_state, _):
runner_state, metrics = jax.lax.scan(
train_and_eval_step, runner_state, xs=jnp.arange(config["checkpoint_save_freq"] // config["eval_freq"])
)
jax.debug.callback(log_checkpoint, metrics["update_count"][-1], runner_state[1])
return runner_state, metrics
# Set up the train states
rng = jax.random.PRNGKey(config["seed"])
rng_init, rng_train = jax.random.split(rng)
train_state = create_train_state(rng_init)
runner_state = (rng_train, train_state)
runner_state, metrics = jax.lax.scan(
train_eval_and_checkpoint_step,
runner_state,
xs=jnp.arange((config["num_updates"]) // (config["checkpoint_save_freq"])),
)
if config["save_path"] is not None:
# save_params_to_wandb(runner_state[1].params, config["total_timesteps"], config)
save_model_to_wandb(runner_state[1], config["total_timesteps"], config, is_final=True)
return runner_state[1]
if __name__ == "__main__":
main()