module#

class flax.nnx.Module(*args, **kwargs)[source]#

所有神经网络模块的基类。

层和模型应该继承这个类。

Module 可以包含子模块,并且可以通过这种方式以树形结构嵌套。子模块可以在 __init__ 方法中作为常规属性分配。

可以在 Module 子类上定义任意“前向传递”方法。虽然没有方法是特殊情况,但 __call__ 是一个流行的选择,因为可以直接调用 Module

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#

sow() 可用于收集中间值,而不会显式地通过每个 Module 调用传递容器带来的开销。 sow() 将值存储在一个新的 Module 属性中,用 name 表示。该值将被 Variable 类型为 variable_type 的包装,这对于在 split()state()pop() 中过滤很有用。

默认情况下,值存储在元组中,并且每个存储的值都追加到末尾。这样,当多次调用同一个模块时,就可以跟踪所有中间值。

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x, add=0):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x+add)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')

>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i.value) == 1 # tuple of length 1
>>> assert model.i.value[0].shape == (1, 3)

>>> y = model(x, add=1)
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()

或者,可以传递自定义 init/reduce 函数

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     self.sow(nnx.Intermediate, 'product', x,
...              init_fn=lambda: 1,
...              reduce_fn=lambda prev, curr: prev*curr)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))

>>> y = model(x)
>>> assert (model.sum.value == model.product.value).all()
>>> intermediate = model.sum.value

>>> y = model(x)
>>> assert (model.sum.value == intermediate*2).all()
>>> assert (model.product.value == intermediate**2).all()
参数
  • variable_type – 存储值的 Variable 类型。通常使用 Intermediate 来表示中间值。

  • name – 表示 Module 属性名称的字符串,存储 sowed 值的位置。

  • value – 要存储的值。

  • reduce_fn – 用于将现有值与新值组合的函数。默认情况下,将值追加到元组。

  • init_fn – 对于存储的第一个值,reduce_fn 将传递 init_fn 的结果以及要存储的值。默认值为空元组。

iter_modules()[source]#

递归遍历当前 Module 的所有嵌套 Module,包括当前 Module。

iter_modules 创建一个生成器,它会生成路径和 Module 实例,其中路径是一个字符串或整数元组,表示从根 Module 到 Module 的路径。

示例

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
eval(**attributes)[source]#

将 Module 设置为评估模式。

eval 使用 set_attributes 递归设置所有嵌套 Module 的属性 deterministic=Trueuse_running_average=True,这些 Module 具有这些属性。它主要用于控制 DropoutBatchNorm Module 的运行时行为。

示例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
参数

**attributes – 传递给 set_attributes 的其他属性。

iter_children()[source]#

遍历当前 Module 的所有子 Module。此方法类似于 iter_modules(),不同之处在于它只遍历直接子级,并且不会递归到更下方。

iter_children 创建一个生成器,它会生成键和 Module 实例,其中键是一个字符串,表示用于访问相应子 Module 的 Module 的属性名称。

示例

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_children():
...  print(path, type(module).__name__)
...
batch_norm BatchNorm
dropout Dropout
linear Linear
submodule SubModule
iter_modules()[source]#

递归遍历当前 Module 的所有嵌套 Module,包括当前 Module。

iter_modules 创建一个生成器,它会生成路径和 Module 实例,其中路径是一个字符串或整数元组,表示从根 Module 到 Module 的路径。

示例

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
set_attributes(*filters, raise_if_not_found=True, **attributes)[source]#

设置嵌套 Module 的属性,包括当前 Module。如果在 Module 中找不到属性,则会忽略它。

示例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter 可用于设置特定 Module 的属性

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
参数
  • *filters – 用于选择要设置属性的 Module 的过滤器。

  • raise_if_not_found – 如果为 True(默认值),则如果在所选 Module 中至少找不到一个属性实例,则会引发 ValueError。

  • **attributes – 要设置的属性。

sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#

sow() 可用于收集中间值,而不会显式地通过每个 Module 调用传递容器带来的开销。 sow() 将值存储在一个新的 Module 属性中,用 name 表示。该值将被 Variable 类型为 variable_type 的包装,这对于在 split()state()pop() 中过滤很有用。

默认情况下,值存储在元组中,并且每个存储的值都追加到末尾。这样,当多次调用同一个模块时,就可以跟踪所有中间值。

示例用法

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x, add=0):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x+add)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')

>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i.value) == 1 # tuple of length 1
>>> assert model.i.value[0].shape == (1, 3)

>>> y = model(x, add=1)
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()

或者,可以传递自定义 init/reduce 函数

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     self.sow(nnx.Intermediate, 'product', x,
...              init_fn=lambda: 1,
...              reduce_fn=lambda prev, curr: prev*curr)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))

>>> y = model(x)
>>> assert (model.sum.value == model.product.value).all()
>>> intermediate = model.sum.value

>>> y = model(x)
>>> assert (model.sum.value == intermediate*2).all()
>>> assert (model.product.value == intermediate**2).all()
参数
  • variable_type – 存储值的 Variable 类型。通常使用 Intermediate 来表示中间值。

  • name – 表示 Module 属性名称的字符串,存储 sowed 值的位置。

  • value – 要存储的值。

  • reduce_fn – 用于将现有值与新值组合的函数。默认情况下,将值追加到元组。

  • init_fn – 对于存储的第一个值,reduce_fn 将传递 init_fn 的结果以及要存储的值。默认值为空元组。

train(**attributes)[source]#

将 Module 设置为训练模式。

train 使用 set_attributes 递归设置所有嵌套 Module 的属性 deterministic=Falseuse_running_average=False,这些 Module 具有这些属性。它主要用于控制 DropoutBatchNorm Module 的运行时行为。

示例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
参数

**attributes – 传递给 set_attributes 的其他属性。