rnglib#

class flax.nnx.Rngs(*args, **kwargs)[source]#

NNX rng 容器类。 要实例化 Rngs,请传入一个整数,指定起始种子。 Rngs 可以具有不同的“流”,允许用户生成不同的 rng 密钥。 例如,要为 paramsdropout 流生成密钥

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> rng1 = nnx.Rngs(0, params=1)
>>> rng2 = nnx.Rngs(0)

>>> assert rng1.params() != rng2.dropout()

因为我们传入的是 params=1,所以 params 的起始种子为 1,而 dropout 的起始种子默认是我们传入的 0,因为我们没有为 dropout 指定种子。 如果我们没有为 params 指定种子,那么两个流都将默认使用我们传入的 0

>>> rng1 = nnx.Rngs(0)
>>> rng2 = nnx.Rngs(0)

>>> assert rng1.params() == rng2.dropout()

The Rngs 容器类包含每个流的单独计数器。 每次调用流以生成新的 rng 密钥时,计数器都会增加 1。 要生成新的 rng 密钥,我们将当前 rng 流的计数器值折叠到其相应的起始种子中。 如果我们尝试为在实例化时未指定流生成 rng 密钥,则将使用 default 流(即,传递给 Rngs 的第一个位置参数在实例化期间是 default 起始种子)。

>>> rng1 = nnx.Rngs(100, params=42)
>>> # `params` stream starting seed is 42, counter is 0
>>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 0)
>>> # `dropout` stream starting seed is defaulted to 100, counter is 0
>>> assert rng1.dropout() == jax.random.fold_in(jax.random.key(100), 0)
>>> # empty stream starting seed is defaulted to 100, counter is 1
>>> assert rng1() == jax.random.fold_in(jax.random.key(100), 1)
>>> # `params` stream starting seed is 42, counter is 1
>>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 1)

让我们看看在 Module 中使用 Rngs 的示例,并通过手动串联 Rngs 验证输出

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     # Linear uses the `params` stream twice for kernel and bias
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     # Dropout uses the `dropout` stream once
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))

>>> def assert_same(x, rng_seed, **rng_kwargs):
...   model = Model(rngs=nnx.Rngs(rng_seed, **rng_kwargs))
...   out = model(x)
...
...   # manual forward propagation
...   rngs = nnx.Rngs(rng_seed, **rng_kwargs)
...   kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3))
...   assert (model.linear.kernel.value==kernel).all()
...   bias = nnx.initializers.zeros_init()(rngs.params(), (3,))
...   assert (model.linear.bias.value==bias).all()
...   mask = jax.random.bernoulli(rngs.dropout(), p=0.5, shape=(1, 3))
...   # dropout scales the output proportional to the dropout rate
...   manual_out = mask * (jnp.dot(x, kernel) + bias) / 0.5
...   assert (out == manual_out).all()

>>> x = jnp.ones((1, 2))
>>> assert_same(x, 0)
>>> assert_same(x, 0, params=1)
>>> assert_same(x, 0, params=1, dropout=2)
__init__(default=None, /, **rngs)[source]#
参数
  • defaultdefault 流的起始种子。 任何从 **rngs 关键字参数中未指定流生成的密钥都将默认使用此起始种子。

  • **rngs – 可选的关键字参数,用于指定不同 rng 流的起始种子。 关键字是流名称,其值是该流的相应起始种子。

class flax.nnx.RngStream(*args: 'Any', **kwargs: 'Any')[source]#
flax.nnx.reseed(node, /, **stream_keys)[source]#

使用新密钥更新指定 RNG 流的密钥。

参数
  • node – 要在其中重新设置 RNG 流的节点。

  • **stream_keys – 流名称到新密钥的映射。 密钥可以是整数或 jax 数组。 如果传入整数,则将使用 jax.random.key 生成密钥。

引发

ValueError – 如果现有的流密钥不是标量。

示例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))
...
>>> model = Model(nnx.Rngs(params=0, dropout=42))
>>> x = jnp.ones((1, 2))
...
>>> y1 = model(x)
...
>>> # reset the ``dropout`` stream key to 42
>>> nnx.reseed(model, dropout=42)
>>> y2 = model(x)
...
>>> jnp.allclose(y1, y2)
Array(True, dtype=bool)