Turning SymPy expressions into JAX functions
Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions. All SymPy floats become trainable input parameters. SymPy symbols become columns of a passed matrix. Installation pip install git+https://github.com/MilesCranmer/sympy2jax.git Example import sympy from sympy import symbols import jax import jax.numpy as jnp from jax import random from sympy2jax import sympy2jax Let’s create an expression in SymPy: x, y = symbols(‘x y’) expression = 1.0 * sympy.cos(x) +
Read more