Triton矩阵乘法Kernel深度解析

概述

本文深入分析Triton编写的矩阵乘法kernel,重点解释分块累加机制、内存访问模式和并行执行策略。

代码示例

import torch 
import triton 
import triton.language  as tl 

@triton.jit 
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,  # 矩阵指针
    M, N, K,  # 矩阵维度 
    stride_am, stride_ak,  # A的步长 
    stride_bk, stride_bn,  # B的步长 
    stride_cm, stride_cn,  # C的步长 
    BLOCK_SIZE: tl.constexpr,   # 分块大小 
):
    # 获取程序ID
    pid = tl.program_id(axis=0) 
    # 计算当前块处理的C矩阵范围 
    rm = pid * BLOCK_SIZE + tl.arange(0,  BLOCK_SIZE)
    rn = tl.arange(0,  BLOCK_SIZE)
    # 初始化累加器
    acc = tl.zeros((BLOCK_SIZE,  BLOCK_SIZE), dtype=tl.float32) 

    # 分块乘法累加
    for k in range(0, K, BLOCK_SIZE):
        a = tl.load(a_ptr  + rm[:, None] * stride_am + k * stride_ak, mask=rm[:, None] < M)
        b = tl.load(b_ptr  + k * stride_bk + rn[None, :] * stride_bn, mask=rn[None, :] < N)
        acc += tl.dot(a,  b)

    # 存储结果 
    tl.store(c_ptr  + rm[:, None] * stride_cm + rn[None, :] * stride_cn, acc)

def matmul(a, b):
    # 验证输入
    assert a.shape[1]  == b.shape[0],  "Incompatible dimensions"
    M, K = a.shape 
    K, N = b.shape 
    # 分配输出
    c = torch.empty((M,  N), device=a.device,  dtype=a.dtype) 
    # 计算网格大小 
    grid = lambda META: (triton.cdiv(M,  META['BLOCK_SIZE']),)
    # 启动内核 
    matmul_kernel[grid](
        a, b, c, M, N, K,
        a.stride(0),  a.stride(1), 
        b.stride(0),  b.stride(1), 
        c.stride(0),  c.stride(1), 
        BLOCK_SIZE=32 
    )
    return c

# 测试代码
if __name__ == "__main__":
    a = torch.randn(128, 64, device='cuda', dtype=torch.float32)
    b = torch.randn(64, 256, device='cuda', dtype=torch.float32)
    c = matmul(a, b)
    print(c.shape)  # 应该输出 (128, 256)
    print(c)  # 打印结果矩阵

分块累加详解

循环次数与数据加载

以上述示例为例,输入矩阵维度:

  • A: (128, 64)
  • B: (64, 256)
  • C: (128, 256)
  • BLOCK_SIZE = 32

循环次数

for k in range(0, K, BLOCK_SIZE):  # K=64, BLOCK_SIZE=32

循环次数:64 / 32 = 2次,k的值分别为0和32。

每次循环的内存加载

第1次循环 (k=0):

  • A块形状: (32, 32) - 从A矩阵的列0-31加载32列数据
  • B块形状: (32, 32) - 从B矩阵的行0-31加载32行数据
  • 运算: tl.dot(a, b) 执行 (32×32) × (32×32) 的矩阵乘法
  • 结果: 得到 (32, 32) 的中间结果矩阵

第2次循环 (k=32):

  • A块形状: (32, 32) - 从A矩阵的列32-63加载32列数据
  • B块形状: (32, 32) - 从B矩阵的行32-63加载32行数据
  • 运算: tl.dot(a, b) 执行 (32×32) × (32×32) 的矩阵乘法
  • 结果: 得到 (32, 32) 的中间结果矩阵

累加过程

acc += tl.dot(a, b)
  • 初始: acc 是 (32, 32) 的零矩阵
  • 第1次累加后: acc = 第1次dot运算结果 (32×32)
  • 第2次累加后: acc = 第1次结果 + 第2次结果 (32×32)

累加完成后,acc 包含了当前工作组负责的C矩阵的一个(32, 32)块的完整计算结果。

并行执行策略

Kernel内部与Grid的分工

Kernel内部的职责

  • 只关心一个固定的(BLOCK_SIZE, BLOCK_SIZE)输出块
  • 在K维度上进行分块累加
  • 处理局部计算
rm = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)  # 固定的行范围
rn = tl.arange(0, BLOCK_SIZE)                     # 固定的列范围

