update layers.py for JAX deprecate "shape"

#4
by zxyse - opened

used generic tuple instead of jax.core.NamedShape
used np to get total size of shape

Ready to merge
This branch is ready to get merged automatically.
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment