jaxpolylog.jax_polylog#
- jax_polylog = <jax._src.custom_derivatives.custom_jvp object>[source]#
** This function computes the polylogarithm of order s at point z using JAX. It supports automatic differentiation and is optimized for performance. The function is defined for integer values of s and can handle both real and complex inputs for z.
Mathematical Definition: The polylogarithm function is defined as:
\[\text{Li}_s(z) = \sum_{k=1}^{\infty} \frac{z^k}{k^s}\]for |z| < 1 and can be analytically continued to other values of z. For integer s, the function can be expressed in terms of elementary functions for specific values of s: - For s = 1: Li_1(z) = -ln(1 - z) - For s = 0: Li_0(z) = z / (1 - z) - For s = -1: Li_{-1}(z) = z / (1 - z)^2 - For s = -2: Li_{-2}(z) = z(1 + z) / (1 - z)^3 - For s = -3: Li_{-3}(z) = z(1 + 4z + z^2) / (1 - z)^4 - For s = -4: Li_{-4}(z) = z(1 + 11z + 11z^2 + z^3) / (1 - z)^5 - For s = -5: Li_{-5}(z) = z(1 + 26z + 66z^2 + 26z^3 + z^4) / (1 - z)^6 - For s = -6: Li_{-6}(z) = z(1 + 57z + 302z^2 + 302z^3 + 57z^4 + z^5) / (1 - z)^7 - For s = -7: Li_{-7}(z) = z(1 + 120z + 1191z^2 + 2416z^3 + 1191z^4 + 120z^5 + z^6) / (1 - z)^8 - For s = -8: Li_{-8}(z) = z(1 + 247z + 4293z^2 + 15619z^3 + 15619z^4 + 4293z^5 + 247z^6 + z^7) / (1 - z)^9 - For s = -9: Li_{-9}(z) = z(1 + 502z + 14608z^2 + 88234z^3 + 156190z^4 + 88234z^5 + 14608z^6 + 502z^7 + z^8) / (1 - z)^10 - For other integer values of s, the function is computed using the series definition.
- Parameters:
z (complex) – The input value(s) at which to evaluate the polylogarithm. Can be a scalar or an array.
s (int) – The order of the polylogarithm. Must be an integer.
p_range (int) – The number of terms to include in the series expansion for non-predefined s values. Higher values increase accuracy but also computation time.
approx (str) – The approximation method to use. Must be one of
"inf","integral","zero", or"patch".pval (float, optional) – Transition parameter for the
"patch"method. Points witht = |log z| / (2π) < pvaluse the"zero"expansion; points witht ≥ pvaland|z| < 1use the"inf"series. Points with|z| ≥ 1always use"zero"(the"inf"series diverges there). The default is_PVAL_OPTIMAL≈ 0.2322, the fixed point ofe^{-2πt} = tthat equates both truncation errors. Ignored forapproxvalues other than"patch".
- Returns:
complex – The computed polylogarithm values at the input z.
- Raises:
ValueError – If s is not an integer or if approx is not one of the specified methods.
Example Usage:
`python import jax.numpy as jnp from jaxpolylog import jax_polylog as polylog z = jnp.array([0.5, 0.9, 1.0]) s = 2 result = polylog(z, s, p_range=1000) print(result) `- Type:
**Description