decoding.utils
Miscellaneous helper functions.
1""" 2Miscellaneous helper functions. 3""" 4 5import secrets 6 7import jax.nn as jnn 8import jax.numpy as jnp 9import jax.random as jr 10 11from decoding.types import FVX, KEY 12 13 14def getkey() -> KEY: 15 """ 16 Get a random key for use in JAX functions. 17 18 Returns: 19 A random key. 20 21 Examples: 22 ```python 23 import jax.random as jr 24 from decoding.utils import getkey 25 26 key = getkey() 27 x = jr.normal(key, (10,)) 28 assert x.shape == (10,) 29 ``` 30 31 """ 32 return jr.PRNGKey(secrets.randbelow(2**32)) 33 34 35def logsoftmax(x: FVX, *, t: float = 1.0) -> FVX: 36 """ 37 Compute the log-softmax of a vector. 38 39 Args: 40 x: The input vector. 41 t: The temperature of the softmax. 42 43 Returns: 44 The log-softmax of the input vector. 45 46 Examples: 47 ```python 48 import jax.numpy as jnp 49 import jax.nn as jnn 50 from decoding.utils import logsoftmax 51 52 x = jnp.array([1.0, 2.0, 3.0]) 53 logp = logsoftmax(x) 54 assert jnn.logsumexp(logp) == 0.0 55 ``` 56 57 """ 58 if t == 0: 59 logp = jnp.where(jnp.arange(x.size) == jnp.argmax(x), 0.0, -jnp.inf) 60 elif t == float("inf"): 61 logp = jnp.full_like(x, -jnp.log(x.size)) 62 else: 63 logp = jnn.log_softmax(x / t) 64 return logp
def
getkey() -> jaxtyping.UInt[Array, '2']:
15def getkey() -> KEY: 16 """ 17 Get a random key for use in JAX functions. 18 19 Returns: 20 A random key. 21 22 Examples: 23 ```python 24 import jax.random as jr 25 from decoding.utils import getkey 26 27 key = getkey() 28 x = jr.normal(key, (10,)) 29 assert x.shape == (10,) 30 ``` 31 32 """ 33 return jr.PRNGKey(secrets.randbelow(2**32))
Get a random key for use in JAX functions.
Returns:
A random key.
Examples:
import jax.random as jr from decoding.utils import getkey key = getkey() x = jr.normal(key, (10,)) assert x.shape == (10,)
def
logsoftmax( x: jaxtyping.Float[Array, 'x'], *, t: float = 1.0) -> jaxtyping.Float[Array, 'x']:
36def logsoftmax(x: FVX, *, t: float = 1.0) -> FVX: 37 """ 38 Compute the log-softmax of a vector. 39 40 Args: 41 x: The input vector. 42 t: The temperature of the softmax. 43 44 Returns: 45 The log-softmax of the input vector. 46 47 Examples: 48 ```python 49 import jax.numpy as jnp 50 import jax.nn as jnn 51 from decoding.utils import logsoftmax 52 53 x = jnp.array([1.0, 2.0, 3.0]) 54 logp = logsoftmax(x) 55 assert jnn.logsumexp(logp) == 0.0 56 ``` 57 58 """ 59 if t == 0: 60 logp = jnp.where(jnp.arange(x.size) == jnp.argmax(x), 0.0, -jnp.inf) 61 elif t == float("inf"): 62 logp = jnp.full_like(x, -jnp.log(x.size)) 63 else: 64 logp = jnn.log_softmax(x / t) 65 return logp
Compute the log-softmax of a vector.
Arguments:
- x: The input vector.
- t: The temperature of the softmax.
Returns:
The log-softmax of the input vector.
Examples:
import jax.numpy as jnp import jax.nn as jnn from decoding.utils import logsoftmax x = jnp.array([1.0, 2.0, 3.0]) logp = logsoftmax(x) assert jnn.logsumexp(logp) == 0.0