A simple implementation of Hamiltonian Monte Carlo in JAX

This is a simple implementation of Hamiltonian Monte Carlo in JAX that is vectorized and supports pytree parameters (i.e. tree-like structures). Here’s a minimal example to sample from a distribution: import jax import jax.numpy as jnp from hmc import hmc_sampler # define target distribution def target_log_pdf(params): return jax.scipy.stats.t.logpdf(params, df=1).sum() # run HMC params_init = jnp.zeros(10)    

Read more