# 只需要在K维度上分块累加
for k in range(0, K, BLOCK_SIZE):
    a = tl.load(...)  # A的(32, 32)块
    b = tl.load(...)  # B的(32, 32)块
    acc += tl.dot(a, b)  # 累加中间结果

Grid的职责

  • 负责并行调度,覆盖整个输出矩阵
  • 启动多个kernel实例
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE']),)
# 对于M=128: 启动4个kernel实例
# pid=0: 处理C[0:32, 0:32]
# pid=1: 处理C[32:64, 0:32] 
# pid=2: 处理C[64:96, 0:32]
# pid=3: 处理C[96:128, 0:32]

设计思想

分工明确:

  • Kernel: 专注计算一个小块,处理K维度的累加
  • Grid: 负责并行调度,覆盖整个输出矩阵

每个kernel实例是一个"工作单元",独立计算一个输出块,通过大量并行的工作单元来加速整个矩阵乘法。其实要想独立计算一个输出块,也只能在完整的K维度上实现。参考矩阵乘法分块原理文章

内存访问模式详解

指针计算过程

以A矩阵加载为例:

a = tl.load(a_ptr + rm[:, None] * stride_am + k * stride_ak, mask=rm[:, None] < M)

rm的维度变化

rm = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# 假设pid=0, BLOCK_SIZE=32
# rm = [0, 1, 2, ..., 31]  # 形状: (32,) 一维数组

rm[:, None]  # 将rm从(32,)变成(32, 1)的二维数组
# [[0],
#  [1], 
#  [2],
#  ...
#  [31]]

矩阵A的内存布局: 对于矩阵A(128, 64),按行主序存储:

stride_am = 64  # 行步长:每行有64个元素
stride_ak = 1   # 列步长:相邻列元素间距为1

# A[i,j]的内存地址 = a_ptr + i * stride_am + j * stride_ak
#                  = a_ptr + i * 64 + j * 1

第1次循环 (k=0, pid=0) 的地址计算

rm[:, None] * stride_am = [[0],   * 64 = [[0],
                           [1],          [64],
                           [2],          [128],
                           ...           ...
                           [31]]         [1984]]

k * stride_ak = 0 * 1 = 0

# 最终地址对应A矩阵的:
A[0, 0:32]   # 第0行,列0-31
A[1, 0:32]   # 第1行,列0-31  
A[2, 0:32]   # 第2行,列0-31
...
A[31, 0:32]  # 第31行,列0-31

为什么使用[:, None]: 通过rm[:, None]将一维索引变成二维,可以利用广播机制一次性计算出整个块的所有地址,而不需要循环。这是Triton中高效内存访问的关键技巧。

为什么使用正方形块?

1. 内存访问模式的对称性

对于矩阵乘法 C = A × B:

  • A矩阵访问:按行读取数据(连续内存访问)
  • B矩阵访问:按列读取数据(跨步内存访问)

使用正方形块可以让A和B的访问模式更加对称。

2. 缓存效率

正方形块能更好地利用缓存:

  • 空间局部性:正方形块在内存中的布局更紧凑
  • 时间局部性:数据被重复使用的次数相等

例如,32×32的块:

  • A的每行被重用32次
  • B的每列被重用32次
  • 重用次数相等,缓存效率最优

3. 计算单元利用率

现代GPU/处理器的计算单元通常设计为处理方形数据块:

  • Tensor Core:通常是16×16或32×32
  • 向量处理单元:对称的SIMD操作
  • 并行度平衡:行和列方向的并行度相等

4. 算法复杂度

对于N×N的正方形块,矩阵乘法需要:

  • 内存读取:2N²次(A和B各N²)
  • 计算操作:N³次乘加运算
  • 计算密度:N³/(2N²) = N/2

正方形块能达到最优的计算密度比。

5. 通用性与简单性

虽然某些矩形块可能有更好的计算密度,但正方形块提供了:

  • 通用性:适合各种矩阵尺寸
  • 简单性:编程和调优更容易
  • 硬件友好:与硬件设计更匹配
  • 数值稳定性:避免极端的长宽比

总结

Triton矩阵乘法kernel的核心设计思想是:

  1. 分而治之:将大矩阵分解为小块并行处理
  2. 内核专一:每个kernel只负责一个输出块的计算
  3. 高效内存访问:使用广播机制优化内存访问模式
  4. 正方形分块:平衡计算效率和硬件适配性

这种设计充分利用了GPU的并行计算能力和内存层次结构,是高性能矩阵乘法实现的典型范例。