bridge#
- class flax.nnx.bridge.ToNNX(*args, **kwargs)[source]#
一个包装器,用于将任何 Linen 模块转换为 NNX 模块。
生成的 NNX 模块可以独立使用所有 NNX API,也可以用作另一个 NNX 模块的子模块。
由于 Linen 模块初始化需要示例输入,因此您需要使用参数调用 lazy_init 来初始化变量。
示例
>>> from flax import linen as nn, nnx >>> import jax >>> linen_module = nn.Dense(features=64) >>> x = jax.numpy.ones((1, 32)) >>> # Like Linen init(), initialize with a sample input >>> model = nnx.bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) >>> # Like Linen apply(), but using NNX's direct call method >>> y = model(x) >>> model.kernel.shape (32, 64)
- 参数
module – Linen 模块实例。
rngs – 传递给任何 NNX 模块的 nnx.Rngs 实例。
- 返回值
一个有状态的 NNX 模块,其行为与包装的 Linen 模块相同。
方法
lazy_init
(*args, **kwargs)调用此模块上的 nnx.bridge.lazy_init() 的快捷方式。
- class flax.nnx.bridge.ToLinen(nnx_class, args=(), kwargs=<factory>, skip_rng=False, metadata_type=<class 'flax.nnx.bridge.variables.NNXMeta'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
一个包装器,用于将任何 NNX 模块转换为 Linen 模块。
生成的 Linen 模块可以独立使用所有 Linen API,也可以用作另一个 Linen 模块的子模块。
由于 NNX 模块是有状态的,并且拥有状态,因此我们只在初始化时创建一次,并将跟踪其状态和静态数据作为单独的变量。
示例
>>> from flax import linen as nn, nnx >>> import jax >>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64)) >>> x = jax.numpy.ones((1, 32)) >>> y, variables = model.init_with_output(jax.random.key(0), x) >>> y.shape (1, 64) >>> variables['params']['kernel'].shape (32, 64) >>> # The static GraphDef of the underlying NNX module >>> variables.keys() dict_keys(['nnx', 'params']) >>> type(variables['nnx']['graphdef']) <class 'flax.nnx.graph.NodeDef'>
- 参数
nnx_class – NNX 模块类(而不是实例!)。
args – 通常会传递给创建 NNX 模块的参数。
kwargs – 通常会传递给创建 NNX 模块的关键字参数。
skip_rng – 如果此 NNX 模块在初始化期间不需要 rngs 参数(不常见),则为 True。
- 返回值
一个有状态的 NNX 模块,其行为与包装的 Linen 模块相同。
方法
- flax.nnx.bridge.to_linen(nnx_class, *args, name=None, **kwargs)[source]#
如果用户没有更改任何默认字段,则为 nnx.bridge.ToLinen 的快捷方式。
- class flax.nnx.bridge.NNXMeta(var_type, value, metadata)[source]#
用于 nnx.VariableState 的默认 Flax 元数据类。
- __call__(**kwargs)#
将 self 作为函数调用。
- add_axis(index, params)[source]#
在轴元数据中添加一个新的轴。
请注意,add_axis 和 remove_axis 应该互为反函数(即:
x.add_axis(i, p).remove_axis(i, p) == x
)- 参数
index – 新轴将要插入的位置
params – 由引入新轴的变换传递的任意参数字典(例如:
nn.scan
或nn.vmap
)。用户将此字典作为 metadata_param 参数传递给变换。
- 返回值
与 self 类型相同的新实例,具有相同
unbox
内容,并更新轴元数据。
- remove_axis(index, params)[source]#
从轴元数据中删除一个轴。
请注意,add_axis 和 remove_axis 应该互为反函数(即:
x.remove_axis(i, p).add_axis(i, p) == x
)- 参数
index – 将要删除的轴的位置
params – 由引入轴的变换传递的任意参数字典(例如:
nn.scan
或nn.vmap
)。用户将此字典作为 metadata_param 参数传递给变换。
- 返回值
与 self 类型相同的新实例,具有相同
unbox
内容,并更新轴元数据。
- replace(**updates)#
“返回一个新的对象,将指定的字段替换为新值。
- replace_boxed(val)[source]#
用提供的 value 替换 boxed value。
- 参数
val – 将要由此 AxisMetadata 包装器 boxed 的新 value
- 返回值
与 self 类型相同的新的实例,具有 val 作为新的
unbox
内容
- unbox()[source]#
返回 AxisMetadata 盒子的内容。
请注意,与
meta.unbox
不同,unbox 调用不应递归地拆解元数据。它应该简单地返回其直接包装的 value,即使该 value 本身是 AxisMetadata 的实例。在实践中,AxisMetadata 子类应注册为 PyTree 节点,以支持将实例传递到 JAX 和 Flax API。 此节点返回的叶子应对应于 unbox 返回的值。
- 返回值
解箱后的值。
方法
add_axis
(index, params)在轴元数据中添加一个新的轴。
获取分区规范
()返回此分区的 value 的
Partitionspec
。remove_axis
(index, params)从轴元数据中删除一个轴。
replace
(**updates)"返回一个新对象,用新值替换指定的字段。
replace_boxed
(val)用提供的 value 替换 boxed value。
解箱
()返回 AxisMetadata 盒子的内容。