# ==============================================================================
# Copyright 2022-2026 Andreas Schachner
#
# This file is part of jaxpolylog.
#
# jaxpolylog is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# jaxpolylog is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with jaxpolylog. If not, see <https://www.gnu.org/licenses/>.
# ==============================================================================
#
# ------------------------------------------------------------------------------
# This file holds functions for polylogarithms using JAX.
# ------------------------------------------------------------------------------
# Important standard libraries
import os, sys, warnings
import math
from fractions import Fraction
from functools import partial, lru_cache
# Important JAX libraries
import jax
from jax import custom_jvp
from jax import jit, vmap, config
import jax.numpy as jnp
from jax import Array
from numpy.typing import ArrayLike
# Enable 64 bit precision
config.update("jax_enable_x64", True)
# ---------------------------------------------------------------------------
# Stable coefficient utilities for the ``"zero"`` series expansion.
# ---------------------------------------------------------------------------
# The ``"zero"`` branch evaluates Li_s(z) = Σ_k ζ(s-k) / k! · μ^k with
# μ = log z, which converges for |μ| < 2π. The kth coefficient is
# ζ(s-k)/k!. The old implementation built this via
# ``Bs = jax.scipy.special.bernoulli(p_range); zeta_neg = -Bs/k``
# which overflows fp64 for ``p_range ≳ 250`` because individual
# Bernoulli numbers grow factorially even though the coefficient
# B_k / k! we actually need decays geometrically. The helpers below
# compute the coefficients without ever materialising large ``B_k`` as
# floats, using exact :class:`fractions.Fraction` Bernoulli numbers and
# (where needed) the identity
# ζ(2m) = |B_{2m}| · (2π)^{2m} / (2 · (2m)!).
# All computation happens at module-trace time and is cached.
@lru_cache(maxsize=4)
def _bernoulli_fractions_up_to(N: int) -> tuple:
r"""Return ``(B_0, B_1, ..., B_N)`` as :class:`fractions.Fraction`.
Akiyama–Tanigawa recurrence, exact rationals. Used by
:func:`_zero_branch_coeffs` and :func:`_zeta_pos_int` for stable
coefficient construction. Cached because computation is O(N²) in
bignum operations.
"""
a = [Fraction(0)] * (N + 1)
B = [Fraction(0)] * (N + 1)
for m in range(N + 1):
a[m] = Fraction(1, m + 1)
for j in range(m, 0, -1):
a[j - 1] = j * (a[j - 1] - a[j])
B[m] = a[0]
return tuple(B)
[docs]
@lru_cache(maxsize=64)
def _zeta_pos_int(n: int) -> float:
r"""Riemann ``ζ(n)`` for integer ``n ≥ 2``, fp64-precise.
Even ``n``: closed form ``ζ(2m) = |B_{2m}| (2π)^{2m} / (2 · (2m)!)`` from
the exact Bernoulli table. Odd ``n``: direct summation plus a five-term
Euler–Maclaurin tail (more than enough for fp64 from K=60 onward).
"""
if n < 2:
raise ValueError(f"_zeta_pos_int requires n >= 2, got {n}")
if n % 2 == 0:
B = _bernoulli_fractions_up_to(n)
abs_bn = B[n] if B[n] > 0 else -B[n]
ratio = abs_bn / (Fraction(2) * Fraction(math.factorial(n)))
return float(ratio) * (2.0 * math.pi) ** n
K = 60
Kf = float(K)
s = sum(1.0 / k ** n for k in range(1, K))
s += 1.0 / ((n - 1) * Kf ** (n - 1))
s += 0.5 / Kf ** n
s += n / (12.0 * Kf ** (n + 1))
s -= n * (n + 1) * (n + 2) / (720.0 * Kf ** (n + 3))
s += n * (n + 1) * (n + 2) * (n + 3) * (n + 4) / (30240.0 * Kf ** (n + 5))
return s
# Hard cap on the number of terms used in the "zero" Laurent series. The
# series converges geometrically with ratio |μ|/(2π); 200 terms gives
# truncation error ≤ 0.5^200 ≈ 10^{-60} at the patch threshold |μ|/2π ≈ pval,
# so any larger ``p_range`` adds no precision but slows the precomputation.
_ZERO_BRANCH_P_MAX = 200
[docs]
@lru_cache(maxsize=64)
def _zero_branch_coeffs(s: int, P: int) -> tuple:
r"""Precompute ``c[k] = ζ(s-k) / k!`` for ``k = 0..P-1``.
Branch table:
* ``arg = s-k ≥ 2``: ``ζ(arg)`` via :func:`_zeta_pos_int`.
* ``arg = 1`` (``k = s-1``): coefficient set to 0 (cancels the
explicit ``log(-μ)`` term in :func:`jax_polylog`).
* ``arg = 0`` (``k = s``): ``ζ(0) = -1/2``.
* ``arg < 0``, ``n = 1 - arg``:
- ``n`` odd ≥ 3: ``B_n = 0`` → coefficient = 0.
- ``n = 2m`` even: coefficient ``= -B_{2m}/(2m · k!)`` as an
exact :class:`Fraction`, then floated. Never materialises
``B_{2m}`` or ``k!`` separately.
"""
coeffs = [0.0] * P
Bf = _bernoulli_fractions_up_to(max(2, P)) if P > 1 else None
for k in range(P):
if k == s - 1:
continue
arg = s - k
if arg >= 2:
coeffs[k] = _zeta_pos_int(arg) / math.factorial(k)
elif arg == 1:
continue
elif arg == 0:
coeffs[k] = -0.5 / math.factorial(k)
else:
n = 1 - arg
if n == 1 or (n % 2 == 1):
continue
frac = -Bf[n] / (Fraction(n) * Fraction(math.factorial(k)))
coeffs[k] = float(frac)
return tuple(coeffs)
[docs]
def _Li1_over_z_stable(z):
r"""``Li_1(z)/z = -log1p(-z)/z`` evaluated stably under nested autodiff.
The closed form ``-log1p(-z)/z`` is correct numerically for any ``|z|``,
but JAX autodiff cascades produce catastrophic cancellation at small
``|z|`` from the second derivative onward: ``d/dz[-log1p(-z)/z]`` is
algebraically ``1/(z(1-z)) + log1p(-z)/z²``, two ``O(1/z)`` terms that
cancel mathematically but lose all bits in fp64 for ``|z| ≲ 10⁻²``.
Cure: at ``|z| < 0.5`` use the convergent power series
.. math::
\frac{\mathrm{Li}_1(z)}{z} \;=\; \sum_{k=0}^{\infty} \frac{z^k}{k+1}
\;=\; 1 + \frac{z}{2} + \frac{z^2}{3} + \cdots,
a polynomial in ``z`` whose JAX autodiff is exact at all orders. 60 terms
give truncation error ``|z|^{60} ≤ 0.5^{60} ≈ 10⁻¹⁸`` at the cutoff. The
double-where ensures the inactive branch sees a safe argument so that
``0 · NaN = NaN`` poisoning cannot reach the result.
"""
EPS = 0.5
use_series = jnp.abs(z) < EPS
SAFE_CLOSED = jnp.asarray(0.7 + 0.0j, dtype=z.dtype)
SAFE_SERIES = jnp.asarray(0.3 + 0.0j, dtype=z.dtype)
z_series = jnp.where(use_series, z, SAFE_SERIES)
z_closed = jnp.where(use_series, SAFE_CLOSED, z)
series_val = jnp.ones_like(z_series)
zk = z_series
for k in range(2, 62):
series_val = series_val + zk / float(k)
zk = zk * z_series
closed_val = -jnp.log1p(-z_closed) / z_closed
return jnp.where(use_series, series_val, closed_val)
[docs]
def _compute_pval_optimal() -> float:
r"""
Find the optimal transition parameter ``t*`` for the ``"patch"`` method.
The ``"patch"`` method switches between two series expansions of Li_s(z):
* **"inf"** series: ``Li_s(z) = Σ z^k / k^s``, convergent for ``|z| < 1``.
After ``N`` terms the truncation error scales as ``|z|^N``.
* **"zero"** expansion: Laurent series in ``μ = log z`` around ``μ = 0``
(``z = 1``), convergent for ``|μ| < 2π``. After ``N`` terms the error
scales as ``(|μ|/(2π))^N = t^N`` where ``t = |μ|/(2π)``.
For real positive ``z < 1`` both errors are equal when
.. math::
|z|^N = t^N \;\Longrightarrow\; |z| = t \;\Longrightarrow\;
e^{-|μ|} = \frac{|μ|}{2π} \;\Longrightarrow\; e^{-2πt} = t \,.
The unique positive solution of ``e^{-2πt} = t`` is computed here by
bisection and stored as the module constant :data:`_PVAL_OPTIMAL` ≈ 0.2322.
This fixed point is independent of ``N`` (the ``p_range`` parameter), so
the optimal crossover does not change with the number of series terms.
Returns:
float: Optimal transition parameter ``t*`` ≈ 0.2322.
"""
import numpy as _np
# Bisect on f(t) = e^{-2πt} - t. f(0.05) > 0, f(0.50) < 0.
lo, hi = 0.05, 0.50
for _ in range(80): # 80 iterations → sub-ULP accuracy
mid = 0.5 * (lo + hi)
if _np.exp(-2.0 * _np.pi * mid) > mid:
lo = mid
else:
hi = mid
return 0.5 * (lo + hi)
# Optimal patch transition: t* = |log z| / (2π) at which the truncation errors
# of the "inf" series and the "zero" expansion are equal. Precomputed once at
# import time by solving e^{-2πt} = t via bisection (~0.2322).
_PVAL_OPTIMAL: float = _compute_pval_optimal()
@partial(jit, static_argnums = (2,))
def intgrand(z: complex, t: complex, s: int) -> complex:
r"""
**Description:**
Integrand for the integral representation of the polylogarithm function.
Args:
z (complex): The input value(s) at which to evaluate the integrand. Can be a scalar or an array.
t (complex): The integration variable.
s (int): The order of the polylogarithm. Must be an integer.
Returns:
complex: The computed integrand values.
"""
return jnp.log(t)**(s-1)/(1-z*t) # type: ignore
[docs]
@partial(custom_jvp, nondiff_argnums=(1,2,3,4,))
@partial(jit, static_argnums=(1,2,3,4,))
def jax_polylog(z: complex, s: int, p_range: int, approx: str, pval: float = _PVAL_OPTIMAL) -> complex: # type: ignore
r"""
**Description:**
This function computes the polylogarithm of order `s` at point `z` using JAX. It supports automatic differentiation and is optimized for performance. The function is defined for integer values of `s` and can handle both real and complex inputs for `z`.
**Mathematical Definition:**
The polylogarithm function is defined as:
.. math::
\text{Li}_s(z) = \sum_{k=1}^{\infty} \frac{z^k}{k^s}
for |z| < 1 and can be analytically continued to other values of `z`.
For integer `s`, the function can be expressed in terms of elementary functions for specific values of `s`:
- For `s = 1`: `Li_1(z) = -ln(1 - z)`
- For `s = 0`: `Li_0(z) = z / (1 - z)`
- For `s = -1`: `Li_{-1}(z) = z / (1 - z)^2`
- For `s = -2`: `Li_{-2}(z) = z(1 + z) / (1 - z)^3`
- For `s = -3`: `Li_{-3}(z) = z(1 + 4z + z^2) / (1 - z)^4`
- For `s = -4`: `Li_{-4}(z) = z(1 + 11z + 11z^2 + z^3) / (1 - z)^5`
- For `s = -5`: `Li_{-5}(z) = z(1 + 26z + 66z^2 + 26z^3 + z^4) / (1 - z)^6`
- For `s = -6`: `Li_{-6}(z) = z(1 + 57z + 302z^2 + 302z^3 + 57z^4 + z^5) / (1 - z)^7`
- For `s = -7`: `Li_{-7}(z) = z(1 + 120z + 1191z^2 + 2416z^3 + 1191z^4 + 120z^5 + z^6) / (1 - z)^8`
- For `s = -8`: `Li_{-8}(z) = z(1 + 247z + 4293z^2 + 15619z^3 + 15619z^4 + 4293z^5 + 247z^6 + z^7) / (1 - z)^9`
- For `s = -9`: `Li_{-9}(z) = z(1 + 502z + 14608z^2 + 88234z^3 + 156190z^4 + 88234z^5 + 14608z^6 + 502z^7 + z^8) / (1 - z)^10`
- For other integer values of `s`, the function is computed using the series definition.
Args:
z (complex): The input value(s) at which to evaluate the polylogarithm. Can be a scalar or an array.
s (int): The order of the polylogarithm. Must be an integer.
p_range (int): The number of terms to include in the series expansion for non-predefined `s` values. Higher values increase accuracy but also computation time.
approx (str): The approximation method to use. Must be one of ``"inf"``, ``"integral"``, ``"zero"``, or ``"patch"``.
pval (float, optional): Transition parameter for the ``"patch"`` method.
Points with ``t = |log z| / (2π) < pval`` use the ``"zero"`` expansion;
points with ``t ≥ pval`` *and* ``|z| < 1`` use the ``"inf"`` series.
Points with ``|z| ≥ 1`` always use ``"zero"`` (the ``"inf"`` series
diverges there). The default is :data:`_PVAL_OPTIMAL` ≈ 0.2322,
the fixed point of ``e^{-2πt} = t`` that equates both truncation
errors. Ignored for ``approx`` values other than ``"patch"``.
Returns:
complex: The computed polylogarithm values at the input `z`.
Raises:
ValueError: If `s` is not an integer or if `approx` is not one of the specified methods.
**Example Usage:**
```python
import jax.numpy as jnp
from jaxpolylog import jax_polylog as polylog
z = jnp.array([0.5, 0.9, 1.0])
s = 2
result = polylog(z, s, p_range=1000)
print(result)
```
"""
# Check if s is an integer
if not isinstance(s, int):
raise ValueError("The order 's' must be an integer.")
if approx not in ["inf","integral","zero","patch"]:
raise ValueError("The approximation method must be one of 'inf', 'integral', 'zero', or 'patch'.")
# Handle special cases for specific integer values of s
if s==1:
return -jnp.log(1-z) # type: ignore
elif s==0:
return z/(1-z)
elif s==-1:
return z/(1-z)**2
elif s==-2:
return ((z*(1 + z))/(1 - z)**3)
elif s==-3:
return (z*(1 + 4*z + z**2))/(-1 + z)**4
elif s==-4:
return ((z*(1 + 11*z + 11*z**2 + z**3))/(1 - z)**5)
elif s==-5:
return (z*(1 + 26*z + 66*z**2 + 26*z**3 + z**4))/(-1 + z)**6
elif s==-6:
return ((z*(1 + 57*z + 302*z**2 + 302*z**3 + 57*z**4 + z**5))/(1 - z)**7)
elif s==-7:
return (z*(1 + 120*z + 1191*z**2 + 2416*z**3 + 1191*z**4 + 120*z**5 + z**6))/(-1 + z)**8
elif s==-8:
return ((z*(1 + 247*z + 4293*z**2 + 15619*z**3 + 15619*z**4 + 4293*z**5 + 247*z**6 + z**7))/(1 - z)**9)
elif s==-9:
return (z*(1 + 502*z + 14608*z**2 + 88234*z**3 + 156190*z**4 + 88234*z**5 + 14608*z**6 + 502*z**7 + z**8))/(-1 + z)**10
else:
if approx=="inf":
# Use series definition around z=infty
polylog_range = jnp.arange(1,p_range)
return jnp.sum(z**polylog_range/polylog_range**s) # type: ignore
elif approx=="integral":
# Use integral representation
polylog_range = jnp.linspace(0+1e-20,1.,p_range)
return z*(-1)**(s-1)/jax.scipy.special.gamma(s)*jax.scipy.integrate.trapezoid(intgrand(z,polylog_range,s),x=polylog_range)
elif approx=="zero":
# Series expansion around z=1: Li_s(z) = term1 + Σ_k ζ(s-k)/k! · μ^k,
# convergent for |μ| < 2π with μ = log z.
#
# The coefficient table c[k] = ζ(s-k)/k! is precomputed once per
# (s, P) by :func:`_zero_branch_coeffs` using exact Bernoulli
# numbers (Fraction arithmetic). This replaces the old
# ``jax.scipy.special.bernoulli(p_range)`` call, which overflows
# fp64 for ``p_range ≳ 250`` because individual B_n grow
# factorially even though the coefficient B_n/n! decays
# geometrically. See :func:`_zero_branch_coeffs` for details.
mu = jnp.log(z)
Hs = jnp.sum(1.0 / jnp.arange(1, s))
term1 = mu ** (s - 1) / jax.scipy.special.gamma(s) * (Hs - jnp.log(-mu))
P = min(p_range, _ZERO_BRANCH_P_MAX)
coeffs = jnp.asarray(_zero_branch_coeffs(s, P))
powers = mu ** jnp.arange(P)
term2 = jnp.sum(coeffs * powers)
return term1 + term2 # type: ignore
elif approx=="patch":
# Double-where dispatch: both branches must be traced for vmap,
# but each branch is *evaluated* on a safe argument when inactive.
# This avoids IEEE ``0 · NaN = NaN`` poisoning when one branch
# overflows (the "zero" branch at small |z|, where μ^k → huge)
# or diverges (the "inf" branch at |z| ≥ 1).
#
# |z| < 1 AND t = |log z|/(2π) ≥ pval → "inf" branch
# otherwise → "zero" branch
#
# The safe substitutes (|z|=0.5 for the inf branch, |z|=0.9 for
# the zero branch) sit comfortably inside each branch's
# convergence basin, so the inactive branch produces a finite
# (irrelevant) number that ``jnp.where`` then discards.
mu = jnp.log(z)
t = jnp.abs(mu) / (2.0 * jnp.pi)
use_inf = (jnp.abs(z) < 1.0) & (t >= pval)
SAFE_INF = jnp.asarray(0.5 + 0.0j, dtype=z.dtype)
SAFE_ZERO = jnp.asarray(0.9 + 0.0j, dtype=z.dtype)
z_for_inf = jnp.where(use_inf, z, SAFE_INF)
z_for_zero = jnp.where(use_inf, SAFE_ZERO, z)
Lis_inf = jax_polylog(z_for_inf, s, p_range, "inf", pval)
Lis_zero = jax_polylog(z_for_zero, s, p_range, "zero", pval)
return jnp.where(use_inf, Lis_inf, Lis_zero)
[docs]
def _Li_over_z(z: complex, s: int, p_range: int, approx: str, pval: float) -> complex:
r"""
**Description:**
Compute :math:`\mathrm{Li}_s(z)/z` via a numerically stable evaluation that
avoids dividing by ``z``.
This quantity is the analytic derivative
:math:`\mathrm{d}/\mathrm{d}z\,\mathrm{Li}_{s+1}(z) = \mathrm{Li}_s(z)/z`
used by the JVP rule of :func:`jax_polylog`. It is mathematically
identical to ``jax_polylog(z, s, ...)/z`` but evaluated as either a
closed-form polynomial-rational expression (for ``s ≤ 1``) or a
re-indexed power series (for ``s ≥ 2`` with ``approx="inf"``).
Avoiding the explicit division by ``z`` prevents :math:`1/z^n` factors
from cascading through higher-order autodiff and overflowing float64
when :math:`|z|` is very small (as happens in LCS-regime period
integrals where :math:`|z| = e^{-2\pi q\cdot\mathrm{Im}(t)}` may be
:math:`\lesssim 10^{-100}`).
Args:
z (complex): The input value(s).
s (int): The order parameter (one less than the order of the parent
polylogarithm being differentiated).
p_range (int): Number of terms in the series expansion.
approx (str): Approximation method. See :func:`jax_polylog`.
pval (float): Transition parameter for the ``"patch"`` method.
Returns:
complex: The value :math:`\mathrm{Li}_s(z)/z` evaluated stably.
"""
# Closed-form simplifications. Each entry is the closed-form Li_s(z)
# divided by z and algebraically simplified — no division by z remains
# for s <= 0 (the polynomial-rational forms manifestly cancel z).
if s == 1:
# Li_1(z)/z = -log(1-z)/z. Use the closed form (with ``log1p`` for
# accuracy when ``|z|`` is small) — this matches v0.1.0's behaviour
# exactly and is exact at any moderate ``z``, including ``|z|`` close
# to 1 where the truncated Taylor series diverges from the closed
# form.
#
# NOTE: this expression contains an explicit ``/z``. At very small
# ``|z|`` the value is fine (``log1p(-z) ≈ -z``, so the ratio is ≈ 1),
# but JAX's auto-diff of ``-log1p(-z)/z`` cascades ``1/z``, ``1/z²``
# at higher orders and overflows when ``|z|`` is tiny — same regime
# as v0.1.0.
#
# v0.3.0: route through :func:`_Li1_over_z_stable`, which uses a
# 60-term Taylor series at |z|<0.5 so even *direct* high-order
# autodiff of Li_2 (or anything that ultimately calls Li_1/z at
# small |z|) stays at fp64 precision. Equivalent to the closed
# form at |z|≥0.5, so existing callers see no numerical change.
return _Li1_over_z_stable(z)
elif s == 0:
return 1.0 / (1.0 - z)
elif s == -1:
return 1.0 / (1.0 - z)**2
elif s == -2:
return (1.0 + z) / (1.0 - z)**3
elif s == -3:
return (1.0 + 4.0*z + z**2) / (-1.0 + z)**4
elif s == -4:
return (1.0 + 11.0*z + 11.0*z**2 + z**3) / (1.0 - z)**5
elif s == -5:
return (1.0 + 26.0*z + 66.0*z**2 + 26.0*z**3 + z**4) / (-1.0 + z)**6
elif s == -6:
return (1.0 + 57.0*z + 302.0*z**2 + 302.0*z**3 + 57.0*z**4 + z**5) / (1.0 - z)**7
elif s == -7:
return (1.0 + 120.0*z + 1191.0*z**2 + 2416.0*z**3 + 1191.0*z**4 + 120.0*z**5 + z**6) / (-1.0 + z)**8
elif s == -8:
return (1.0 + 247.0*z + 4293.0*z**2 + 15619.0*z**3 + 15619.0*z**4 + 4293.0*z**5 + 247.0*z**6 + z**7) / (1.0 - z)**9
elif s == -9:
return (1.0 + 502.0*z + 14608.0*z**2 + 88234.0*z**3 + 156190.0*z**4 + 88234.0*z**5 + 14608.0*z**6 + 502.0*z**7 + z**8) / (-1.0 + z)**10
else:
if approx == "inf":
# Re-indexed inf-series: Li_s(z)/z = sum_{k=1}^{p_range-1} z^{k-1}/k^s
# Hand-extract the k=1 constant term and unroll the rest with
# static integer exponents (see the ``s==1`` branch above for
# rationale).
result = 1.0 + 0.0 * z # k=1 term, broadcast to z's shape/dtype
for k in range(2, p_range):
result = result + z**(k - 1) / float(k)**s
return result
elif approx == "integral":
# Li_s(z)/z = (-1)^{s-1}/Γ(s) · ∫₀¹ log(t)^{s-1}/(1-z·t) dt
polylog_range = jnp.linspace(0+1e-20, 1., p_range)
return ((-1)**(s-1) / jax.scipy.special.gamma(s)
* jax.scipy.integrate.trapezoid(intgrand(z, polylog_range, s),
x=polylog_range))
elif approx == "zero":
# The "zero" expansion is convergent for |log z| < 2π, i.e. z near 1.
# In that regime z is not tiny, so dividing by z is safe and the
# 1/z^n cascade does not occur. Use the literal form.
return jax_polylog(z, s, p_range, "zero", pval) / z
elif approx == "patch":
# Double-where dispatch — mirror of :func:`jax_polylog`'s patch
# branch. Each side sees a safe argument when inactive, so
# ``0 · NaN = NaN`` poisoning cannot reach the ``jnp.where``.
# The zero branch divides by z, which is safe because the safe
# substitute SAFE_ZERO = 0.9 keeps the divisor away from 0.
mu = jnp.log(z)
t = jnp.abs(mu) / (2.0 * jnp.pi)
use_inf = (jnp.abs(z) < 1.0) & (t >= pval)
SAFE_INF = jnp.asarray(0.5 + 0.0j, dtype=z.dtype)
SAFE_ZERO = jnp.asarray(0.9 + 0.0j, dtype=z.dtype)
z_for_inf = jnp.where(use_inf, z, SAFE_INF)
z_for_zero = jnp.where(use_inf, SAFE_ZERO, z)
Lis_inf_over_z = 1.0 + 0.0 * z_for_inf
for k in range(2, p_range):
Lis_inf_over_z = Lis_inf_over_z + z_for_inf ** (k - 1) / float(k) ** s
Lis_zero_over_z = jax_polylog(z_for_zero, s, p_range, "zero", pval) / z_for_zero
return jnp.where(use_inf, Lis_inf_over_z, Lis_zero_over_z)
@jax_polylog.defjvp
def _jax_polylog_jvp(s: int, p_range: int, approx: str, pval: float, primals: tuple, tangents: tuple) -> tuple:
r"""
**Description:**
Forward-mode JVP rule for :func:`jax_polylog`.
Implements the analytic identity
.. math::
\frac{\mathrm{d}}{\mathrm{d}z}\,\mathrm{Li}_s(z) = \frac{\mathrm{Li}_{s-1}(z)}{z}\,,
with :math:`\mathrm{Li}_{s-1}(z)/z` evaluated via :func:`_Li_over_z`,
a numerically-stable rewriting that avoids dividing by ``z``. Using
``custom_jvp`` (rather than ``custom_vjp``) also makes
:func:`jax.jvp` work directly, and JAX automatically derives the VJP
via transposition (the rule is linear in the tangent).
Args:
s, p_range, approx, pval: nondiff_argnums passed through.
primals (tuple): ``(z,)``.
tangents (tuple): ``(z_dot,)``.
Returns:
tuple: ``(primal_out, tangent_out)`` with
``tangent_out = (Li_{s-1}(z)/z) · z_dot``.
"""
z, = primals
z_dot, = tangents
primal_out = jax_polylog(z, s, p_range, approx, pval)
deriv = _Li_over_z(z, s - 1, p_range, approx, pval)
tangent_out = deriv * z_dot
return primal_out, tangent_out
jax_polylog_vmap_tmp = jax.vmap(jax_polylog, in_axes=(0, None, None, None, None))
[docs]
@partial(jit, static_argnames=['s','p_range','approx','pval'])
def jax_polylog_vmap(z: complex, s: int, p_range: int, approx: str = "inf", pval: float = _PVAL_OPTIMAL) -> complex:
r"""
**Description:**
Vectorized version of the polylogarithm function using JAX's vmap.
Args:
z (complex): The input values at which to evaluate the polylogarithm. Must be a 1D array.
s (int): The order of the polylogarithm. Must be an integer.
p_range (int): The number of terms to include in the series expansion for non-predefined `s` values. Higher values increase accuracy but also computation time.
approx (str, optional): The approximation method to use. Must be one of "inf", "integral", "zero", or "patch". Default is ``"inf"``.
pval (float, optional): Transition parameter for the ``"patch"`` method. See :func:`jax_polylog`. Default is :data:`_PVAL_OPTIMAL`.
Returns:
complex: The computed polylogarithm values at the input `z`.
"""
return jax_polylog_vmap_tmp(z, s, p_range, approx, pval)