jaxpolylog – Polylogarithms in JAX with autodiff#
jaxpolylog is a JAX-native implementation of the polylogarithm
for integer order \(s\) and complex argument \(z\), designed
so that the function evaluates, JIT-compiles, vectorises, and
differentiates to arbitrary order under jax.grad() /
jax.jvp() without losing fp64 precision – including in the
deep-LCS regimes where \(|z|\) may be as small as
\(10^{-30}\) or below.
Start with the introduction. It explains the polylogarithm
and the four numerical strategies ("inf", "zero",
"patch", "integral") the package implements.
The API reference documents the public functions
jax_polylog() and jax_polylog_vmap() and the
custom JVP rule underneath.
Why is this package important?#
Polylogarithms appear in nearly every advanced calculation in theoretical physics and number theory: Feynman-integral coefficients, modular forms, period integrals on Calabi–Yau manifolds, instanton sums in string theory, finite-temperature equations of state, and the multi-polylog basis of multi-loop QFT. None of the standard JAX or NumPy/SciPy stacks ship a usable polylogarithm:
scipy.specialhas no polylogarithm;mpmath.polylog()does, but it is pure Python, not vectorisable, and not differentiable.jax.scipy.special.zeta()covers only \(z = 1\).A naive
jnp.sumof \(\sum z^k / k^s\) is unusable: it diverges for \(|z| \ge 1\), loses all precision near \(z = 1\), and – most importantly – its higher derivatives underjax.grad()cascade \(1/z^n\) factors that overflow fp64 the moment \(|z|\) drops below \(\sim 10^{-2}\).
jaxpolylog solves all three problems at once. See
Introduction for the full numerical story.
Quick start#
import jax, jax.numpy as jnp
from jaxpolylog import jax_polylog, jax_polylog_vmap
# Evaluate Li_3 at a single point using the auto-patched series
z = jnp.array(0.7 + 0.1j)
val = jax_polylog(z, s=3, p_range=200, approx="patch")
# Vectorised evaluation along an axis
zs = jnp.linspace(0.1, 0.95, 32) + 0.0j
vals = jax_polylog_vmap(zs, s=2, p_range=200, approx="patch")
# Forward- and reverse-mode autodiff, to arbitrary order
dLi3 = jax.grad(lambda z: jax_polylog(z, 3, 200, "patch").real)(0.5 + 0.0j)
d2Li3 = jax.grad(jax.grad(
lambda z: jax_polylog(z, 3, 200, "patch").real))(0.5 + 0.0j)