decoding.utils

Miscellaneous helper functions.

 1"""
 2Miscellaneous helper functions.
 3"""
 4
 5import secrets
 6
 7import jax.nn as jnn
 8import jax.numpy as jnp
 9import jax.random as jr
10
11from decoding.types import FVX, KEY
12
13
14def getkey() -> KEY:
15    """
16    Get a random key for use in JAX functions.
17
18    Returns:
19        A random key.
20
21    Examples:
22        ```python
23        import jax.random as jr
24        from decoding.utils import getkey
25
26        key = getkey()
27        x = jr.normal(key, (10,))
28        assert x.shape == (10,)
29        ```
30
31    """
32    return jr.PRNGKey(secrets.randbelow(2**32))
33
34
35def logsoftmax(x: FVX, *, t: float = 1.0) -> FVX:
36    """
37    Compute the log-softmax of a vector.
38
39    Args:
40        x: The input vector.
41        t: The temperature of the softmax.
42
43    Returns:
44        The log-softmax of the input vector.
45
46    Examples:
47        ```python
48        import jax.numpy as jnp
49        import jax.nn as jnn
50        from decoding.utils import logsoftmax
51
52        x = jnp.array([1.0, 2.0, 3.0])
53        logp = logsoftmax(x)
54        assert jnn.logsumexp(logp) == 0.0
55        ```
56
57    """
58    if t == 0:
59        logp = jnp.where(jnp.arange(x.size) == jnp.argmax(x), 0.0, -jnp.inf)
60    elif t == float("inf"):
61        logp = jnp.full_like(x, -jnp.log(x.size))
62    else:
63        logp = jnn.log_softmax(x / t)
64    return logp
def getkey() -> jaxtyping.UInt[Array, '2']:
15def getkey() -> KEY:
16    """
17    Get a random key for use in JAX functions.
18
19    Returns:
20        A random key.
21
22    Examples:
23        ```python
24        import jax.random as jr
25        from decoding.utils import getkey
26
27        key = getkey()
28        x = jr.normal(key, (10,))
29        assert x.shape == (10,)
30        ```
31
32    """
33    return jr.PRNGKey(secrets.randbelow(2**32))

Get a random key for use in JAX functions.

Returns:

A random key.

Examples:
import jax.random as jr
from decoding.utils import getkey

key = getkey()
x = jr.normal(key, (10,))
assert x.shape == (10,)
def logsoftmax( x: jaxtyping.Float[Array, 'x'], *, t: float = 1.0) -> jaxtyping.Float[Array, 'x']:
36def logsoftmax(x: FVX, *, t: float = 1.0) -> FVX:
37    """
38    Compute the log-softmax of a vector.
39
40    Args:
41        x: The input vector.
42        t: The temperature of the softmax.
43
44    Returns:
45        The log-softmax of the input vector.
46
47    Examples:
48        ```python
49        import jax.numpy as jnp
50        import jax.nn as jnn
51        from decoding.utils import logsoftmax
52
53        x = jnp.array([1.0, 2.0, 3.0])
54        logp = logsoftmax(x)
55        assert jnn.logsumexp(logp) == 0.0
56        ```
57
58    """
59    if t == 0:
60        logp = jnp.where(jnp.arange(x.size) == jnp.argmax(x), 0.0, -jnp.inf)
61    elif t == float("inf"):
62        logp = jnp.full_like(x, -jnp.log(x.size))
63    else:
64        logp = jnn.log_softmax(x / t)
65    return logp

Compute the log-softmax of a vector.

Arguments:
  • x: The input vector.
  • t: The temperature of the softmax.
Returns:

The log-softmax of the input vector.

Examples:
import jax.numpy as jnp
import jax.nn as jnn
from decoding.utils import logsoftmax

x = jnp.array([1.0, 2.0, 3.0])
logp = logsoftmax(x)
assert jnn.logsumexp(logp) == 0.0