module#
- class flax.nnx.Module(*args, **kwargs)[source]#
所有神经网络模块的基类。
层和模型应该继承这个类。
Module
可以包含子模块,并且可以通过这种方式以树形结构嵌套。子模块可以在__init__
方法中作为常规属性分配。可以在
Module
子类上定义任意“前向传递”方法。虽然没有方法是特殊情况,但__call__
是一个流行的选择,因为可以直接调用Module
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#
sow()
可用于收集中间值,而不会显式地通过每个 Module 调用传递容器带来的开销。sow()
将值存储在一个新的Module
属性中,用name
表示。该值将被Variable
类型为variable_type
的包装,这对于在split()
、state()
和pop()
中过滤很有用。默认情况下,值存储在元组中,并且每个存储的值都追加到末尾。这样,当多次调用同一个模块时,就可以跟踪所有中间值。
示例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x, add=0): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x+add) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> assert len(model.i.value) == 1 # tuple of length 1 >>> assert model.i.value[0].shape == (1, 3) >>> y = model(x, add=1) >>> assert len(model.i.value) == 2 # tuple of length 2 >>> assert (model.i.value[0] + 1 == model.i.value[1]).all()
或者,可以传递自定义 init/reduce 函数
>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'sum', x, ... init_fn=lambda: 0, ... reduce_fn=lambda prev, curr: prev+curr) ... self.sow(nnx.Intermediate, 'product', x, ... init_fn=lambda: 1, ... reduce_fn=lambda prev, curr: prev*curr) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x) >>> assert (model.sum.value == model.product.value).all() >>> intermediate = model.sum.value >>> y = model(x) >>> assert (model.sum.value == intermediate*2).all() >>> assert (model.product.value == intermediate**2).all()
- 参数
variable_type – 存储值的
Variable
类型。通常使用Intermediate
来表示中间值。name – 表示
Module
属性名称的字符串,存储 sowed 值的位置。value – 要存储的值。
reduce_fn – 用于将现有值与新值组合的函数。默认情况下,将值追加到元组。
init_fn – 对于存储的第一个值,
reduce_fn
将传递init_fn
的结果以及要存储的值。默认值为空元组。
- iter_modules()[source]#
递归遍历当前 Module 的所有嵌套
Module
,包括当前 Module。iter_modules
创建一个生成器,它会生成路径和 Module 实例,其中路径是一个字符串或整数元组,表示从根 Module 到 Module 的路径。示例
>>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_modules(): ... print(path, type(module).__name__) ... ('batch_norm',) BatchNorm ('dropout',) Dropout ('linear',) Linear ('submodule', 'linear1') Linear ('submodule', 'linear2') Linear ('submodule',) SubModule () Block
- eval(**attributes)[source]#
将 Module 设置为评估模式。
eval
使用set_attributes
递归设置所有嵌套 Module 的属性deterministic=True
和use_running_average=True
,这些 Module 具有这些属性。它主要用于控制Dropout
和BatchNorm
Module 的运行时行为。示例
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.eval() >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
- 参数
**attributes – 传递给
set_attributes
的其他属性。
- iter_children()[source]#
遍历当前 Module 的所有子
Module
。此方法类似于iter_modules()
,不同之处在于它只遍历直接子级,并且不会递归到更下方。iter_children
创建一个生成器,它会生成键和 Module 实例,其中键是一个字符串,表示用于访问相应子 Module 的 Module 的属性名称。示例
>>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_children(): ... print(path, type(module).__name__) ... batch_norm BatchNorm dropout Dropout linear Linear submodule SubModule
- iter_modules()[source]#
递归遍历当前 Module 的所有嵌套
Module
,包括当前 Module。iter_modules
创建一个生成器,它会生成路径和 Module 实例,其中路径是一个字符串或整数元组,表示从根 Module 到 Module 的路径。示例
>>> from flax import nnx ... >>> class SubModule(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.linear1 = nnx.Linear(din, dout, rngs=rngs) ... self.linear2 = nnx.Linear(din, dout, rngs=rngs) ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.submodule = SubModule(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5) ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) ... >>> model = Block(2, 5, rngs=nnx.Rngs(0)) >>> for path, module in model.iter_modules(): ... print(path, type(module).__name__) ... ('batch_norm',) BatchNorm ('dropout',) Dropout ('linear',) Linear ('submodule', 'linear1') Linear ('submodule', 'linear2') Linear ('submodule',) SubModule () Block
- set_attributes(*filters, raise_if_not_found=True, **attributes)[source]#
设置嵌套 Module 的属性,包括当前 Module。如果在 Module 中找不到属性,则会忽略它。
示例
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, deterministic=False) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False) >>> block.set_attributes(deterministic=True, use_running_average=True) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True)
Filter
可用于设置特定 Module 的属性>>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.set_attributes(nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, False)
- 参数
*filters – 用于选择要设置属性的 Module 的过滤器。
raise_if_not_found – 如果为 True(默认值),则如果在所选 Module 中至少找不到一个属性实例,则会引发 ValueError。
**attributes – 要设置的属性。
- sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#
sow()
可用于收集中间值,而不会显式地通过每个 Module 调用传递容器带来的开销。sow()
将值存储在一个新的Module
属性中,用name
表示。该值将被Variable
类型为variable_type
的包装,这对于在split()
、state()
和pop()
中过滤很有用。默认情况下,值存储在元组中,并且每个存储的值都追加到末尾。这样,当多次调用同一个模块时,就可以跟踪所有中间值。
示例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x, add=0): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x+add) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> assert len(model.i.value) == 1 # tuple of length 1 >>> assert model.i.value[0].shape == (1, 3) >>> y = model(x, add=1) >>> assert len(model.i.value) == 2 # tuple of length 2 >>> assert (model.i.value[0] + 1 == model.i.value[1]).all()
或者,可以传递自定义 init/reduce 函数
>>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'sum', x, ... init_fn=lambda: 0, ... reduce_fn=lambda prev, curr: prev+curr) ... self.sow(nnx.Intermediate, 'product', x, ... init_fn=lambda: 1, ... reduce_fn=lambda prev, curr: prev*curr) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x) >>> assert (model.sum.value == model.product.value).all() >>> intermediate = model.sum.value >>> y = model(x) >>> assert (model.sum.value == intermediate*2).all() >>> assert (model.product.value == intermediate**2).all()
- 参数
variable_type – 存储值的
Variable
类型。通常使用Intermediate
来表示中间值。name – 表示
Module
属性名称的字符串,存储 sowed 值的位置。value – 要存储的值。
reduce_fn – 用于将现有值与新值组合的函数。默认情况下,将值追加到元组。
init_fn – 对于存储的第一个值,
reduce_fn
将传递init_fn
的结果以及要存储的值。默认值为空元组。
- train(**attributes)[source]#
将 Module 设置为训练模式。
train
使用set_attributes
递归设置所有嵌套 Module 的属性deterministic=False
和use_running_average=False
,这些 Module 具有这些属性。它主要用于控制Dropout
和BatchNorm
Module 的运行时行为。示例
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... # initialize Dropout and BatchNorm in eval mode ... self.dropout = nnx.Dropout(0.5, deterministic=True) ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs) ... >>> block = Block(2, 5, rngs=nnx.Rngs(0)) >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, True) >>> block.train() >>> block.dropout.deterministic, block.batch_norm.use_running_average (False, False)
- 参数
**attributes – 传递给
set_attributes
的其他属性。