图表#
- flax.nnx.split(node, *filters)[source]#
将图表节点拆分为
GraphDef
和一个或多个State`s. State is a ``Mapping`
从字符串或整数到Variables
、数组或嵌套状态。GraphDef 包含重建Module
图表所需的所有静态信息,类似于 JAX 的PyTreeDef
。split()
与merge()
结合使用,可以在图表的有状态和无状态表示之间无缝切换。示例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { 'bias': VariableState( type=Param, value=(2,) ), 'scale': VariableState( type=Param, value=(2,) ) }, 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { 'mean': VariableState( type=BatchStat, value=(2,) ), 'var': VariableState( type=BatchStat, value=(2,) ) } })
split()
和merge()
主要用于直接与 JAX 转换交互,有关更多信息,请参阅 函数式 API。- 参数
node – 要拆分的图表节点。
*filters – 一些可选过滤器,用于将状态分组为相互排斥的子状态。
- 返回值
GraphDef
和一个或多个States
等于传递的过滤器数量。如果没有传递过滤器,则返回单个State
。
- flax.nnx.merge(graphdef, state, /, *states)[source]#
split()
的反函数。merge
接受一个GraphDef
和一个或多个State
,并创建一个新的节点,其结构与原始节点相同。示例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> new_node = nnx.merge(graphdef, params, batch_stats) >>> assert isinstance(new_node, Foo) >>> assert isinstance(new_node.batch_norm, nnx.BatchNorm) >>> assert isinstance(new_node.linear, nnx.Linear)
- flax.nnx.update(node, state, /, *states)[source]#
使用新的
State
原地更新给定的图表节点。示例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> def loss_fn(model, x, y): ... return jnp.mean((y - model(x))**2) >>> prev_loss = loss_fn(model, x, y) >>> grads = nnx.grad(loss_fn)(model, x, y) >>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads) >>> nnx.update(model, new_state) >>> assert loss_fn(model, x, y) < prev_loss
- flax.nnx.pop(node, *filters)[source]#
从图表节点中弹出一种或多种
Variable
类型。示例用法
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> assert not hasattr(model, 'i') >>> y = model(x) >>> assert hasattr(model, 'i') >>> intermediates = nnx.pop(model, nnx.Intermediate) >>> assert intermediates['i'].value[0].shape == (1, 3) >>> assert not hasattr(model, 'i')
- flax.nnx.state(node, *filters)[source]#
-
示例用法
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batch_norm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> # get the learnable parameters from the batch norm and linear layer >>> params = nnx.state(model, nnx.Param) >>> # get the batch statistics from the batch norm layer >>> batch_stats = nnx.state(model, nnx.BatchStat) >>> # get them separately >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> # get them together >>> state = nnx.state(model)
- flax.nnx.graphdef(node, /)[source]#
获取给定图表节点的
GraphDef
。示例用法
>>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> graphdef, _ = nnx.split(model) >>> assert graphdef == nnx.graphdef(model)
- flax.nnx.iter_graph(node, /)[source]#
遍历给定图表节点的所有嵌套节点和叶子,包括当前节点。
iter_graph
创建一个生成器,该生成器会生成路径和值对,其中路径是表示从根节点到值的路径的字符串或整数元组。重复节点仅访问一次。叶子包括静态值。- 示例:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Linear(nnx.Module): ... def __init__(self, din, dout, *, rngs: nnx.Rngs): ... self.din, self.dout = din, dout ... self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... >>> module = Linear(3, 4, rngs=nnx.Rngs(0)) >>> graph = [module, module] ... >>> for path, value in nnx.iter_graph(graph): ... print(path, type(value).__name__) ... (0, 'b') Param (0, 'din') int (0, 'dout') int (0, 'w') Param (0,) Linear () list
- flax.nnx.clone(node)[source]#
创建给定图表节点的深层副本。
示例用法
>>> from flax import nnx >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> cloned_model = nnx.clone(model) >>> model.bias.value += 1 >>> assert (model.bias.value != cloned_model.bias.value).all()
- 参数
node – 一个图表节点对象。
- 返回值
该
Module
对象的深层副本。
- flax.nnx.call(graphdef_state, /)[source]#
调用由 (GraphDef, State) 对定义的底层图表节点的方法。
call
接收一个(GraphDef, State)
对并创建一个代理对象,可用于调用底层图节点上的方法。 当调用方法时,将返回输出以及一个新的 (GraphDef, State) 对,该对表示图节点的更新状态。call
等效于merge()
>method
>split`()
但在纯 JAX 函数中使用起来更方便。示例
>>> from flax import nnx >>> import jax >>> import jax.numpy as jnp ... >>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> linear = StatefulLinear(3, 2, nnx.Rngs(0)) >>> linear_state = nnx.split(linear) ... >>> @jax.jit ... def forward(x, linear_state): ... y, linear_state = nnx.call(linear_state)(x) ... return y, linear_state ... >>> x = jnp.ones((1, 3)) >>> y, linear_state = forward(x, linear_state) >>> y, linear_state = forward(x, linear_state) ... >>> linear = nnx.merge(*linear_state) >>> linear.count.value Array(2, dtype=uint32)
由
call
返回的代理对象支持索引和属性访问以访问嵌套方法。 在下面的示例中,increment
方法索引用于调用nodes
字典的b
键处的StatefulLinear
模块的increment
方法。>>> class StatefulLinear(nnx.Module): ... def __init__(self, din, dout, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout))) ... self.b = nnx.Param(jnp.zeros((dout,))) ... self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32)) ... ... def increment(self): ... self.count += 1 ... ... def __call__(self, x): ... self.increment() ... return x @ self.w + self.b ... >>> rngs = nnx.Rngs(0) >>> nodes = dict( ... a=StatefulLinear(3, 2, rngs), ... b=StatefulLinear(2, 1, rngs), ... ) ... >>> node_state = nnx.split(nodes) >>> # use attribute access >>> _, node_state = nnx.call(node_state)['b'].increment() ... >>> nodes = nnx.merge(*node_state) >>> nodes['a'].count.value Array(0, dtype=uint32) >>> nodes['b'].count.value Array(1, dtype=uint32)
- class flax.nnx.UpdateContext(tag, ref_index, index_ref)[source]#
用于处理复杂状态更新的上下文管理器。
- split(node, *filters)[source]#
将图表节点拆分为
GraphDef
和一个或多个State`s. State is a ``Mapping`
从字符串或整数到Variables
、数组或嵌套状态。GraphDef 包含重建Module
图表所需的所有静态信息,类似于 JAX 的PyTreeDef
。split()
与merge()
结合使用,可以在图表的有状态和无状态表示之间无缝切换。示例用法
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> class Foo(nnx.Module): ... def __init__(self, rngs): ... self.batch_norm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... >>> node = Foo(nnx.Rngs(0)) >>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat) ... >>> jax.tree.map(jnp.shape, params) State({ 'batch_norm': { 'bias': VariableState( type=Param, value=(2,) ), 'scale': VariableState( type=Param, value=(2,) ) }, 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> jax.tree.map(jnp.shape, batch_stats) State({ 'batch_norm': { 'mean': VariableState( type=BatchStat, value=(2,) ), 'var': VariableState( type=BatchStat, value=(2,) ) } })
- flax.nnx.update_context(tag)[source]#
创建一个
UpdateContext
上下文管理器,可用于处理比nnx.update
可以处理的更复杂的 state 更新,包括对静态属性和图结构的更新。UpdateContext 公开了一个
split
和merge
API,与nnx.split
/nnx.merge
的签名相同,但执行一些簿记以获取必要的 信息,以便能够根据转换内部进行的更改完美地更新输入对象。 UpdateContext 必须总共调用 split 和 merge 4 次,第一次和最后一次调用发生在转换之外,第二次和第三次调用发生在转换内部,如下面的图表所示idxmap (2) merge ─────────────────────────────► split (3) ▲ │ │ inside │ │. . . . . . . . . . . . . . . . . . │ index_mapping │ outside │ │ ▼ (1) split──────────────────────────────► merge (4) refmap
第一次调用 split
(1)
创建一个refmap
用于跟踪外部引用,第一次调用 merge(2)
创建一个idxmap
用于跟踪内部引用。 第二次调用 split(3)
组合 refmap 和 idxmap 以生成index_mapping
,该映射指示外部引用如何映射到内部引用。 最后,对 merge 的最后一次调用(4)
使用 index_mapping 和 refmap 来重建转换的输出,同时重用/更新内部引用。 为了避免内存泄漏,idxmap 在(3)
之后清除,refmap 在(4)
之后清除,两者在上下文管理器退出后清除。以下是一个简单示例,展示了
update_context
的用法>>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> with nnx.update_context('example') as ctx: ... graphdef, state = ctx.split(m1) ... @jax.jit ... def f(graphdef, state): ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 ... m2.ref = m2 # create a reference cycle ... return ctx.split(m2) ... graphdef_out, state_out = f(graphdef, state) ... m3 = ctx.merge(graphdef_out, state_out) ... >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1
请注意,
update_context
接收一个tag
参数,该参数主要用作安全机制,以减少在使用current_update_context()
访问当前活动上下文时意外使用错误的 UpdateContext 的风险。 current_update_context 可用作访问当前活动上下文的一种方式,而无需将其作为捕获传递>>> from flax import nnx ... >>> m1 = nnx.Dict({}) >>> @jax.jit ... def f(graphdef, state): ... ctx = nnx.current_update_context('example') ... m2 = ctx.merge(graphdef, state) ... m2.a = 1 # insert static attribute ... m2.ref = m2 # create a reference cycle ... return ctx.split(m2) ... >>> @nnx.update_context('example') ... def g(m1): ... ctx = nnx.current_update_context('example') ... graphdef, state = ctx.split(m1) ... graphdef_out, state_out = f(graphdef, state) ... return ctx.merge(graphdef_out, state_out) ... >>> m3 = g(m1) >>> assert m1 is m3 >>> assert m1.a == 1 >>> assert m1.ref is m1
如上面的代码所示,
update_context
也可以用作装饰器,为函数的持续时间创建/激活一个 UpdateContext 上下文。 可以使用current_update_context()
访问该上下文。- 参数
tag – 用于标识上下文的字符串标签。
- flax.nnx.current_update_context(tag)[source]#
返回给定标签的当前活动
UpdateContext
。