超大模型加载转换Trick
在深度学习领域,大模型的训练和推理通常需要消耗大量的计算和内存。如何高效地加载和使用大模型是一个相当关键的问题。在这篇博客中,我将分享一些关于更快加载大模型和减少内存的技巧.
问题分析
假设现在我们有一个236B 超大模型的原始权重的 checkpoint.pth
文件, 比如 DeepSeek Chat V2, 以BF16 格式存储, 一个标准的加载流程如下
1 | import torch |
在这段代码的中, my_model = BigModelClass(...)
会初始化一个模型, torch.load(checkpoint_file)
函数会将模型权重从磁盘加载到内存中。然后,my_model.load_state_dict(state_dict)
函数会将权重从内存加载到模型的参数中。这两个步骤都可能会消耗大量的时间和内存。理想情况下, 一个236B BF16格式的模型需要占据 472GB 的内存, 上面的代码会有两个模型副本, 这意味着峰值需要944GB 内存, 接近1T ,这是非常夸张的也是不可接受的.
我们用一段简单的代码来验证上面的推断, 首先初始化一个 1B size 的模型并存下来,
1 | import torch |
The model has 1.0001 B trainable parameters
The model’s memory size is approximately 3.73 GB.
然后 按照上面的方式加载模型, 并统计cpu 内存占用, torch 默认是FP32 格式, 1B模型占用约 4GB 内存(实际为3.73GB左右), 下面代码验证后基本符合预期
1 | def print_usage(): |
Before Load the state_dict:
memory: 0.34 GB
CPU percent: 8.5
1 | start_time = time.time() |
Loading the state_dict took 2.09 seconds
After Load the state_dict:
memory: 4.06 GB
CPU percent: 7.0
4.06 - 0.34 = 3.72基本一致
1 |
|
Init the model took 7.23 seconds
After Init the model:
memory: 7.79 GB
CPU percent: 7.6
7.79 - 4.06 = 3.73 基本一致
1 | start_time = time.time() |
Loading the state_dict to model took 2.63 seconds
After Load the state_dict to model:
memory: 7.79 GB
CPU percent: 16.4
问题解决
分析清楚在加载和初始化环节中各个流程的开销, 我们来看看可以如何加速每个过程.
使用torch.load(mmap=True)
首先,让我们考虑一下当我们使用 加载检查点时会发生什么torch.load
。当我们使用 保存检查点时torch.save
,张量存储会使用保存它们的设备进行标记。使用torch.load
,张量存储将加载到它们标记的设备(除非使用标志覆盖此行为 map_location
)。为了便于解释,我们假设张量保存在 CPU 上。这意味着在第一行,所有张量存储都将加载到 CPU RAM 中,这在以下情况下可能是不可行的:
- CPU RAM 小于检查点的大小。
- 等待整个检查点加载到 RAM 中,然后再执行某些按张量处理等操作。
1 | start_time = time.time() |
loading time without mmap=2.0737619400024414
memory: 4.06 GB
CPU percent: 8.7
torch.load
中的mmap
参数图解决上述两个问题。顾名思义,mmap
关键字参数 totorch.load
使用mmap 调用 ,将磁盘上的文件映射到虚拟内存,并让操作系统自动处理到物理内存的加载和卸载。当这个标志被传递时,张量存储将被内存映射。
1 | start_time = time.time() |
loading time with mmap=0.003424406051635742
memory: 0.34 GB
通过上面对比,我们可以发现 使用mmap可以加速模型加载并减少内存占用, 对于236B的模型, 我们实际上并不需要 1TB的 CPU内存来完成转换
使用 torch.device('meta')
当模型size 巨大时, 模型初始化也需要巨大时间, 我们扩大一下模型size到25B, 初始化一个模型就需要接近3分钟.
1 | size = 50000 |
init time=184.56671452522278
The model has 25.0005 B trainable parameters
The model’s memory size is approximately 93.13 GB.
但在load 模型时, 初始化这一步是多余的, 我们实际上只需要知道模型的所有 key 和 对应的 shape,
这个时候, torch.device('meta')
这个 上下文就可以发挥作用了, torch.device()
上下文管理器确保工厂调用将像它们被传递了指定的”device”作为参数一样执行。在 torch.device('meta')
上的张量不携带数据。然而,它们具有张量所具有的所有其他元数据,例如.size()
、.stride()
、.requires_grad
等。
1 | with torch.device('meta'): |
注意, 在使用 torch.device('meta')
后, 我们需要加上 assign=True
参数来让参数被加载. 最后一段代码可以check 所有参数被正确加载了, 加载后的参数的 device应该不再是 meta
了.
实验结果
最后, 我们直接上一个100B size大小的大模型来对比, 是否使用 torch.load(mmap=True)
和torch.device('meta')
速度差别.
1 | size = 100000 |
The model has 100.001 B trainable parameters
The model’s memory size is approximately 186.27 GB.
加速前
1 | start_time = time.time() |
Before Load the state_dict:
memory: 0.34 GB
CPU percent: 9.1
Loading the state_dict took 852.06 seconds
After Load the state_dict:
memory: 372.87 GB
CPU percent: 5.0
Init the model took 518.15 seconds
After Init the model:
memory: 745.41 GB
CPU percent: 4.9
Loading the state_dict to model took 125.63 seconds
After Load the state_dict to model:
memory: 745.41 GB
CPU percent: 11.7
tensor([[-0.0015, 0.0017, -0.0009, …, -0.0036, 0.0041, 0.0052]],
grad_fn=\
) One time forward 6.95 seconds
memory: 745.42 GB
CPU percent: 11.4
加速后
1 | start_time = time.time() |
Before Load the state_dict:
memory: 0.34 GB
CPU percent: 9.1
Loading the state_dict took 0.11 seconds
After Load the state_dict:
memory: 0.34 GB
CPU percent: 6.1
Init the model took 0.00 seconds
After Init the model:
memory: 0.34 GB
CPU percent: 4.3
Loading the state_dict to model took 0.00 seconds
After Load the state_dict to model:
memory: 0.34 GB
CPU percent: 10.0
tensor([[ 0.0080, -0.0017, -0.0027, …, -0.0011, 0.0097, -0.0048]],
grad_fn=\
) One time forward 48.37 seconds
memory: 372.85 GB
CPU percent: 5.2
tensor([[ 0.0038, 0.0014, -0.0076, …, -0.0016, 0.0004, -0.0018]],
grad_fn=\
) One time forward 3.28 seconds
memory: 372.86 GB
CPU percent: 13.4
通过上面的对比, 加速前100B模型加载时间为
852.06 + 518.15 + 125.63 = 1495(s) = 25 (min)
而使用 mmap
+ meta device
加载几乎没有时间开销, 只有模型真正运行时才会从硬盘拷贝权重到CPU RAM.