Spaces:
Runtime error
Runtime error
from functools import partial | |
import time | |
from enum import IntEnum | |
from typing import Tuple | |
import chex | |
import hydra | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from omegaconf import OmegaConf | |
import optax | |
from flax import core, struct | |
from flax.training.train_state import TrainState as BaseTrainState | |
import wandb | |
from kinetix.environment.ued.distributions import ( | |
create_random_starting_distribution, | |
) | |
from kinetix.environment.ued.ued import ( | |
make_mutate_env, | |
make_reset_train_function_with_mutations, | |
make_vmapped_filtered_level_sampler, | |
) | |
from kinetix.environment.ued.ued import ( | |
make_mutate_env, | |
make_reset_train_function_with_list_of_levels, | |
make_reset_train_function_with_mutations, | |
) | |
from kinetix.util.config import ( | |
generate_ued_params_from_config, | |
get_video_frequency, | |
init_wandb, | |
normalise_config, | |
save_data_to_local_file, | |
generate_params_from_config, | |
get_eval_level_groups, | |
) | |
from jaxued.environments.underspecified_env import EnvState | |
from jaxued.level_sampler import LevelSampler | |
from jaxued.utils import compute_max_returns, max_mc, positive_value_loss | |
from flax.serialization import to_state_dict | |
import sys | |
sys.path.append("experiments") | |
from kinetix.environment.env import make_kinetix_env_from_name | |
from kinetix.environment.env_state import StaticEnvParams | |
from kinetix.environment.wrappers import ( | |
UnderspecifiedToGymnaxWrapper, | |
LogWrapper, | |
DenseRewardWrapper, | |
AutoReplayWrapper, | |
) | |
from kinetix.models import make_network_from_config | |
from kinetix.render.renderer_pixels import make_render_pixels | |
from kinetix.models.actor_critic import ScannedRNN | |
from kinetix.util.learning import ( | |
general_eval, | |
get_eval_levels, | |
no_op_and_random_rollout, | |
sample_trajectories_and_learn, | |
) | |
from kinetix.util.saving import ( | |
load_train_state_from_wandb_artifact_path, | |
save_model_to_wandb, | |
) | |
class UpdateState(IntEnum): | |
DR = 0 | |
REPLAY = 1 | |
MUTATE = 2 | |
def get_level_complexity_metrics(all_levels: EnvState, static_env_params: StaticEnvParams): | |
def get_for_single_level(level): | |
return { | |
"complexity/num_shapes": level.polygon.active[static_env_params.num_static_fixated_polys :].sum() | |
+ level.circle.active.sum(), | |
"complexity/num_joints": level.joint.active.sum(), | |
"complexity/num_thrusters": level.thruster.active.sum(), | |
"complexity/num_rjoints": (level.joint.active * jnp.logical_not(level.joint.is_fixed_joint)).sum(), | |
"complexity/num_fjoints": (level.joint.active * (level.joint.is_fixed_joint)).sum(), | |
"complexity/has_ball": ((level.polygon_shape_roles == 1) * level.polygon.active).sum() | |
+ ((level.circle_shape_roles == 1) * level.circle.active).sum(), | |
"complexity/has_goal": ((level.polygon_shape_roles == 2) * level.polygon.active).sum() | |
+ ((level.circle_shape_roles == 2) * level.circle.active).sum(), | |
} | |
return jax.tree.map(lambda x: x.mean(), jax.vmap(get_for_single_level)(all_levels)) | |
def get_ued_score_metrics(all_ued_scores): | |
(mc, pvl, learn) = all_ued_scores | |
scores = {} | |
for score, name in zip([mc, pvl, learn], ["MaxMC", "PVL", "Learnability"]): | |
scores[f"ued_scores/{name}/Mean"] = score.mean() | |
scores[f"ued_scores_additional/{name}/Max"] = score.max() | |
scores[f"ued_scores_additional/{name}/Min"] = score.min() | |
return scores | |
class TrainState(BaseTrainState): | |
sampler: core.FrozenDict[str, chex.ArrayTree] = struct.field(pytree_node=True) | |
update_state: UpdateState = struct.field(pytree_node=True) | |
# === Below is used for logging === | |
num_dr_updates: int | |
num_replay_updates: int | |
num_mutation_updates: int | |
dr_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True) | |
replay_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True) | |
mutation_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True) | |
dr_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True) | |
replay_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True) | |
mutation_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True) | |
dr_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True) | |
replay_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True) | |
mutation_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True) | |
# region PPO helper functions | |
# endregion | |
def train_state_to_log_dict(train_state: TrainState, level_sampler: LevelSampler) -> dict: | |
"""To prevent the entire (large) train_state to be copied to the CPU when doing logging, this function returns all of the important information in a dictionary format. | |
Anything in the `log` key will be logged to wandb. | |
Args: | |
train_state (TrainState): | |
level_sampler (LevelSampler): | |
Returns: | |
dict: | |
""" | |
sampler = train_state.sampler | |
idx = jnp.arange(level_sampler.capacity) < sampler["size"] | |
s = jnp.maximum(idx.sum(), 1) | |
return { | |
"log": { | |
"level_sampler/size": sampler["size"], | |
"level_sampler/episode_count": sampler["episode_count"], | |
"level_sampler/max_score": sampler["scores"].max(), | |
"level_sampler/weighted_score": (sampler["scores"] * level_sampler.level_weights(sampler)).sum(), | |
"level_sampler/mean_score": (sampler["scores"] * idx).sum() / s, | |
}, | |
"info": { | |
"num_dr_updates": train_state.num_dr_updates, | |
"num_replay_updates": train_state.num_replay_updates, | |
"num_mutation_updates": train_state.num_mutation_updates, | |
}, | |
} | |
def compute_learnability(config, done, reward, info, num_envs): | |
num_agents = 1 | |
BATCH_ACTORS = num_envs * num_agents | |
rollout_length = config["num_steps"] * config["outer_rollout_steps"] | |
def _calc_outcomes_by_agent(max_steps: int, dones, returns, info): | |
idxs = jnp.arange(max_steps) | |
def __ep_outcomes(start_idx, end_idx): | |
mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps) | |
r = jnp.sum(returns * mask) | |
goal_r = info["GoalR"] | |
success = jnp.sum(goal_r * mask) | |
collision = 0 | |
timeo = 0 | |
l = end_idx - start_idx | |
return r, success, collision, timeo, l | |
done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze() | |
mask_done = jnp.where(done_idxs == max_steps, 0, 1) | |
ep_return, success, collision, timeo, length = __ep_outcomes( | |
jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs | |
) | |
return { | |
"ep_return": ep_return.mean(where=mask_done), | |
"num_episodes": mask_done.sum(), | |
"num_success": success.sum(where=mask_done), | |
"success_rate": success.mean(where=mask_done), | |
"collision_rate": collision.mean(where=mask_done), | |
"timeout_rate": timeo.mean(where=mask_done), | |
"ep_len": length.mean(where=mask_done), | |
} | |
done_by_env = done.reshape((-1, num_agents, num_envs)) | |
reward_by_env = reward.reshape((-1, num_agents, num_envs)) | |
o = _calc_outcomes_by_agent(rollout_length, done, reward, info) | |
success_by_env = o["success_rate"].reshape((num_agents, num_envs)) | |
learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0) | |
return ( | |
learnability_by_env, | |
o["num_episodes"].reshape(num_agents, num_envs).sum(axis=0), | |
o["num_success"].reshape(num_agents, num_envs).T, | |
) # so agents is at the end. | |
def compute_score( | |
config: dict, dones: chex.Array, values: chex.Array, max_returns: chex.Array, reward, info, advantages: chex.Array | |
) -> chex.Array: | |
# Computes the score for each level | |
if config["score_function"] == "MaxMC": | |
return max_mc(dones, values, max_returns) | |
elif config["score_function"] == "pvl": | |
return positive_value_loss(dones, advantages) | |
elif config["score_function"] == "learnability": | |
learnability, num_episodes, num_success = compute_learnability( | |
config, dones, reward, info, config["num_train_envs"] | |
) | |
return learnability | |
else: | |
raise ValueError(f"Unknown score function: {config['score_function']}") | |
def compute_all_scores( | |
config: dict, | |
dones: chex.Array, | |
values: chex.Array, | |
max_returns: chex.Array, | |
reward, | |
info, | |
advantages: chex.Array, | |
return_success_rate=False, | |
): | |
mc = max_mc(dones, values, max_returns) | |
pvl = positive_value_loss(dones, advantages) | |
learnability, num_episodes, num_success = compute_learnability( | |
config, dones, reward, info, config["num_train_envs"] | |
) | |
if config["score_function"] == "MaxMC": | |
main_score = mc | |
elif config["score_function"] == "pvl": | |
main_score = pvl | |
elif config["score_function"] == "learnability": | |
main_score = learnability | |
else: | |
raise ValueError(f"Unknown score function: {config['score_function']}") | |
if return_success_rate: | |
success_rate = num_success.squeeze(1) / jnp.maximum(num_episodes, 1) | |
return main_score, (mc, pvl, learnability, success_rate) | |
return main_score, (mc, pvl, learnability) | |
def main(config=None): | |
my_name = "PLR" | |
config = OmegaConf.to_container(config) | |
if config["ued"]["replay_prob"] == 0.0: | |
my_name = "DR" | |
elif config["ued"]["use_accel"]: | |
my_name = "ACCEL" | |
time_start = time.time() | |
config = normalise_config(config, my_name) | |
env_params, static_env_params = generate_params_from_config(config) | |
config["env_params"] = to_state_dict(env_params) | |
config["static_env_params"] = to_state_dict(static_env_params) | |
run = init_wandb(config, my_name) | |
config = wandb.config | |
time_prev = time.time() | |
def log_eval(stats, train_state_info): | |
nonlocal time_prev | |
print(f"Logging update: {stats['update_count']}") | |
total_loss = jnp.mean(stats["losses"][0]) | |
if jnp.isnan(total_loss): | |
print("NaN loss, skipping logging") | |
raise ValueError("NaN loss") | |
# generic stats | |
env_steps = int( | |
int(stats["update_count"]) * config["num_train_envs"] * config["num_steps"] * config["outer_rollout_steps"] | |
) | |
env_steps_delta = ( | |
config["eval_freq"] * config["num_train_envs"] * config["num_steps"] * config["outer_rollout_steps"] | |
) | |
time_now = time.time() | |
log_dict = { | |
"timing/num_updates": stats["update_count"], | |
"timing/num_env_steps": env_steps, | |
"timing/sps": env_steps_delta / (time_now - time_prev), | |
"timing/sps_agg": env_steps / (time_now - time_start), | |
"loss/total_loss": jnp.mean(stats["losses"][0]), | |
"loss/value_loss": jnp.mean(stats["losses"][1][0]), | |
"loss/policy_loss": jnp.mean(stats["losses"][1][1]), | |
"loss/entropy_loss": jnp.mean(stats["losses"][1][2]), | |
} | |
time_prev = time_now | |
# evaluation performance | |
returns = stats["eval_returns"] | |
log_dict.update({"eval/mean_eval_return": returns.mean()}) | |
log_dict.update({"eval/mean_eval_learnability": stats["eval_learn"].mean()}) | |
log_dict.update({"eval/mean_eval_solve_rate": stats["eval_solves"].mean()}) | |
log_dict.update({"eval/mean_eval_eplen": stats["eval_ep_lengths"].mean()}) | |
for i in range(config["num_eval_levels"]): | |
log_dict[f"eval_avg_return/{config['eval_levels'][i]}"] = returns[i] | |
log_dict[f"eval_avg_learnability/{config['eval_levels'][i]}"] = stats["eval_learn"][i] | |
log_dict[f"eval_avg_solve_rate/{config['eval_levels'][i]}"] = stats["eval_solves"][i] | |
log_dict[f"eval_avg_episode_length/{config['eval_levels'][i]}"] = stats["eval_ep_lengths"][i] | |
log_dict[f"eval_get_max_eplen/{config['eval_levels'][i]}"] = stats["eval_get_max_eplen"][i] | |
log_dict[f"episode_return_bigger_than_negative/{config['eval_levels'][i]}"] = stats[ | |
"episode_return_bigger_than_negative" | |
][i] | |
def _aggregate_per_size(values, name): | |
to_return = {} | |
for group_name, indices in eval_group_indices.items(): | |
to_return[f"{name}_{group_name}"] = values[indices].mean() | |
return to_return | |
log_dict.update(_aggregate_per_size(returns, "eval_aggregate/return")) | |
log_dict.update(_aggregate_per_size(stats["eval_solves"], "eval_aggregate/solve_rate")) | |
if config["EVAL_ON_SAMPLED"]: | |
log_dict.update({"eval/mean_eval_return_sampled": stats["eval_dr_returns"].mean()}) | |
log_dict.update({"eval/mean_eval_solve_rate_sampled": stats["eval_dr_solve_rates"].mean()}) | |
log_dict.update({"eval/mean_eval_eplen_sampled": stats["eval_dr_eplen"].mean()}) | |
# level sampler | |
log_dict.update(train_state_info["log"]) | |
# images | |
log_dict.update( | |
{ | |
"images/highest_scoring_level": wandb.Image( | |
np.array(stats["highest_scoring_level"]), caption="Highest scoring level" | |
) | |
} | |
) | |
log_dict.update( | |
{ | |
"images/highest_weighted_level": wandb.Image( | |
np.array(stats["highest_weighted_level"]), caption="Highest weighted level" | |
) | |
} | |
) | |
for s in ["dr", "replay", "mutation"]: | |
if train_state_info["info"][f"num_{s}_updates"] > 0: | |
log_dict.update( | |
{ | |
f"images/{s}_levels": [ | |
wandb.Image(np.array(image), caption=f"{score}") | |
for image, score in zip(stats[f"{s}_levels"], stats[f"{s}_scores"]) | |
] | |
} | |
) | |
if stats["log_videos"]: | |
# animations | |
rollout_ep = stats[f"{s}_ep_len"] | |
arr = np.array(stats[f"{s}_rollout"][:rollout_ep]) | |
log_dict.update( | |
{ | |
f"media/{s}_eval": wandb.Video( | |
arr.astype(np.uint8), fps=15, caption=f"{s.capitalize()} (len {rollout_ep})" | |
) | |
} | |
) | |
# * 255 | |
# DR, Replay and Mutate Returns | |
dr_inds = (stats["update_state"] == UpdateState.DR).nonzero()[0] | |
rep_inds = (stats["update_state"] == UpdateState.REPLAY).nonzero()[0] | |
mut_inds = (stats["update_state"] == UpdateState.MUTATE).nonzero()[0] | |
for name, inds in [ | |
("DR", dr_inds), | |
("REPLAY", rep_inds), | |
("MUTATION", mut_inds), | |
]: | |
if len(inds) > 0: | |
log_dict.update( | |
{ | |
f"{name}/episode_return": stats["episode_return"][inds].mean(), | |
f"{name}/mean_eplen": stats["returned_episode_lengths"][inds].mean(), | |
f"{name}/mean_success": stats["returned_episode_solved"][inds].mean(), | |
f"{name}/noop_return": stats["noop_returns"][inds].mean(), | |
f"{name}/noop_eplen": stats["noop_eplen"][inds].mean(), | |
f"{name}/noop_success": stats["noop_success"][inds].mean(), | |
f"{name}/random_return": stats["random_returns"][inds].mean(), | |
f"{name}/random_eplen": stats["random_eplen"][inds].mean(), | |
f"{name}/random_success": stats["random_success"][inds].mean(), | |
} | |
) | |
for k in stats: | |
if "complexity/" in k: | |
k2 = "complexity/" + name + "_" + k.replace("complexity/", "") | |
log_dict.update({k2: stats[k][inds].mean()}) | |
if "ued_scores/" in k: | |
k2 = "ued_scores/" + name + "_" + k.replace("ued_scores/", "") | |
log_dict.update({k2: stats[k][inds].mean()}) | |
# Eval rollout animations | |
if stats["log_videos"]: | |
for i in range((config["num_eval_levels"])): | |
frames, episode_length = stats["eval_animation"][0][:, i], stats["eval_animation"][1][i] | |
frames = np.array(frames[:episode_length]) | |
log_dict.update( | |
{ | |
f"media/eval_video_{config['eval_levels'][i]}": wandb.Video( | |
frames.astype(np.uint8), fps=15, caption=f"Len ({episode_length})" | |
) | |
} | |
) | |
wandb.log(log_dict) | |
def get_all_metrics( | |
rng, | |
losses, | |
info, | |
init_env_state, | |
init_obs, | |
dones, | |
grads, | |
all_ued_scores, | |
new_levels, | |
): | |
noop_returns, noop_len, noop_success, random_returns, random_lens, random_success = no_op_and_random_rollout( | |
env, | |
env_params, | |
rng, | |
init_obs, | |
init_env_state, | |
config["num_train_envs"], | |
config["num_steps"] * config["outer_rollout_steps"], | |
) | |
metrics = ( | |
{ | |
"losses": jax.tree_util.tree_map(lambda x: x.mean(), losses), | |
"returned_episode_lengths": (info["returned_episode_lengths"] * dones).sum() | |
/ jnp.maximum(1, dones.sum()), | |
"max_episode_length": info["returned_episode_lengths"].max(), | |
"levels_played": init_env_state.env_state.env_state, | |
"episode_return": (info["returned_episode_returns"] * dones).sum() / jnp.maximum(1, dones.sum()), | |
"episode_return_v2": (info["returned_episode_returns"] * info["returned_episode"]).sum() | |
/ jnp.maximum(1, info["returned_episode"].sum()), | |
"grad_norms": grads.mean(), | |
"noop_returns": noop_returns, | |
"noop_eplen": noop_len, | |
"noop_success": noop_success, | |
"random_returns": random_returns, | |
"random_eplen": random_lens, | |
"random_success": random_success, | |
"returned_episode_solved": (info["returned_episode_solved"] * dones).sum() | |
/ jnp.maximum(1, dones.sum()), | |
} | |
| get_level_complexity_metrics(new_levels, static_env_params) | |
| get_ued_score_metrics(all_ued_scores) | |
) | |
return metrics | |
# Setup the environment. | |
def make_env(static_env_params): | |
env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params) | |
env = AutoReplayWrapper(env) | |
env = UnderspecifiedToGymnaxWrapper(env) | |
env = DenseRewardWrapper(env, dense_reward_scale=config["dense_reward_scale"]) | |
env = LogWrapper(env) | |
return env | |
env = make_env(static_env_params) | |
if config["train_level_mode"] == "list": | |
sample_random_level = make_reset_train_function_with_list_of_levels( | |
config, config["train_levels_list"], static_env_params, make_pcg_state=False, is_loading_train_levels=True | |
) | |
elif config["train_level_mode"] == "random": | |
sample_random_level = make_reset_train_function_with_mutations( | |
env.physics_engine, env_params, static_env_params, config, make_pcg_state=False | |
) | |
else: | |
raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}") | |
if config["use_accel"] and config["accel_start_from_empty"]: | |
def make_sample_random_level(): | |
def inner(rng): | |
def _inner_accel(rng): | |
return create_random_starting_distribution( | |
rng, env_params, static_env_params, ued_params, config["env_size_name"], controllable=True | |
) | |
def _inner_accel_not_controllable(rng): | |
return create_random_starting_distribution( | |
rng, env_params, static_env_params, ued_params, config["env_size_name"], controllable=False | |
) | |
rng, _rng = jax.random.split(rng) | |
return _inner_accel(_rng) | |
return inner | |
sample_random_level = make_sample_random_level() | |
sample_random_levels = make_vmapped_filtered_level_sampler( | |
sample_random_level, env_params, static_env_params, config, make_pcg_state=False, env=env | |
) | |
def generate_world(): | |
raise NotImplementedError | |
pass | |
def generate_eval_world(rng, env_params, static_env_params, level_idx): | |
# jax.random.split(jax.random.PRNGKey(101), num_levels), env_params, static_env_params, jnp.arange(num_levels) | |
raise NotImplementedError | |
_, eval_static_env_params = generate_params_from_config( | |
config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]} | |
) | |
eval_env = make_env(eval_static_env_params) | |
ued_params = generate_ued_params_from_config(config) | |
mutate_world = make_mutate_env(static_env_params, env_params, ued_params) | |
def make_render_fn(static_env_params): | |
render_fn_inner = make_render_pixels(env_params, static_env_params) | |
render_fn = lambda x: render_fn_inner(x).transpose(1, 0, 2)[::-1] | |
return render_fn | |
render_fn = make_render_fn(static_env_params) | |
render_fn_eval = make_render_fn(eval_static_env_params) | |
if config["EVAL_ON_SAMPLED"]: | |
NUM_EVAL_DR_LEVELS = 200 | |
key_to_sample_dr_eval_set = jax.random.PRNGKey(100) | |
DR_EVAL_LEVELS = sample_random_levels(key_to_sample_dr_eval_set, NUM_EVAL_DR_LEVELS) | |
# And the level sampler | |
level_sampler = LevelSampler( | |
capacity=config["level_buffer_capacity"], | |
replay_prob=config["replay_prob"], | |
staleness_coeff=config["staleness_coeff"], | |
minimum_fill_ratio=config["minimum_fill_ratio"], | |
prioritization=config["prioritization"], | |
prioritization_params={"temperature": config["temperature"], "k": config["topk_k"]}, | |
duplicate_check=config["buffer_duplicate_check"], | |
) | |
def create_train_state(rng) -> TrainState: | |
# Creates the train state | |
def linear_schedule(count): | |
frac = 1.0 - (count // (config["num_minibatches"] * config["update_epochs"])) / ( | |
config["num_updates"] * config["outer_rollout_steps"] | |
) | |
return config["lr"] * frac | |
rng, _rng = jax.random.split(rng) | |
init_state = jax.tree.map(lambda x: x[0], sample_random_levels(_rng, 1)) | |
rng, _rng = jax.random.split(rng) | |
obs, _ = env.reset_to_level(_rng, init_state, env_params) | |
ns = config["num_steps"] * config["outer_rollout_steps"] | |
obs = jax.tree.map( | |
lambda x: jnp.repeat(jnp.repeat(x[None, ...], config["num_train_envs"], axis=0)[None, ...], ns, axis=0), | |
obs, | |
) | |
init_x = (obs, jnp.zeros((ns, config["num_train_envs"]), dtype=jnp.bool_)) | |
network = make_network_from_config(env, env_params, config) | |
rng, _rng = jax.random.split(rng) | |
network_params = network.init(_rng, ScannedRNN.initialize_carry(config["num_train_envs"]), init_x) | |
if config["anneal_lr"]: | |
tx = optax.chain( | |
optax.clip_by_global_norm(config["max_grad_norm"]), | |
optax.adam(learning_rate=linear_schedule, eps=1e-5), | |
) | |
else: | |
tx = optax.chain( | |
optax.clip_by_global_norm(config["max_grad_norm"]), | |
optax.adam(config["lr"], eps=1e-5), | |
) | |
pholder_level = jax.tree.map(lambda x: x[0], sample_random_levels(jax.random.PRNGKey(0), 1)) | |
sampler = level_sampler.initialize(pholder_level, {"max_return": -jnp.inf}) | |
pholder_level_batch = jax.tree_util.tree_map( | |
lambda x: jnp.array([x]).repeat(config["num_train_envs"], axis=0), pholder_level | |
) | |
pholder_rollout_batch = ( | |
jax.tree.map( | |
lambda x: jnp.repeat( | |
jnp.expand_dims(x, 0), repeats=config["num_steps"] * config["outer_rollout_steps"], axis=0 | |
), | |
init_state, | |
), | |
init_x[1][:, 0], | |
) | |
pholder_level_batch_scores = jnp.zeros((config["num_train_envs"],), dtype=jnp.float32) | |
train_state = TrainState.create( | |
apply_fn=network.apply, | |
params=network_params, | |
tx=tx, | |
sampler=sampler, | |
update_state=0, | |
num_dr_updates=0, | |
num_replay_updates=0, | |
num_mutation_updates=0, | |
dr_last_level_batch_scores=pholder_level_batch_scores, | |
replay_last_level_batch_scores=pholder_level_batch_scores, | |
mutation_last_level_batch_scores=pholder_level_batch_scores, | |
dr_last_level_batch=pholder_level_batch, | |
replay_last_level_batch=pholder_level_batch, | |
mutation_last_level_batch=pholder_level_batch, | |
dr_last_rollout_batch=pholder_rollout_batch, | |
replay_last_rollout_batch=pholder_rollout_batch, | |
mutation_last_rollout_batch=pholder_rollout_batch, | |
) | |
if config["load_from_checkpoint"] != None: | |
print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"]) | |
train_state = load_train_state_from_wandb_artifact_path( | |
train_state, | |
config["load_from_checkpoint"], | |
load_only_params=config["load_only_params"], | |
legacy=config["load_legacy_checkpoint"], | |
) | |
return train_state | |
all_eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params) | |
eval_group_indices = get_eval_level_groups(config["eval_levels"]) | |
def train_step(carry: Tuple[chex.PRNGKey, TrainState], _): | |
""" | |
This is the main training loop. It basically calls either `on_new_levels`, `on_replay_levels`, or `on_mutate_levels` at every step. | |
""" | |
def on_new_levels(rng: chex.PRNGKey, train_state: TrainState): | |
""" | |
Samples new (randomly-generated) levels and evaluates the policy on these. It also then adds the levels to the level buffer if they have high-enough scores. | |
The agent is updated on these trajectories iff `config["exploratory_grad_updates"]` is True. | |
""" | |
sampler = train_state.sampler | |
# Reset | |
rng, rng_levels, rng_reset = jax.random.split(rng, 3) | |
new_levels = sample_random_levels(rng_levels, config["num_train_envs"]) | |
init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))( | |
jax.random.split(rng_reset, config["num_train_envs"]), new_levels, env_params | |
) | |
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"]) | |
# Rollout | |
( | |
(rng, train_state, new_hstate, last_obs, last_env_state), | |
( | |
obs, | |
actions, | |
rewards, | |
dones, | |
log_probs, | |
values, | |
info, | |
advantages, | |
targets, | |
losses, | |
grads, | |
rollout_states, | |
), | |
) = sample_trajectories_and_learn( | |
env, | |
env_params, | |
config, | |
rng, | |
train_state, | |
init_hstate, | |
init_obs, | |
init_env_state, | |
update_grad=config["exploratory_grad_updates"], | |
return_states=True, | |
) | |
max_returns = compute_max_returns(dones, rewards) | |
scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages) | |
sampler, _ = level_sampler.insert_batch(sampler, new_levels, scores, {"max_return": max_returns}) | |
rng, _rng = jax.random.split(rng) | |
metrics = { | |
"update_state": UpdateState.DR, | |
} | get_all_metrics(_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, new_levels) | |
train_state = train_state.replace( | |
sampler=sampler, | |
update_state=UpdateState.DR, | |
num_dr_updates=train_state.num_dr_updates + 1, | |
dr_last_level_batch=new_levels, | |
dr_last_level_batch_scores=scores, | |
dr_last_rollout_batch=jax.tree.map( | |
lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones) | |
), | |
) | |
return (rng, train_state), metrics | |
def on_replay_levels(rng: chex.PRNGKey, train_state: TrainState): | |
""" | |
This samples levels from the level buffer, and updates the policy on them. | |
""" | |
sampler = train_state.sampler | |
# Collect trajectories on replay levels | |
rng, rng_levels, rng_reset = jax.random.split(rng, 3) | |
sampler, (level_inds, levels) = level_sampler.sample_replay_levels( | |
sampler, rng_levels, config["num_train_envs"] | |
) | |
init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))( | |
jax.random.split(rng_reset, config["num_train_envs"]), levels, env_params | |
) | |
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"]) | |
( | |
(rng, train_state, new_hstate, last_obs, last_env_state), | |
( | |
obs, | |
actions, | |
rewards, | |
dones, | |
log_probs, | |
values, | |
info, | |
advantages, | |
targets, | |
losses, | |
grads, | |
rollout_states, | |
), | |
) = sample_trajectories_and_learn( | |
env, | |
env_params, | |
config, | |
rng, | |
train_state, | |
init_hstate, | |
init_obs, | |
init_env_state, | |
update_grad=True, | |
return_states=True, | |
) | |
max_returns = jnp.maximum( | |
level_sampler.get_levels_extra(sampler, level_inds)["max_return"], compute_max_returns(dones, rewards) | |
) | |
scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages) | |
sampler = level_sampler.update_batch(sampler, level_inds, scores, {"max_return": max_returns}) | |
rng, _rng = jax.random.split(rng) | |
metrics = { | |
"update_state": UpdateState.REPLAY, | |
} | get_all_metrics(_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, levels) | |
train_state = train_state.replace( | |
sampler=sampler, | |
update_state=UpdateState.REPLAY, | |
num_replay_updates=train_state.num_replay_updates + 1, | |
replay_last_level_batch=levels, | |
replay_last_level_batch_scores=scores, | |
replay_last_rollout_batch=jax.tree.map( | |
lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones) | |
), | |
) | |
return (rng, train_state), metrics | |
def on_mutate_levels(rng: chex.PRNGKey, train_state: TrainState): | |
""" | |
This mutates the previous batch of replay levels and potentially adds them to the level buffer. | |
This also updates the policy iff `config["exploratory_grad_updates"]` is True. | |
""" | |
sampler = train_state.sampler | |
rng, rng_mutate, rng_reset = jax.random.split(rng, 3) | |
# mutate | |
parent_levels = train_state.replay_last_level_batch | |
child_levels = jax.vmap(mutate_world, (0, 0, None))( | |
jax.random.split(rng_mutate, config["num_train_envs"]), parent_levels, config["num_edits"] | |
) | |
init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))( | |
jax.random.split(rng_reset, config["num_train_envs"]), child_levels, env_params | |
) | |
init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"]) | |
# rollout | |
( | |
(rng, train_state, new_hstate, last_obs, last_env_state), | |
( | |
obs, | |
actions, | |
rewards, | |
dones, | |
log_probs, | |
values, | |
info, | |
advantages, | |
targets, | |
losses, | |
grads, | |
rollout_states, | |
), | |
) = sample_trajectories_and_learn( | |
env, | |
env_params, | |
config, | |
rng, | |
train_state, | |
init_hstate, | |
init_obs, | |
init_env_state, | |
update_grad=config["exploratory_grad_updates"], | |
return_states=True, | |
) | |
max_returns = compute_max_returns(dones, rewards) | |
scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages) | |
sampler, _ = level_sampler.insert_batch(sampler, child_levels, scores, {"max_return": max_returns}) | |
rng, _rng = jax.random.split(rng) | |
metrics = {"update_state": UpdateState.MUTATE,} | get_all_metrics( | |
_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, child_levels | |
) | |
train_state = train_state.replace( | |
sampler=sampler, | |
update_state=UpdateState.DR, | |
num_mutation_updates=train_state.num_mutation_updates + 1, | |
mutation_last_level_batch=child_levels, | |
mutation_last_level_batch_scores=scores, | |
mutation_last_rollout_batch=jax.tree.map( | |
lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones) | |
), | |
) | |
return (rng, train_state), metrics | |
rng, train_state = carry | |
rng, rng_replay = jax.random.split(rng) | |
# The train step makes a decision on which branch to take, either on_new, on_replay or on_mutate. | |
# on_mutate is only called if the replay branch has been taken before (as it uses `train_state.update_state`). | |
branches = [ | |
on_new_levels, | |
on_replay_levels, | |
] | |
if config["use_accel"]: | |
s = train_state.update_state | |
branch = (1 - s) * level_sampler.sample_replay_decision(train_state.sampler, rng_replay) + 2 * s | |
branches.append(on_mutate_levels) | |
else: | |
branch = level_sampler.sample_replay_decision(train_state.sampler, rng_replay).astype(int) | |
return jax.lax.switch(branch, branches, rng, train_state) | |
def eval(rng: chex.PRNGKey, train_state: TrainState, keep_states=True): | |
""" | |
This evaluates the current policy on the set of evaluation levels specified by config["eval_levels"]. | |
It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,) | |
""" | |
num_levels = config["num_eval_levels"] | |
return general_eval( | |
rng, | |
eval_env, | |
env_params, | |
train_state, | |
all_eval_levels, | |
env_params.max_timesteps, | |
num_levels, | |
keep_states=keep_states, | |
return_trajectories=True, | |
) | |
def eval_on_dr_levels(rng: chex.PRNGKey, train_state: TrainState, keep_states=False): | |
return general_eval( | |
rng, | |
env, | |
env_params, | |
train_state, | |
DR_EVAL_LEVELS, | |
env_params.max_timesteps, | |
NUM_EVAL_DR_LEVELS, | |
keep_states=keep_states, | |
) | |
def train_and_eval_step(runner_state, _): | |
""" | |
This function runs the train_step for a certain number of iterations, and then evaluates the policy. | |
It returns the updated train state, and a dictionary of metrics. | |
""" | |
# Train | |
(rng, train_state), metrics = jax.lax.scan(train_step, runner_state, None, config["eval_freq"]) | |
# Eval | |
metrics["update_count"] = ( | |
train_state.num_dr_updates + train_state.num_replay_updates + train_state.num_mutation_updates | |
) | |
vid_frequency = get_video_frequency(config, metrics["update_count"]) | |
should_log_videos = metrics["update_count"] % vid_frequency == 0 | |
def _compute_eval_learnability(dones, rewards, infos): | |
def _single(d, r, i): | |
learn, num_eps, num_succ = compute_learnability(config, d, r, i, config["num_eval_levels"]) | |
return num_eps, num_succ.squeeze(-1) | |
num_eps, num_succ = _single(dones, rewards, infos) | |
num_eps, num_succ = num_eps.sum(axis=0), num_succ.sum(axis=0) | |
success_rate = num_succ / jnp.maximum(1, num_eps) | |
return success_rate * (1 - success_rate) | |
def _get_eval(rng): | |
metrics = {} | |
rng, rng_eval = jax.random.split(rng) | |
(states, cum_rewards, done_idx, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap( | |
eval, (0, None) | |
)(jax.random.split(rng_eval, config["eval_num_attempts"]), train_state) | |
# learnability here of the holdout set: | |
eval_learn = _compute_eval_learnability(eval_dones, eval_rewards, eval_infos) | |
# Collect Metrics | |
eval_returns = cum_rewards.mean(axis=0) # (num_eval_levels,) | |
eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum( | |
1, eval_dones.sum(axis=1) | |
) | |
eval_solves = eval_solves.mean(axis=0) | |
metrics["eval_returns"] = eval_returns | |
metrics["eval_ep_lengths"] = episode_lengths.mean(axis=0) | |
metrics["eval_learn"] = eval_learn | |
metrics["eval_solves"] = eval_solves | |
metrics["eval_get_max_eplen"] = (episode_lengths == env_params.max_timesteps).mean(axis=0) | |
metrics["episode_return_bigger_than_negative"] = (cum_rewards > -0.4).mean(axis=0) | |
if config["EVAL_ON_SAMPLED"]: | |
states_dr, cum_rewards_dr, done_idx_dr, episode_lengths_dr, infos_dr = jax.vmap( | |
eval_on_dr_levels, (0, None) | |
)(jax.random.split(rng_eval, config["eval_num_attempts"]), train_state) | |
eval_dr_returns = cum_rewards_dr.mean(axis=0).mean() | |
eval_dr_eplen = episode_lengths_dr.mean(axis=0).mean() | |
my_eval_dones = infos_dr["returned_episode"] | |
eval_dr_solves = (infos_dr["returned_episode_solved"] * my_eval_dones).sum(axis=1) / jnp.maximum( | |
1, my_eval_dones.sum(axis=1) | |
) | |
metrics["eval_dr_returns"] = eval_dr_returns | |
metrics["eval_dr_eplen"] = eval_dr_eplen | |
metrics["eval_dr_solve_rates"] = eval_dr_solves | |
return metrics, states, episode_lengths, cum_rewards | |
def _get_videos(rng, states, episode_lengths, cum_rewards): | |
metrics = {"log_videos": True} | |
# just grab the first run | |
states, episode_lengths = jax.tree_util.tree_map( | |
lambda x: x[0], (states, episode_lengths) | |
) # (num_steps, num_eval_levels, ...), (num_eval_levels,) | |
# And one attempt | |
states = jax.tree_util.tree_map(lambda x: x[:, :], states) | |
episode_lengths = episode_lengths[:] | |
images = jax.vmap(jax.vmap(render_fn_eval))( | |
states.env_state.env_state.env_state | |
) # (num_steps, num_eval_levels, ...) | |
frames = images.transpose( | |
0, 1, 4, 2, 3 | |
) # WandB expects color channel before image dimensions when dealing with animations for some reason | |
def _get_video(rollout_batch): | |
states = rollout_batch[0] | |
images = jax.vmap(render_fn)(states) # dimensions are (steps, x, y, 3) | |
return ( | |
# jax.tree.map(lambda x: x[:].transpose(0, 2, 1, 3)[:, ::-1], images).transpose(0, 3, 1, 2), | |
images.transpose(0, 3, 1, 2), | |
# images.transpose(0, 1, 4, 2, 3), | |
rollout_batch[1][:].argmax(), | |
) | |
# rollouts | |
metrics["dr_rollout"], metrics["dr_ep_len"] = _get_video(train_state.dr_last_rollout_batch) | |
metrics["replay_rollout"], metrics["replay_ep_len"] = _get_video(train_state.replay_last_rollout_batch) | |
metrics["mutation_rollout"], metrics["mutation_ep_len"] = _get_video( | |
train_state.mutation_last_rollout_batch | |
) | |
metrics["eval_animation"] = (frames, episode_lengths) | |
metrics["eval_returns_video"] = cum_rewards[0] | |
metrics["eval_len_video"] = episode_lengths | |
# Eval on sampled | |
return metrics | |
def _get_dummy_videos(rng, states, episode_lengths, cum_rewards): | |
n_eval = config["num_eval_levels"] | |
nsteps = env_params.max_timesteps | |
nsteps2 = config["outer_rollout_steps"] * config["num_steps"] | |
img_size = ( | |
env.static_env_params.screen_dim[0] // env.static_env_params.downscale, | |
env.static_env_params.screen_dim[1] // env.static_env_params.downscale, | |
) | |
return { | |
"log_videos": False, | |
"dr_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32), | |
"dr_ep_len": jnp.zeros((), jnp.int32), | |
"replay_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32), | |
"replay_ep_len": jnp.zeros((), jnp.int32), | |
"mutation_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32), | |
"mutation_ep_len": jnp.zeros((), jnp.int32), | |
# "eval_returns": jnp.zeros((n_eval,), jnp.float32), | |
# "eval_solves": jnp.zeros((n_eval,), jnp.float32), | |
# "eval_learn": jnp.zeros((n_eval,), jnp.float32), | |
# "eval_ep_lengths": jnp.zeros((n_eval,), jnp.int32), | |
"eval_animation": ( | |
jnp.zeros((nsteps, n_eval, 3, *img_size), jnp.float32), | |
jnp.zeros((n_eval,), jnp.int32), | |
), | |
"eval_returns_video": jnp.zeros((n_eval,), jnp.float32), | |
"eval_len_video": jnp.zeros((n_eval,), jnp.int32), | |
} | |
rng, rng_eval, rng_vid = jax.random.split(rng, 3) | |
metrics_eval, states, episode_lengths, cum_rewards = _get_eval(rng_eval) | |
metrics = { | |
**metrics, | |
**metrics_eval, | |
**jax.lax.cond( | |
should_log_videos, _get_videos, _get_dummy_videos, rng_vid, states, episode_lengths, cum_rewards | |
), | |
} | |
max_num_images = 8 | |
top_regret_ones = max_num_images // 2 | |
bot_regret_ones = max_num_images - top_regret_ones | |
def get_values(level_batch, scores): | |
args = jnp.argsort(scores) # low scores are at the start, high scores are at the end | |
low_scores = args[:bot_regret_ones] | |
high_scores = args[-top_regret_ones:] | |
low_levels = jax.tree.map(lambda x: x[low_scores], level_batch) | |
high_levels = jax.tree.map(lambda x: x[high_scores], level_batch) | |
low_scores = scores[low_scores] | |
high_scores = scores[high_scores] | |
# now concatenate: | |
return jax.vmap(render_fn)( | |
jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), low_levels, high_levels) | |
), jnp.concatenate([low_scores, high_scores], axis=0) | |
metrics["dr_levels"], metrics["dr_scores"] = get_values( | |
train_state.dr_last_level_batch, train_state.dr_last_level_batch_scores | |
) | |
metrics["replay_levels"], metrics["replay_scores"] = get_values( | |
train_state.replay_last_level_batch, train_state.replay_last_level_batch_scores | |
) | |
metrics["mutation_levels"], metrics["mutation_scores"] = get_values( | |
train_state.mutation_last_level_batch, train_state.mutation_last_level_batch_scores | |
) | |
def _t(i): | |
return jax.lax.select(i == 0, config["num_steps"], i) | |
metrics["dr_ep_len"] = _t(train_state.dr_last_rollout_batch[1][:].argmax()) | |
metrics["replay_ep_len"] = _t(train_state.replay_last_rollout_batch[1][:].argmax()) | |
metrics["mutation_ep_len"] = _t(train_state.mutation_last_rollout_batch[1][:].argmax()) | |
highest_scoring_level = level_sampler.get_levels(train_state.sampler, train_state.sampler["scores"].argmax()) | |
highest_weighted_level = level_sampler.get_levels( | |
train_state.sampler, level_sampler.level_weights(train_state.sampler).argmax() | |
) | |
metrics["highest_scoring_level"] = render_fn(highest_scoring_level) | |
metrics["highest_weighted_level"] = render_fn(highest_weighted_level) | |
# log_eval(metrics, train_state_to_log_dict(runner_state[1], level_sampler)) | |
jax.debug.callback(log_eval, metrics, train_state_to_log_dict(runner_state[1], level_sampler)) | |
return (rng, train_state), {"update_count": metrics["update_count"]} | |
def log_checkpoint(update_count, train_state): | |
if config["save_path"] is not None and config["checkpoint_save_freq"] > 1: | |
steps = ( | |
int(update_count) | |
* int(config["num_train_envs"]) | |
* int(config["num_steps"]) | |
* int(config["outer_rollout_steps"]) | |
) | |
# save_params_to_wandb(train_state.params, steps, config) | |
save_model_to_wandb(train_state, steps, config) | |
def train_eval_and_checkpoint_step(runner_state, _): | |
runner_state, metrics = jax.lax.scan( | |
train_and_eval_step, runner_state, xs=jnp.arange(config["checkpoint_save_freq"] // config["eval_freq"]) | |
) | |
jax.debug.callback(log_checkpoint, metrics["update_count"][-1], runner_state[1]) | |
return runner_state, metrics | |
# Set up the train states | |
rng = jax.random.PRNGKey(config["seed"]) | |
rng_init, rng_train = jax.random.split(rng) | |
train_state = create_train_state(rng_init) | |
runner_state = (rng_train, train_state) | |
runner_state, metrics = jax.lax.scan( | |
train_eval_and_checkpoint_step, | |
runner_state, | |
xs=jnp.arange((config["num_updates"]) // (config["checkpoint_save_freq"])), | |
) | |
if config["save_path"] is not None: | |
# save_params_to_wandb(runner_state[1].params, config["total_timesteps"], config) | |
save_model_to_wandb(runner_state[1], config["total_timesteps"], config, is_final=True) | |
return runner_state[1] | |
if __name__ == "__main__": | |
main() | |