Burn 是一款新一代的深度学习框架,兼顾了灵活性、效率和可移植性,毫不妥协。
性能
因为我们相信深度学习框架的目标是将计算转化为有用的智能,所以 Burn 把性能作为核心支柱。我们努力通过多种优化技术来实现极高的效率,下面将详细介绍这些技术。
点击每个部分了解更多详情 👇
自动内核融合 💥
使用 Burn 意味着你的模型会在任何后端上被自动优化。我们会在可能的情况下自动、动态地创建自定义内核,以最小化不同内存空间之间的数据搬运,这在内存移动成为瓶颈时非常有用。
例如,你可以使用高级 Tensor API 来编写自己的 GELU 激活函数(见下面的 Rust 代码片段):
fn gelu_custom<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
let x = x.clone() * ((x / SQRT_2).erf() + 1);
x / 2
}
在运行时,Burn 会为你的实现自动生成一个底层自定义内核,其性能可与手写 GPU 实现相媲美。这个内核大约有 60 行 WGSL(WebGPU Shading Language)代码,这是一种非常繁琐的底层着色器语言,你大概不会想用它来写深度学习模型。
异步执行 ❤️🔥
在 first-party backends 中,Burn 使用异步执行方式,这使得各种优化(如自动内核融合)成为可能。
异步执行保证框架的正常执行不会阻塞模型计算,这意味着框架开销不会显著影响执行速度。反过来,模型中的高强度计算也不会干扰框架的响应性。更多关于异步后端的信息可以参阅 这篇博客。
线程安全的构建块 🦞
Burn 通过利用 Rust 的所有权系统 强调线程安全。在 Burn 中,每个模块都拥有自己的权重。因此可以把一个模块发送到另一个线程计算梯度,然后再把梯度发送回主线程进行聚合,从而实现多设备训练。
这与 PyTorch 的方式完全不同,PyTorch 的反向传播会直接修改每个 Tensor 参数的 _grad
属性,这不是线程安全的操作,需要使用底层同步机制(参考 distributed training)。虽然 PyTorch 依然很快,但它在不同后端之间不易兼容,也难以实现。
智能内存管理 🦀
深度学习框架的主要职责之一是减少运行模型所需的内存。最简单的方式是为每个 Tensor 分配独立内存,当 Tensor 释放时再回收。但频繁分配和回收内存非常昂贵,因此通常需要内存池来提高吞吐量。Burn 提供了便捷的内存管理策略选择机制。详情可参阅 这篇博客。
Burn 还会跟踪何时可以原地(in-place)修改 Tensor,从而进一步节省内存。虽然这是一个小优化,但在大型模型的训练和推理中累计效果明显。更多信息请见 Tensor 处理的博客。
自动内核选择 🎯
优秀的深度学习框架必须确保模型在各种硬件上都能高效运行。然而,不同硬件的执行速度差异很大。例如,矩阵乘法内核有许多参数配置,这些参数对矩阵大小和硬件非常敏感,配置不当可能会使速度降低十倍以上。
Burn 的自研后端会自动运行基准测试,为当前硬件和矩阵大小选择最佳配置,并采用合理的缓存策略。虽然会增加少量预热时间,但长远来看能节省大量时间。此功能可根据需要关闭,以便在冷启动时优先减少延迟。
硬件特定优化 🔥
深度学习的核心操作大多依赖矩阵乘法,因此硬件厂商越来越多地针对这种工作负载优化芯片。例如,Nvidia 有 Tensor Cores,大部分手机也有 AI 专用芯片。目前 Burn 在 LibTorch、Candle、CUDA、Metal 和 WGPU/SPIR-V 后端中支持 Tensor Cores,但其他加速器暂不支持。我们希望 这个 issue 能尽快解决,以便支持 WGPU 后端。
自定义 Backend 扩展 🎒
Burn 旨在成为最灵活的深度学习框架。除了兼容各种后端,Burn 还支持扩展后端功能,以满足你的个性化需求。
例如,你可以添加自定义操作(如 flash attention)或为特定后端手写内核以提升性能。详情可参考 Burn Book 🔥。
Backend
Burn 致力于在尽可能多的硬件上实现高性能和稳健实现。我们相信这种灵活性对现代需求至关重要:你可以在云端训练模型,然后部署到客户的各种硬件上。
支持的 Backends
Backend | Devices | Class |
---|---|---|
CUDA | NVIDIA GPUs | First-Party |
ROCm | AMD GPUs | First-Party |
Metal | Apple GPUs | First-Party |
Vulkan | Linux & Windows 上的大部分 GPU | First-Party |
Wgpu | 大多数 GPU | First-Party |
NdArray | 大多数 CPU | Third-Party |
LibTorch | 大多数 GPU 和 CPU | Third-Party |
Candle | Nvidia、Apple GPU & CPU | Third-Party |
相比其他框架,Burn 在多后端支持上有着完全不同的设计思路。大部分代码都基于 Backend trait 进行泛型抽象,这使 Burn 可以灵活切换后端,还能通过组合添加功能(如自动微分和自动内核融合)。