See: nmc_numpyro/proposals.py#L25
Currently, the code block performing eigenvalue recovery for Minka's Normal proposal
lamb, vec = jnp.linalg.eigh(sigma)
sigma = vec @ jnp.diag(jnp.maximum(lamb, normal_sigma_eigen_epsilon)) @ vec.T
is guarded only by a nonzero dimensionality check for inputs.
This implies that such code block (which is expensive due to calls to jnp.linalg.eigh and repeated vector-matrix multiplications) is executed once per call (provided nonzero input-dimensionality), regardless of its actual effect or lack thereof.
However, guarding also for jnp.all(jnp.linalg.eigvalsh(sigma) > 0) produces a hard-to-avoid JIT-compilation error.
Since enabling JIT is of much greater importance performance-wise (~500x vs <1.05x), current implementation is a non-showstopping fix. Nonetheless, success in implementing such guard condition will lead to even greater speedups.
See: nmc_numpyro/proposals.py#L25
Currently, the code block performing eigenvalue recovery for Minka's Normal proposal
is guarded only by a nonzero dimensionality check for inputs.
This implies that such code block (which is expensive due to calls to
jnp.linalg.eighand repeated vector-matrix multiplications) is executed once per call (provided nonzero input-dimensionality), regardless of its actual effect or lack thereof.However, guarding also for
jnp.all(jnp.linalg.eigvalsh(sigma) > 0)produces a hard-to-avoid JIT-compilation error.Since enabling JIT is of much greater importance performance-wise (~500x vs <1.05x), current implementation is a non-showstopping fix. Nonetheless, success in implementing such guard condition will lead to even greater speedups.