Quickstart: evaluating \(\mathrm{Li}_s(z)\) with jaxpolylog#
This notebook walks through every public entry point of jaxpolylog on
small, runnable examples:
Direct evaluation at a single point with each of the four
approxstrategies ("inf","zero","patch","integral"), benchmarked againstmpmath.polylog.Regime comparison along a path approaching \(z = 1\), showing where each strategy wins.
Vectorisation with
jax_polylog_vmap.Automatic differentiation of \(\mathrm{Li}_s\) to second order, verified against the analytic identity \(\tfrac{\mathrm{d}}{\mathrm{d}z}\mathrm{Li}_s(z) = \mathrm{Li}_{s-1}(z)/z\).
All examples assume the package has been installed (pip install jaxpolylog) and that mpmath is available for cross-checks.
1. Setup#
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import mpmath
import jaxpolylog
from jaxpolylog import jax_polylog, jax_polylog_vmap
2. Direct evaluation against mpmath#
A convenient way to parameterise points near \(z = 1\) is by
so that small \(|t|\) corresponds to small \(|\log z|\) and therefore to
the regime where the "zero" Laurent expansion converges fastest.
We evaluate \(\mathrm{Li}_3(z)\) at three representative points and
compare every available strategy against mpmath.polylog, which we
take as the ground truth.
def li_table(s, t, p_range_inf=100, p_range_zero=100,
p_range_int=10**5, p_range_patch=100):
"""Evaluate Li_s(exp(2 pi i t)) using every strategy + mpmath."""
z = np.exp(2*np.pi*1j*t)
return {
"z" : z,
"Li_inf" : complex(jax_polylog(z, s, p_range_inf, "inf")),
"Li_zero" : complex(jax_polylog(z, s, p_range_zero, "zero")),
"Li_patch" : complex(jax_polylog(z, s, p_range_patch, "patch")),
"Li_integral": complex(jax_polylog(z, s, p_range_int, "integral")),
"Li_mpmath" : complex(mpmath.polylog(s, z)),
}
for t in [2j, 1j, 1e-5j]:
print(f"t = {t!r}")
res = li_table(s=3, t=t)
for k, v in res.items():
print(f" {k:<12s} = {v}")
print()
All four strategies agree with mpmath at \(t = 2\mathrm{i}\) (well
inside the unit disk) and at \(t = \mathrm{i}\) (on the unit circle).
At \(t = 10^{-5}\mathrm{i}\), where \(z\) is extremely close to \(1\), the
naive "inf" series stalls (truncation error dominates) while
"zero" and "patch" reproduce the mpmath value to all displayed
digits.
3. Regime comparison along \(z \to 1\)#
# Sweep t along the positive imaginary axis from far from 1 down to
# very close to 1, and record |Li_2(z) - Li_2_mpmath| for the three
# JAX strategies.
from tqdm.auto import tqdm
N = 1024
X = np.linspace(1e-4, 1.0, N) # t in (0, 1] purely imaginary
t = X * 1j
z = np.exp(2*np.pi*1j*t)
s = 2
Li_inf_v = jax_polylog_vmap(z, s, p_range=100, approx="inf")
Li_zero_v = jax_polylog_vmap(z, s, p_range=100, approx="zero")
Li_patch_v = jax_polylog_vmap(z, s, p_range=100, approx="patch")
Li_mpm = np.array([complex(mpmath.polylog(s, z0)) for z0 in tqdm(z)])
err_inf = np.abs(np.asarray(Li_inf_v) - Li_mpm)
err_zero = np.abs(np.asarray(Li_zero_v) - Li_mpm)
err_patch = np.abs(np.asarray(Li_patch_v) - Li_mpm)
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(8, 5), dpi=120)
ax.plot(X, err_inf, label='"inf" series', lw=1.4)
ax.plot(X, err_zero, label='"zero" series', lw=1.4)
ax.plot(X, err_patch, label='"patch" (auto)', lw=2.0, color='k', ls='--')
ax.set_yscale('log')
ax.set_xlabel(r'$t = \mathrm{Im}(\log z)\,/\,2\pi$')
ax.set_ylabel(r'$|\mathrm{Li}_2(z) - \mathrm{mpmath}|$')
ax.set_title('Truncation error of each strategy along $z = e^{2\\pi i t}$, $s = 2$')
ax.axvline(jaxpolylog.polylogs._PVAL_OPTIMAL, color='grey', lw=0.8, ls=':',
label=fr'$t_\star \approx {jaxpolylog.polylogs._PVAL_OPTIMAL:.4f}$')
ax.legend()
fig.tight_layout()
plt.show()
The plot displays the per-strategy truncation error against the normalised log-distance \(t = |\log z|/(2\pi)\):
The
"inf"series wins at large \(t\) (away from \(z = 1\)) but loses precision rapidly as \(t \to 0\).The
"zero"series wins at small \(t\) (near \(z = 1\)) but its error grows as \(t\) approaches \(1\), where \(|\log z|\) approaches the \(2\pi\) radius of convergence."patch"(dashed black) tracks the lower envelope of the two by switching at the analytical crossover \(t_\star\) marked by the dotted vertical line.
The crossover \(t_\star\) is the fixed point of \(e^{-2\pi t} = t\),
pre-computed at import time and stored in
jaxpolylog.polylogs._PVAL_OPTIMAL.
4. Vectorisation#
# jax_polylog_vmap is a thin JIT-compiled vmap wrapper. Use it for
# batch evaluation along the leading axis.
zs = jnp.linspace(0.1, 0.95, 32) + 0.0j
vals = jax_polylog_vmap(zs, s=2, p_range=200, approx="patch")
print("input shape:", zs.shape)
print("output shape:", vals.shape)
print("first 5 values:", vals[:5])
5. Automatic differentiation#
The custom-JVP rule of jax_polylog implements the analytic identity
evaluated through a stable helper that never divides by \(z\) in the
series regime. We verify this at one point by comparing
jax.grad(Li_3) against \(\mathrm{Li}_2(z) / z\).
# Real-valued probe (Li_3 of a complex z is complex; pick the real
# part to obtain a real-to-real function that jax.grad can handle).
def f(z_real):
z = z_real + 0.0j
return jax_polylog(z, 3, 200, "patch").real
z_val = 0.5
df_dz = float(jax.grad(f)(z_val))
# Analytic derivative: Li_2(z) / z, then take .real for the real axis.
ana = float((jax_polylog(z_val + 0.0j, 2, 200, "patch") / z_val).real)
print(f"jax.grad(Li_3)(z) = {df_dz:+.12e}")
print(f"Li_2(z) / z = {ana:+.12e}")
print(f"absolute deviation = {abs(df_dz - ana):.2e}")
# Second-order derivative: d^2/dz^2 Li_3(z) = d/dz [Li_2(z)/z]
d2f = float(jax.grad(jax.grad(f))(z_val))
print(f"jax.grad(jax.grad(Li_3))(z) = {d2f:+.12e}")
What next#
Read the introduction page for the mathematical background and a description of the four evaluation strategies.
Read the API reference for the full signatures of
jax_polylog,jax_polylog_vmap, the custom JVP rule, and the numerical helpers.For a deep-LCS regime stress test (third- and fourth-order derivatives at \(|z| \le 10^{-30}\)), see
tests/test_patch_high_deriv_vmap.pyin the source tree.