使用过滤器#
注意:本页内容与新的 Flax NNX API 相关。
过滤器在 Flax NNX 中被广泛使用,作为在 API(如 nnx.split
、nnx.state
和许多 Flax NNX 变换中)创建 State
组的一种方式。例如
from flax import nnx
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = nnx.BatchStat(True)
foo = Foo()
graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
batch_stats = State({
'b': VariableState(
type=BatchStat,
value=True
)
})
这里 nnx.Param
和 nnx.BatchStat
用作过滤器,将模型分成两组:一组包含参数,另一组包含批统计信息。但是,这引出了以下问题
什么是过滤器?
为什么诸如
Param
或BatchStat
之类的类型是过滤器?如何对
State
进行分组/过滤?
过滤器协议#
一般来说,过滤器是以下形式的谓词函数
(path: tuple[Key, ...], value: Any) -> bool
其中 Key
是一种可哈希且可比较的类型,path
是 Key
的元组,表示嵌套结构中值的路径,value
是该路径上的值。如果该值应包含在组中,则该函数返回 True
;否则返回 False
。
类型显然不是这种形式的函数,因此它们被视为过滤器的理由是,正如我们将在下面看到的那样,类型和某些其他字面量会被转换为谓词。例如,Param
大致被转换为类似于以下的谓词
def is_param(path, value) -> bool:
return isinstance(value, nnx.Param) or (
hasattr(value, 'type') and issubclass(value.type, nnx.Param)
)
print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
此类函数匹配任何是 Param
实例的值,或任何具有 type
属性(该属性是 Param
的子类)的值。在内部,Flax NNX 使用 OfType
,它为给定类型定义了这种形式的可调用对象
is_param = nnx.OfType(nnx.Param)
print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
过滤器 DSL#
为了避免用户必须创建这些函数,Flax NNX 公开了一个小型 DSL,它被形式化为 nnx.filterlib.Filter
类型,该类型允许用户传递类型、布尔值、省略号、元组/列表等,并在内部将其转换为相应的谓词。
以下是 Flax NNX 中包含的所有可调用过滤器及其 DSL 字面量(如果可用)的列表
字面量 |
可调用对象 |
描述 |
---|---|---|
|
|
匹配所有值 |
|
|
不匹配任何值 |
|
|
匹配是 |
|
匹配具有包含给定 |
|
|
|
匹配具有等于 |
|
|
匹配与任何内部 |
|
匹配与所有内部 |
|
|
匹配不与内部 |
让我们看看 DSL 在 nnx.vmap
示例中的实际应用。假设我们要将所有参数和第 0 轴上的 dropout
Rng(Keys|Counts) 向量化,并广播其余部分。为此,我们可以使用以下过滤器
from functools import partial
@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})
def forward(model, x):
...
这里 (nnx.Param, 'dropout')
扩展为 Any(OfType(nnx.Param), WithTag('dropout'))
,而 ...
扩展为 Everything()
。
如果要手动将字面量转换为谓词,可以使用 nnx.filterlib.to_predicate
is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))
print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.nnx.variables.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.nnx.variables.Param'>), WithTag('dropout'))
分组状态#
了解了过滤器后,让我们看看 nnx.split
的大致实现。关键思想
使用
nnx.graph.flatten
获取节点的GraphDef
和State
表示。将所有过滤器转换为谓词。
使用
State.flat_state
获取状态的扁平表示。遍历扁平状态中的所有
(path, value)
对,并根据谓词对其进行分组。使用
State.from_flat_state
将扁平状态转换为嵌套的State
。
from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]
def split(node, *filters):
graphdef, state, _ = nnx.graph.flatten(node)
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
for path, value in state.flat_state().items():
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i][path] = value
break
else:
raise ValueError(f'No filter matched {path = } {value = }')
states: tuple[nnx.GraphState, ...] = tuple(
nnx.State.from_flat_path(flat_state) for flat_state in flat_states
)
return graphdef, *states
# lets test it...
foo = Foo()
graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
batch_stats = State({
'b': VariableState(
type=BatchStat,
value=True
)
})
需要注意的一点非常重要,即 过滤是依赖顺序的。匹配某个值的第一个过滤器将保留它,因此应将更具体的过滤器放在更通用的过滤器之前。例如,如果我们创建了一个 SpecialParam
类型,它是 Param
的子类,以及一个包含这两种类型参数的 Bar
对象,如果我们尝试在 SpecialParam
之前拆分 Param
,那么所有值都将被放置在 Param
组中,而 SpecialParam
组将为空,因为所有 SpecialParam
也都是 Param
class SpecialParam(nnx.Param):
pass
class Bar(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = SpecialParam(0)
bar = Bar()
graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': VariableState(
type=Param,
value=0
),
'b': VariableState(
type=SpecialParam,
value=0
)
})
special_params = State({})
反转顺序将确保首先捕获 SpecialParam
graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
special_params = State({
'b': VariableState(
type=SpecialParam,
value=0
)
})