模型手术#

注意:此页面与新的 Flax NNX API 相关。

在本指南中,您将学习如何使用 Flax NNX 进行模型手术,并了解一些真实的用例。

  • Pythonic 模块操作:用于操作给定模型的子模块的 Pythonic 方法。

  • 操作抽象模型或状态:一种关键技巧,可以在没有内存分配的情况下使用 Flax NNX 模块和状态。

  • 检查点手术:从原始状态到模型:如何在模型代码与现有模型代码不兼容时操作参数状态。

  • 部分初始化:如何使用朴素方法或内存高效方法仅从头开始初始化模型的一部分。

from typing import *
from pprint import pprint
import functools

import jax
from jax import lax, numpy as jnp, tree_util as jtu

from jax.sharding import PartitionSpec, Mesh, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import flax.traverse_util
import numpy as np
import orbax.checkpoint as orbax

key = jax.random.key(0)
class TwoLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(dim, dim, rngs=rngs)
    self.linear2 = nnx.Linear(dim, dim, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    return self.linear2(x)

Pythonic 模块操作#

当您已经拥有一个加载了正确参数的完整模型,并且不打算更改模型定义代码时,模型手术最容易。

您可以对模型的子模块执行各种 Pythonic 操作,例如子模块交换、模块共享、变量共享和猴子补丁。

model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
np.testing.assert_allclose(model(x), model.linear2(model.linear1(x)))

# Sub-module swapping
original1, original2 = model.linear1, model.linear2
model.linear1, model.linear2 = model.linear2, model.linear1
np.testing.assert_allclose(model(x), original1(original2(x)))

# Module sharing (tying all weights)
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear2 = model.linear1
assert not hasattr(nnx.state(model), 'linear2')
np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))

# Variable sharing (weight-tying)
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear1.kernel = model.linear2.kernel  # the bias parameter is kept separate
assert hasattr(nnx.state(model), 'linear2')
assert hasattr(nnx.state(model)['linear2'], 'bias')
assert not hasattr(nnx.state(model)['linear2'], 'kernel')

# Monkey-patching
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
def awesome_layer(x): return x
model.linear2 = awesome_layer
np.testing.assert_allclose(model(x), model.linear1(x))

创建没有内存分配的抽象模型或状态#

对于更复杂的模型手术,关键技巧是创建和操作抽象模型或状态,而无需分配任何实际的参数数据。这使得试用迭代更快,并且消除了对内存限制的担忧。

要创建抽象模型,

  • 创建一个返回有效 Flax NNX 模型的函数;

  • 对其运行 nnx.eval_shape(而不是 jax.eval_shape)。

现在您可以像往常一样使用 nnx.split 来获取其抽象状态。请注意,所有在真实模型中应该是 jax.Array 的字段现在都是抽象的 jax.ShapeDtypeStruct,只包含形状/数据类型/分片信息。

abs_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
gdef, abs_state = nnx.split(abs_model)
pprint(abs_state)
State({
  'linear1': {
    'bias': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4,), dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)
    )
  },
  'linear2': {
    'bias': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4,), dtype=float32)
    ),
    'kernel': VariableState(
      type=Param,
      value=ShapeDtypeStruct(shape=(4, 4), dtype=float32)
    )
  }
})

当您用真实 jax 数组填充每个 VariableState 叶子节点的 value 时,抽象模型就等同于真实模型。

model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
abs_state['linear1']['kernel'].value = model.linear1.kernel
abs_state['linear1']['bias'].value = model.linear1.bias
abs_state['linear2']['kernel'].value = model.linear2.kernel
abs_state['linear2']['bias'].value = model.linear2.bias
nnx.update(abs_model, abs_state)
np.testing.assert_allclose(abs_model(x), model(x))  # They are equivalent now!

检查点手术#

有了抽象状态技术,您可以对任何检查点(或运行时参数 pytree)进行任意操作,以使它们符合给定的模型代码,然后调用 nnx.update 来合并它们。

当您尝试大幅更改模型代码(例如,从 Flax Linen 迁移到 Flax NNX)时,这可能很有用,因为旧的权重不再与新的模型状态结构自然兼容。让我们在这里运行一个简单的示例。

# Save a version of model into a checkpoint
checkpointer = orbax.PyTreeCheckpointer()
old_model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
checkpointer.save(f'/tmp/nnx-surgery-state', nnx.state(model), force=True)

在这个新模型中,子模块从 linear(1|2) 重命名为 layer(1|2)。由于 pytree 结构发生了变化,因此无法使用新的模型状态结构加载旧的检查点。

class ModifiedTwoLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.layer1 = nnx.Linear(dim, dim, rngs=rngs)  # no longer linear1!
    self.layer2 = nnx.Linear(dim, dim, rngs=rngs)

  def __call__(self, x):
    x = self.layer1(x)
    return self.layer2(x)

abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
try:
  with_item = checkpointer.restore('/tmp/nnx-surgery-state', item=nnx.state(abs_model))
  print(with_item)
except Exception as e:
  print(f'This will throw error: {type(e)}: {e}')
This will throw error: <class 'KeyError'>: 'layer1'
/Users/ivyzheng/envs/py310/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1401: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
  warnings.warn(

但是,您可以将参数树加载为原始字典,进行重命名,并生成一个新状态,该状态保证与您的新模型定义兼容。

def module_from_variables_dict(module_factory, variables, map_key_fn):
  if map_key_fn is None:
    map_key_fn = lambda path: path
  mdl = nnx.eval_shape(module_factory)
  graph_def, state = nnx.split(mdl)
  state = state.flat_state()
  for path, val in flax.traverse_util.flatten_dict(variables).items():
    mapped_path = map_key_fn(path)
    if mapped_path not in state:
      raise ValueError(f"{mapped_path} doesn't exist in {state.keys()}")
    state[mapped_path].value = val
  state = nnx.State.from_flat_path(state)
  return nnx.merge(graph_def, state)

# Make your local change on the checkpoint.
raw = checkpointer.restore('/tmp/nnx-surgery-state')
pprint(raw)
raw['layer1'], raw['layer2'] = raw['linear1'], raw['linear2']
del raw['linear1'], raw['linear2']

restored_model = module_from_variables_dict(
  lambda: nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0))),
  raw,
  lambda path: path[:-1] if path[-1] == 'raw_value' else path
)

np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))
{'linear1': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)},
             'kernel': {'raw_value': Array([[-0.80345297, -0.34071913, -0.9408296 ,  0.01005968],
       [ 0.26146442,  1.1247735 ,  0.54563737, -0.374164  ],
       [ 1.0281805 , -0.6798804 , -0.1488401 ,  0.05694951],
       [-0.44308168, -0.60587114,  0.434087  , -0.40541083]],      dtype=float32)}},
 'linear2': {'bias': {'raw_value': Array([0., 0., 0., 0.], dtype=float32)},
             'kernel': {'raw_value': Array([[ 0.21010089,  0.8289361 ,  0.04589564,  0.5422644 ],
       [ 0.41914317,  0.84359694, -0.47937787, -0.49135214],
       [-0.46072108,  0.4630125 ,  0.39276958, -0.9441406 ],
       [-0.6690758 , -0.18474789, -0.57622856,  0.4821079 ]],      dtype=float32)}}}

部分初始化#

在某些情况下(例如 LoRA),您可能希望仅随机初始化模型参数的一部分。这可以通过朴素部分初始化或内存高效部分初始化来实现。

朴素部分初始化#

您可以简单地初始化整个模型,然后交换预训练参数。但是这种方法可能会在中途分配额外的内存,如果您的修改需要重新创建您稍后会丢弃的模块参数。请参见下面的示例。

注意:您可以使用 jax.live_arrays() 检查任何给定时间内存中存在的数组。当您多次运行单个笔记本单元格时,此调用可能会被搞乱(由于垃圾收集了旧的 python 变量),但是重新启动内核并从头开始运行将始终产生相同的输出。

# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))
print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')
# On this line, extra kernel and bias is created inside the new LoRALinear!
# They are wasted since you are going to use the kernel and bias in `old_state` anyway.
simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))
print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'
      ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')
nnx.update(simple_model, old_state)
print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'
      ' (2 discarded - only lora_a & lora_b are used in model)')
Number of jax arrays in memory at start: 34
Number of jax arrays in memory midway: 38 (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)
Number of jax arrays in memory at end: 36 (2 discarded - only lora_a & lora_b are used in model)

内存高效部分初始化#

使用 nnx.jit 的高效编译代码来确保只初始化您需要的状态参数。

# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))

# Use `nnx.jit` (which wraps `jax.jit`) to automatically skip unused arrays - memory efficient!
@functools.partial(nnx.jit, donate_argnums=0, static_argnums=1)
def partial_init(old_state, rngs):
  model = TwoLayerMLP(4, rngs=rngs)
  # Create a new state.
  model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=rngs)
  # Add the existing state.
  nnx.update(model, old_state)
  return model

print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')
# Note that `old_state` will be deleted after this `partial_init` call.
good_model = partial_init(old_state, nnx.Rngs(42))
print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'
      ' (2 new created - lora_a and lora_b)')
Number of jax arrays in memory at start: 40
Number of jax arrays in memory at end: 42 (2 new created - lora_a and lora_b)