使用过滤器

使用过滤器#

注意:本页内容与新的 Flax NNX API 相关。

过滤器在 Flax NNX 中被广泛使用,作为在 API(如 nnx.splitnnx.state 和许多 Flax NNX 变换中)创建 State 组的一种方式。例如

from flax import nnx

class Foo(nnx.Module):
  def __init__(self):
    self.a = nnx.Param(0)
    self.b = nnx.BatchStat(True)

foo = Foo()

graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)

print(f'{params = }')
print(f'{batch_stats = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
batch_stats = State({
  'b': VariableState(
    type=BatchStat,
    value=True
  )
})

这里 nnx.Paramnnx.BatchStat 用作过滤器,将模型分成两组:一组包含参数,另一组包含批统计信息。但是,这引出了以下问题

  • 什么是过滤器?

  • 为什么诸如 ParamBatchStat 之类的类型是过滤器?

  • 如何对 State 进行分组/过滤?

过滤器协议#

一般来说,过滤器是以下形式的谓词函数


(path: tuple[Key, ...], value: Any) -> bool

其中 Key 是一种可哈希且可比较的类型,pathKey 的元组,表示嵌套结构中值的路径,value 是该路径上的值。如果该值应包含在组中,则该函数返回 True;否则返回 False

类型显然不是这种形式的函数,因此它们被视为过滤器的理由是,正如我们将在下面看到的那样,类型和某些其他字面量会被转换为谓词。例如,Param 大致被转换为类似于以下的谓词

def is_param(path, value) -> bool:
  return isinstance(value, nnx.Param) or (
    hasattr(value, 'type') and issubclass(value.type, nnx.Param)
  )

print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True

此类函数匹配任何是 Param 实例的值,或任何具有 type 属性(该属性是 Param 的子类)的值。在内部,Flax NNX 使用 OfType,它为给定类型定义了这种形式的可调用对象

is_param = nnx.OfType(nnx.Param)

print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True

过滤器 DSL#

为了避免用户必须创建这些函数,Flax NNX 公开了一个小型 DSL,它被形式化为 nnx.filterlib.Filter 类型,该类型允许用户传递类型、布尔值、省略号、元组/列表等,并在内部将其转换为相应的谓词。

以下是 Flax NNX 中包含的所有可调用过滤器及其 DSL 字面量(如果可用)的列表

字面量

可调用对象

描述

...True

Everything()

匹配所有值

NoneFalse

Nothing()

不匹配任何值

type

OfType(type)

匹配是 type 实例的值,或具有 type 属性(该属性是 type 实例)的值

PathContains(key)

匹配具有包含给定 key 的关联 path 的值

'{filter}' str

WithTag('{filter}')

匹配具有等于 '{filter}' 的字符串 tag 属性的值。由 RngKeyRngCount 使用。

(*filters) tuple[*filters] list

Any(*filters)

匹配与任何内部 filters 匹配的值

All(*filters)

匹配与所有内部 filters 匹配的值

Not(filter)

匹配不与内部 filter 匹配的值

让我们看看 DSL 在 nnx.vmap 示例中的实际应用。假设我们要将所有参数和第 0 轴上的 dropout Rng(Keys|Counts) 向量化,并广播其余部分。为此,我们可以使用以下过滤器

from functools import partial

@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})
def forward(model, x):
  ...

这里 (nnx.Param, 'dropout') 扩展为 Any(OfType(nnx.Param), WithTag('dropout')),而 ... 扩展为 Everything()

如果要手动将字面量转换为谓词,可以使用 nnx.filterlib.to_predicate

is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))

print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.nnx.variables.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.nnx.variables.Param'>), WithTag('dropout'))

分组状态#

了解了过滤器后,让我们看看 nnx.split 的大致实现。关键思想

  • 使用 nnx.graph.flatten 获取节点的 GraphDefState 表示。

  • 将所有过滤器转换为谓词。

  • 使用 State.flat_state 获取状态的扁平表示。

  • 遍历扁平状态中的所有 (path, value) 对,并根据谓词对其进行分组。

  • 使用 State.from_flat_state 将扁平状态转换为嵌套的 State

from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]

def split(node, *filters):
  graphdef, state, _ = nnx.graph.flatten(node)
  predicates = [nnx.filterlib.to_predicate(f) for f in filters]
  flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]

  for path, value in state.flat_state().items():
    for i, predicate in enumerate(predicates):
      if predicate(path, value):
        flat_states[i][path] = value
        break
    else:
      raise ValueError(f'No filter matched {path = } {value = }')

  states: tuple[nnx.GraphState, ...] = tuple(
    nnx.State.from_flat_path(flat_state) for flat_state in flat_states
  )
  return graphdef, *states

# lets test it...
foo = Foo()

graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)

print(f'{params = }')
print(f'{batch_stats = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
batch_stats = State({
  'b': VariableState(
    type=BatchStat,
    value=True
  )
})

需要注意的一点非常重要,即 过滤是依赖顺序的。匹配某个值的第一个过滤器将保留它,因此应将更具体的过滤器放在更通用的过滤器之前。例如,如果我们创建了一个 SpecialParam 类型,它是 Param 的子类,以及一个包含这两种类型参数的 Bar 对象,如果我们尝试在 SpecialParam 之前拆分 Param,那么所有值都将被放置在 Param 组中,而 SpecialParam 组将为空,因为所有 SpecialParam 也都是 Param

class SpecialParam(nnx.Param):
  pass

class Bar(nnx.Module):
  def __init__(self):
    self.a = nnx.Param(0)
    self.b = SpecialParam(0)

bar = Bar()

graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  ),
  'b': VariableState(
    type=SpecialParam,
    value=0
  )
})
special_params = State({})

反转顺序将确保首先捕获 SpecialParam

graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
  'a': VariableState(
    type=Param,
    value=0
  )
})
special_params = State({
  'b': VariableState(
    type=SpecialParam,
    value=0
  )
})