Spaces:
Runtime error
Runtime error
File size: 2,032 Bytes
e0f25ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from typing import Any, Sequence
from chex import PRNGKey
import distrax
from flax import struct
import jax
import jax.numpy as jnp
@struct.dataclass
class HybridAction:
discrete: int
continuous: jnp.ndarray
class HybridActionDistribution(distrax.Distribution):
def __init__(self, discrete_logits, continuous_mu, continuous_sigma) -> None:
self.discrete = distrax.Categorical(logits=discrete_logits)
self.continuous = distrax.MultivariateNormalDiag(continuous_mu, continuous_sigma)
def _sample_n(self, rng: PRNGKey, n: int) -> Any:
rng, _rng, _rng2 = jax.random.split(rng, 3)
a = self.discrete._sample_n(_rng, n)
b = self.continuous._sample_n(_rng2, n)
return HybridAction(a, b)
def log_prob(self, value: Any):
a = self.discrete.log_prob(value.discrete)
b = self.continuous.log_prob(value.continuous)
return a + b # log probs, we add.
def entropy(self):
return self.discrete.entropy() + self.continuous.entropy()
def event_shape(self) -> Sequence[int]:
return ()
class MultiDiscreteActionDistribution(distrax.Distribution):
def __init__(self, flat_logits, number_of_dims_per_distribution) -> None:
self.distributions = []
total_dims = 0
for dims in number_of_dims_per_distribution:
self.distributions.append(distrax.Categorical(logits=flat_logits[..., total_dims : total_dims + dims]))
total_dims += dims
def _sample_n(self, key: PRNGKey, n: int) -> Any:
rngs = jax.random.split(key, len(self.distributions))
samples = [jnp.expand_dims(d._sample_n(rng, n), axis=-1) for rng, d in zip(rngs, self.distributions)]
return jnp.concatenate(samples, axis=-1)
def log_prob(self, value: Any):
return sum(d.log_prob(value[..., i]) for i, d in enumerate(self.distributions))
def entropy(self):
return sum(d.entropy() for d in self.distributions)
def event_shape(self) -> Sequence[int]:
return ()
|