jaxpolylog – Polylogarithms in JAX with autodiff

jaxpolylog – Polylogarithms in JAX with autodiff#

jaxpolylog is a JAX-native implementation of the polylogarithm

\[\mathrm{Li}_s(z) \;=\; \sum_{k=1}^{\infty}\frac{z^k}{k^s}\]

for integer order \(s\) and complex argument \(z\), designed so that the function evaluates, JIT-compiles, vectorises, and differentiates to arbitrary order under jax.grad() / jax.jvp() without losing fp64 precision – including in the deep-LCS regimes where \(|z|\) may be as small as \(10^{-30}\) or below.

New here?

Start with the introduction. It explains the polylogarithm and the four numerical strategies ("inf", "zero", "patch", "integral") the package implements.

Introduction
Looking for the API?

The API reference documents the public functions jax_polylog() and jax_polylog_vmap() and the custom JVP rule underneath.

API reference

Why is this package important?#

Polylogarithms appear in nearly every advanced calculation in theoretical physics and number theory: Feynman-integral coefficients, modular forms, period integrals on Calabi–Yau manifolds, instanton sums in string theory, finite-temperature equations of state, and the multi-polylog basis of multi-loop QFT. None of the standard JAX or NumPy/SciPy stacks ship a usable polylogarithm:

  • scipy.special has no polylogarithm; mpmath.polylog() does, but it is pure Python, not vectorisable, and not differentiable.

  • jax.scipy.special.zeta() covers only \(z = 1\).

  • A naive jnp.sum of \(\sum z^k / k^s\) is unusable: it diverges for \(|z| \ge 1\), loses all precision near \(z = 1\), and – most importantly – its higher derivatives under jax.grad() cascade \(1/z^n\) factors that overflow fp64 the moment \(|z|\) drops below \(\sim 10^{-2}\).

jaxpolylog solves all three problems at once. See Introduction for the full numerical story.

Quick start#

import jax, jax.numpy as jnp
from jaxpolylog import jax_polylog, jax_polylog_vmap

# Evaluate Li_3 at a single point using the auto-patched series
z   = jnp.array(0.7 + 0.1j)
val = jax_polylog(z, s=3, p_range=200, approx="patch")

# Vectorised evaluation along an axis
zs   = jnp.linspace(0.1, 0.95, 32) + 0.0j
vals = jax_polylog_vmap(zs, s=2, p_range=200, approx="patch")

# Forward- and reverse-mode autodiff, to arbitrary order
dLi3   = jax.grad(lambda z: jax_polylog(z, 3, 200, "patch").real)(0.5 + 0.0j)
d2Li3  = jax.grad(jax.grad(
             lambda z: jax_polylog(z, 3, 200, "patch").real))(0.5 + 0.0j)