jaxpolylog.jax_polylog

Contents

jaxpolylog.jax_polylog#

jax_polylog = <jax._src.custom_derivatives.custom_jvp object>[source]#

** This function computes the polylogarithm of order s at point z using JAX. It supports automatic differentiation and is optimized for performance. The function is defined for integer values of s and can handle both real and complex inputs for z.

Mathematical Definition: The polylogarithm function is defined as:

\[\text{Li}_s(z) = \sum_{k=1}^{\infty} \frac{z^k}{k^s}\]

for |z| < 1 and can be analytically continued to other values of z. For integer s, the function can be expressed in terms of elementary functions for specific values of s: - For s = 1: Li_1(z) = -ln(1 - z) - For s = 0: Li_0(z) = z / (1 - z) - For s = -1: Li_{-1}(z) = z / (1 - z)^2 - For s = -2: Li_{-2}(z) = z(1 + z) / (1 - z)^3 - For s = -3: Li_{-3}(z) = z(1 + 4z + z^2) / (1 - z)^4 - For s = -4: Li_{-4}(z) = z(1 + 11z + 11z^2 + z^3) / (1 - z)^5 - For s = -5: Li_{-5}(z) = z(1 + 26z + 66z^2 + 26z^3 + z^4) / (1 - z)^6 - For s = -6: Li_{-6}(z) = z(1 + 57z + 302z^2 + 302z^3 + 57z^4 + z^5) / (1 - z)^7 - For s = -7: Li_{-7}(z) = z(1 + 120z + 1191z^2 + 2416z^3 + 1191z^4 + 120z^5 + z^6) / (1 - z)^8 - For s = -8: Li_{-8}(z) = z(1 + 247z + 4293z^2 + 15619z^3 + 15619z^4 + 4293z^5 + 247z^6 + z^7) / (1 - z)^9 - For s = -9: Li_{-9}(z) = z(1 + 502z + 14608z^2 + 88234z^3 + 156190z^4 + 88234z^5 + 14608z^6 + 502z^7 + z^8) / (1 - z)^10 - For other integer values of s, the function is computed using the series definition.

Parameters:
  • z (complex) – The input value(s) at which to evaluate the polylogarithm. Can be a scalar or an array.

  • s (int) – The order of the polylogarithm. Must be an integer.

  • p_range (int) – The number of terms to include in the series expansion for non-predefined s values. Higher values increase accuracy but also computation time.

  • approx (str) – The approximation method to use. Must be one of "inf", "integral", "zero", or "patch".

  • pval (float, optional) – Transition parameter for the "patch" method. Points with t = |log z| / (2π) < pval use the "zero" expansion; points with t pval and |z| < 1 use the "inf" series. Points with |z| 1 always use "zero" (the "inf" series diverges there). The default is _PVAL_OPTIMAL ≈ 0.2322, the fixed point of e^{-2πt} = t that equates both truncation errors. Ignored for approx values other than "patch".

Returns:

complex – The computed polylogarithm values at the input z.

Raises:

ValueError – If s is not an integer or if approx is not one of the specified methods.

Example Usage: `python import jax.numpy as jnp from jaxpolylog import jax_polylog as polylog z = jnp.array([0.5, 0.9, 1.0]) s = 2 result = polylog(z, s, p_range=1000) print(result) `

Type:

**Description