File size: 18,096 Bytes
6e5cc8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 |
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import mixed_precision as prec
from dreamerv2 import common
from dreamerv2 import expl
tfd = tfp.distributions
class Agent(common.Module):
def __init__(self, config, obs_space, act_space, step):
self.config = config
self.obs_space = obs_space
self.act_space = act_space['action']
self.step = step
self.tfstep = tf.Variable(int(self.step), tf.int64)
self.wm = WorldModel(config, obs_space, self.tfstep)
self._task_behavior = ActorCritic(config, self.act_space, self.tfstep)
if config.expl_behavior == 'greedy':
self._expl_behavior = self._task_behavior
else:
self._expl_behavior = getattr(expl, config.expl_behavior)(
self.config, self.act_space, self.wm, self.tfstep,
lambda seq: self.wm.heads['reward'](seq['feat']).mode())
if self.config.offline_tune_lmbd:
self.log_lmbd = tf.Variable(0, dtype=tf.float32)
self.lmbd_optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)
else:
self.lmbd = self.config.offline_lmbd
@tf.function
def policy(self, obs, state=None, mode='train'):
obs = tf.nest.map_structure(tf.tensor, obs)
tf.py_function(lambda: self.tfstep.assign(
int(self.step), read_value=False), [], [])
if state is None:
latent = self.wm.rssm.initial(len(obs['reward']))
action = tf.zeros((len(obs['reward']),) + self.act_space.shape)
state = latent, action
latent, action = state
embed = self.wm.encoder(self.wm.preprocess(obs))
sample = (mode == 'train') or not self.config.eval_state_mean
latent, _ = self.wm.rssm.obs_step(
latent, action, embed, obs['is_first'], sample)
feat = self.wm.rssm.get_feat(latent)
if mode == 'eval':
actor = self._task_behavior.actor(feat)
action = actor.mode()
noise = self.config.eval_noise
elif mode == 'explore':
actor = self._expl_behavior.actor(feat)
action = actor.sample()
noise = self.config.expl_noise
elif mode == 'train':
actor = self._task_behavior.actor(feat)
action = actor.sample()
noise = self.config.expl_noise
action = common.action_noise(action, noise, self.act_space)
outputs = {'action': action}
state = (latent, action)
return outputs, state
@tf.function
def train(self, data, state=None):
metrics = {}
state, outputs, mets = self.wm.train(data, state)
metrics.update(mets)
start = outputs['post']
reward = lambda seq: self.wm.heads['reward'](seq['feat']).mode()
metrics.update(self._task_behavior.train(
self.wm, start, data['is_terminal'], reward))
if self.config.expl_behavior != 'greedy':
mets = self._expl_behavior.train(start, outputs, data)[-1]
metrics.update({'expl_' + key: value for key, value in mets.items()})
return state, metrics
# The folllowing methods split the above for offline training
@tf.function
def model_train(self, data):
_, _, metrics = self.wm.train(data)
return metrics
@tf.function
def agent_train(self, data):
data = self.wm.preprocess(data)
embed = self.wm.encoder(data)
start, _ = self.wm.rssm.observe(embed, data['action'], data['is_first'])
reward = lambda seq: self.penalized_reward(seq)
metrics = self._task_behavior.train(self.wm, start, data['is_terminal'], reward)
if self.config.offline_tune_lmbd:
metrics['lambda'] = tf.exp(self.log_lmbd)
else:
metrics['lambda'] = self.lmbd
return metrics
@tf.function
def compute_penalty(self, seq):
if self.config.offline_penalty_type == 'log_prob_ens':
dist = self.wm.rssm.get_dist(seq, ensemble=True)
penalty = tf.math.reduce_std(dist.log_prob(seq['stoch']), axis=0)
elif self.config.offline_penalty_type == 'meandis':
dist = self.wm.rssm.get_dist(seq, ensemble=True)
m = dist.mean()
mean_pred = tf.math.reduce_mean(m, axis=0)
penalty = tf.math.reduce_mean(tf.norm(m - mean_pred, axis=[-1, -2]), axis=0)
elif self.config.offline_penalty_type == 'mixstd_mean':
dist = self.wm.rssm.get_dist(seq, ensemble=True)
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=[1 / self.config.rssm.ensemble] * self.config.rssm.ensemble),
components_distribution=tfd.BatchReshape(
distribution=dist,
batch_shape=dist.batch_shape[1:] + dist.batch_shape[0],
validate_args=True
))
penalty = tf.math.reduce_mean(gm.stddev(), axis=[-1, -2])
else:
penalty = 0
return penalty
@tf.function
def penalized_reward(self, seq):
rew = self.wm.heads['reward'](seq['feat']).mode()
penalty = self.compute_penalty(seq)
if self.config.offline_tune_lmbd:
with tf.GradientTape() as tape:
lmbd = tf.exp(self.log_lmbd)
lambda_loss = tf.math.reduce_mean(
self.log_lmbd * (tf.stop_gradient(lmbd * penalty) - self.config.offline_lmbd_cons))
variables = [self.log_lmbd]
grads = tape.gradient(lambda_loss, variables)
self.lmbd_optimizer.apply_gradients(zip(grads, variables))
else:
lmbd = self.lmbd
return rew - lmbd * penalty
@tf.function
def report(self, data):
report = {}
data = self.wm.preprocess(data)
for key in self.wm.heads['decoder'].cnn_keys:
name = key.replace('/', '_')
report[f'openl_{name}'] = self.wm.video_pred(data, key)
return report
class WorldModel(common.Module):
def __init__(self, config, obs_space, tfstep):
shapes = {k: tuple(v.shape) for k, v in obs_space.items()}
self.config = config
self.tfstep = tfstep
self.rssm = common.EnsembleRSSM(**config.rssm)
self.encoder = common.Encoder(shapes, **config.encoder)
self.heads = {
'decoder': common.Decoder(shapes, **config.decoder),
'reward': common.MLP([], **config.reward_head),
}
if config.pred_discount:
self.heads['discount'] = common.MLP([], **config.discount_head)
for name in config.grad_heads:
assert name in self.heads, name
self.model_opt = common.Optimizer('model', **config.model_opt)
def train(self, data, state=None):
with tf.GradientTape() as model_tape:
model_loss, state, outputs, metrics = self.loss(data, state)
modules = [self.encoder, self.rssm, *self.heads.values()]
metrics.update(self.model_opt(model_tape, model_loss, modules))
return state, outputs, metrics
def loss(self, data, state=None):
data = self.preprocess(data)
embed = self.encoder(data)
post, prior = self.rssm.observe(
embed, data['action'], data['is_first'], state)
kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.config.kl)
assert len(kl_loss.shape) == 0
likes = {}
losses = {'kl': kl_loss}
feat = self.rssm.get_feat(post)
for name, head in self.heads.items():
grad_head = (name in self.config.grad_heads)
inp = feat if grad_head else tf.stop_gradient(feat)
out = head(inp)
dists = out if isinstance(out, dict) else {name: out}
for key, dist in dists.items():
like = tf.cast(dist.log_prob(data[key]), tf.float32)
likes[key] = like
losses[key] = -like.mean()
model_loss = sum(
self.config.loss_scales.get(k, 1.0) * v for k, v in losses.items())
outs = dict(
embed=embed, feat=feat, post=post,
prior=prior, likes=likes, kl=kl_value)
metrics = {f'{name}_loss': value for name, value in losses.items()}
metrics['model_kl'] = kl_value.mean()
metrics['prior_ent'] = self.rssm.get_dist(prior).entropy().mean()
metrics['post_ent'] = self.rssm.get_dist(post).entropy().mean()
last_state = {k: v[:, -1] for k, v in post.items()}
return model_loss, last_state, outs, metrics
def imagine(self, policy, start, is_terminal, horizon):
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
start['feat'] = self.rssm.get_feat(start)
start['action'] = tf.zeros_like(policy(start['feat']).mode())
seq = {k: [v] for k, v in start.items()}
for _ in range(horizon):
action = policy(tf.stop_gradient(seq['feat'][-1])).sample()
state = self.rssm.img_step({k: v[-1] for k, v in seq.items()}, action)
feat = self.rssm.get_feat(state)
for key, value in {**state, 'action': action, 'feat': feat}.items():
seq[key].append(value)
seq = {k: tf.stack(v, 0) for k, v in seq.items()}
if 'discount' in self.heads:
disc = self.heads['discount'](seq['feat']).mean()
if is_terminal is not None:
# Override discount prediction for the first step with the true
# discount factor from the replay buffer.
true_first = 1.0 - flatten(is_terminal).astype(disc.dtype)
true_first *= self.config.discount
disc = tf.concat([true_first[None], disc[1:]], 0)
else:
disc = self.config.discount * tf.ones(seq['feat'].shape[:-1])
seq['discount'] = disc
# Shift discount factors because they imply whether the following state
# will be valid, not whether the current state is valid.
seq['weight'] = tf.math.cumprod(
tf.concat([tf.ones_like(disc[:1]), disc[:-1]], 0), 0)
return seq
@tf.function
def preprocess(self, obs):
dtype = prec.global_policy().compute_dtype
obs = obs.copy()
for key, value in obs.items():
if key.startswith('log_'):
continue
if value.dtype == tf.int32:
value = value.astype(dtype)
if value.dtype == tf.uint8:
value = value.astype(dtype) / 255.0 - 0.5
obs[key] = value
obs['reward'] = {
'identity': tf.identity,
'sign': tf.sign,
'tanh': tf.tanh,
}[self.config.clip_rewards](obs['reward'])
if 'discount' not in obs:
obs['discount'] = 1.0 - obs['is_terminal'].astype(dtype)
obs['discount'] *= self.config.discount
return obs
@tf.function
def video_pred(self, data, key):
decoder = self.heads['decoder']
truth = data[key][:6] + 0.5
embed = self.encoder(data)
states, _ = self.rssm.observe(
embed[:6, :5], data['action'][:6, :5], data['is_first'][:6, :5])
recon = decoder(self.rssm.get_feat(states))[key].mode()[:6]
init = {k: v[:, -1] for k, v in states.items()}
prior = self.rssm.imagine(data['action'][:6, 5:], init)
openl = decoder(self.rssm.get_feat(prior))[key].mode()
model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1)
error = (model - truth + 1) / 2
video = tf.concat([truth, model, error], 2)
B, T, H, W, C = video.shape
return video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C))
class ActorCritic(common.Module):
def __init__(self, config, act_space, tfstep):
self.config = config
self.act_space = act_space
self.tfstep = tfstep
discrete = hasattr(act_space, 'n')
if self.config.actor.dist == 'auto':
self.config = self.config.update({
'actor.dist': 'onehot' if discrete else 'trunc_normal'})
if self.config.actor_grad == 'auto':
self.config = self.config.update({
'actor_grad': 'reinforce' if discrete else 'dynamics'})
self.actor = common.MLP(act_space.shape[0], **self.config.actor)
self.critic = common.MLP([], **self.config.critic)
if self.config.slow_target:
self._target_critic = common.MLP([], **self.config.critic)
self._updates = tf.Variable(0, tf.int64)
else:
self._target_critic = self.critic
self.actor_opt = common.Optimizer('actor', **self.config.actor_opt)
self.critic_opt = common.Optimizer('critic', **self.config.critic_opt)
self.rewnorm = common.StreamNorm(**self.config.reward_norm)
def train(self, world_model, start, is_terminal, reward_fn):
metrics = {}
hor = self.config.imag_horizon
# The weights are is_terminal flags for the imagination start states.
# Technically, they should multiply the losses from the second trajectory
# step onwards, which is the first imagined step. However, we are not
# training the action that led into the first step anyway, so we can use
# them to scale the whole sequence.
with tf.GradientTape() as actor_tape:
seq = world_model.imagine(self.actor, start, is_terminal, hor)
reward = reward_fn(seq)
seq['reward'], mets1 = self.rewnorm(reward)
mets1 = {f'reward_{k}': v for k, v in mets1.items()}
target, mets2 = self.target(seq)
actor_loss, mets3 = self.actor_loss(seq, target)
with tf.GradientTape() as critic_tape:
critic_loss, mets4 = self.critic_loss(seq, target)
metrics.update(self.actor_opt(actor_tape, actor_loss, self.actor))
metrics.update(self.critic_opt(critic_tape, critic_loss, self.critic))
metrics.update(**mets1, **mets2, **mets3, **mets4)
self.update_slow_target() # Variables exist after first forward pass.
return metrics
def actor_loss(self, seq, target):
# Actions: 0 [a1] [a2] a3
# ^ | ^ | ^ |
# / v / v / v
# States: [z0]->[z1]-> z2 -> z3
# Targets: t0 [t1] [t2]
# Baselines: [v0] [v1] v2 v3
# Entropies: [e1] [e2]
# Weights: [ 1] [w1] w2 w3
# Loss: l1 l2
metrics = {}
# Two states are lost at the end of the trajectory, one for the boostrap
# value prediction and one because the corresponding action does not lead
# anywhere anymore. One target is lost at the start of the trajectory
# because the initial state comes from the replay buffer.
policy = self.actor(tf.stop_gradient(seq['feat'][:-2]))
if self.config.actor_grad == 'dynamics':
objective = target[1:]
elif self.config.actor_grad == 'reinforce':
baseline = self._target_critic(seq['feat'][:-2]).mode()
advantage = tf.stop_gradient(target[1:] - baseline)
objective = policy.log_prob(seq['action'][1:-1]) * advantage
elif self.config.actor_grad == 'both':
baseline = self._target_critic(seq['feat'][:-2]).mode()
advantage = tf.stop_gradient(target[1:] - baseline)
objective = policy.log_prob(seq['action'][1:-1]) * advantage
mix = common.schedule(self.config.actor_grad_mix, self.tfstep)
objective = mix * target[1:] + (1 - mix) * objective
metrics['actor_grad_mix'] = mix
else:
raise NotImplementedError(self.config.actor_grad)
ent = policy.entropy()
ent_scale = common.schedule(self.config.actor_ent, self.tfstep)
objective += ent_scale * ent
weight = tf.stop_gradient(seq['weight'])
actor_loss = -(weight[:-2] * objective).mean()
metrics['actor_ent'] = ent.mean()
metrics['actor_ent_scale'] = ent_scale
return actor_loss, metrics
def critic_loss(self, seq, target):
# States: [z0] [z1] [z2] z3
# Rewards: [r0] [r1] [r2] r3
# Values: [v0] [v1] [v2] v3
# Weights: [ 1] [w1] [w2] w3
# Targets: [t0] [t1] [t2]
# Loss: l0 l1 l2
dist = self.critic(seq['feat'][:-1])
target = tf.stop_gradient(target)
weight = tf.stop_gradient(seq['weight'])
critic_loss = -(dist.log_prob(target) * weight[:-1]).mean()
metrics = {'critic': dist.mode().mean()}
return critic_loss, metrics
def target(self, seq):
# States: [z0] [z1] [z2] [z3]
# Rewards: [r0] [r1] [r2] r3
# Values: [v0] [v1] [v2] [v3]
# Discount: [d0] [d1] [d2] d3
# Targets: t0 t1 t2
reward = tf.cast(seq['reward'], tf.float32)
disc = tf.cast(seq['discount'], tf.float32)
value = self._target_critic(seq['feat']).mode()
# Skipping last time step because it is used for bootstrapping.
target = common.lambda_return(
reward[:-1], value[:-1], disc[:-1],
bootstrap=value[-1],
lambda_=self.config.discount_lambda,
axis=0)
metrics = {'critic_slow': value.mean(), 'critic_target': target.mean()}
return target, metrics
def update_slow_target(self):
if self.config.slow_target:
if self._updates % self.config.slow_target_update == 0:
mix = 1.0 if self._updates == 0 else float(
self.config.slow_target_fraction)
for s, d in zip(self.critic.variables, self._target_critic.variables):
d.assign(mix * s + (1 - mix) * d)
self._updates.assign_add(1)
|