Deepseek 开源周第三发 DeepGEMM:DeepGEMM 是一个支持密集型和 MoE GEMM 的 FP8 GEMM 库,核心逻辑仅约300行代码,极限情况下可以将 NVIDIA H800 的计算性能提高 2.7 倍。本文将详细介绍 DeepGEMM 的设计原理、性能优势以及快速启动方法。
介绍
DeepGEMM 是一个专为高效 FP8 通用矩阵乘法(GEMM)设计的库。它支持普通和混合专家(MoE)分组 GEMM,并且核心逻辑仅约300行代码。该库专为 NVIDIA H800 GPU 优化,能够在某些情况下将计算性能提升至1350+ TFLOPS。DeepGEMM 的设计非常简洁,易于理解和优化,是学习 Hopper FP8 矩阵乘法和优化技术的极佳资源。
核心功能
1. 支持多种 GEMM 类型
DeepGEMM 支持以下类型的 GEMM:
- 密集型 GEMM:适用于普通模型。
- 分组 GEMM(连续布局):适用于 MoE 模型中的专家具有相同形状的场景。
- 分组 GEMM(掩膜布局):适用于推理解码阶段,支持掩码张量。
2. 高性能
DeepGEMM 在极限情况下可以将 NVIDIA H800 的计算性能提高 2.7 倍。以下是一些性能测试结果:
M | N | K | 计算 (TFLOPS) | 内存带宽 (GB/秒) | 加速 (倍) |
---|---|---|---|---|---|
64 | 2112 | 7168 | 206 | 1688 | 2.7 |
64 | 24576 | 1536 | 289 | 2455 | 1.7 |
64 | 32768 | 512 | 219 | 2143 | 1.8 |
... | ... | ... | ... | ... | ... |
快速启动
系统要求
- Hopper 架构 GPU,支持 sm_90a
- Python 3.8 或更高版本
- CUDA 12.3 或更高版本(推荐 12.8 或更高版本)
- PyTorch 2.1 或更高版本
- CUTLASS 3.6 或更高版本(可以通过 Git 子模块克隆)
安装步骤
安装依赖:
pip install torch cutlass
克隆 DeepGEMM 仓库:
git clone https://github.com/deepseek-ai/DeepGEMM.git cd DeepGEMM
使用 GEMM 内核:
正常密集 GEMM:
from deep_gemm import gemm_fp8_fp8_bf16_nt result = gemm_fp8_fp8_bf16_nt(A, B, C)
分组 GEMM(连续布局):
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_contiguous result = m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(A, B, C, group_m)
分组 GEMM(掩膜布局):
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked result = m_grouped_gemm_fp8_fp8_bf16_nt_masked(A, B, C, mask)
优化细节
1. 持久扭曲专业化
DeepGEMM 中的内核经过了 warp 专门化,以实现重叠数据移动、张量核心 MMA 指令和 CUDA 核心提升。这种方法能够在不同 warpgroups 之间优化寄存器计数控制,并尽可能重叠操作。
2. Hopper TMA 功能
Hopper 引入了张量内存加速器(TMA),DeepGEMM 充分利用 TMA 来实现更快、更异步的数据移动。具体应用包括 LHS、LHS 缩放因子和 RHS 矩阵的 TMA 负载、TMA 存储输出矩阵等。
3. 完全 JIT 设计
DeepGEMM 采用完全即时(JIT)设计,安装时无需编译。所有内核均在运行时编译,具有更高的灵活性和性能。GEMM 形状、块大小和管道阶段数等参数被视为编译时常量,编译器可以进行更多优化。
结论
DeepGEMM 是一个简洁高效、易于理解和优化的 FP8 GEMM 库,专为 NVIDIA H800 GPU 优化。通过减少依赖、简化设计和充分利用 Hopper 架构的新功能,DeepGEMM 能够在多种矩阵形状下达到出色的性能。对于需要高性能 GEMM 计算的深度学习项目,DeepGEMM 是一个值得尝试的工具。
评论 (0)