随机#

class flax.nnx.Dropout(*args, **kwargs)[source]#

创建一个丢弃层。

要使用丢弃,请调用 train() 方法(或在构造函数或调用时传入 deterministic=False)。

要禁用丢弃,请调用 eval() 方法(或在构造函数或调用时传入 deterministic=True)。

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class MLP(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear(x)
...     x = self.dropout(x)
...     return x

>>> model = MLP(rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 3))

>>> model.train() # use dropout
>>> model(x)
Array([[-0.9353421,  0.       ,  1.434417 ,  0.       ]], dtype=float32)

>>> model.eval() # don't use dropout
>>> model(x)
Array([[-0.46767104, -0.7213411 ,  0.7172085 , -0.31562346]], dtype=float32)
rate#

丢弃概率。(_不是_ 保留率!)

类型

float

broadcast_dims#

将共享相同丢弃掩码的维度

类型

collections.abc.Sequence[int]

deterministic#

如果为 false,则输入将按 1 / (1 - rate) 缩放并进行掩码,而如果为 true,则不应用任何掩码,并将按原样返回输入。

类型

bool

rng_collection#

请求 rng 密钥时要使用的 rng 集合名称。

类型

str

rngs#

rng 密钥。

类型

flax.nnx.rnglib.Rngs | None